gitcrawl/internal/github/client.go
2026-05-05 03:26:39 +01:00

390 lines
11 KiB
Go

package github
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
)
type Reporter func(message string)
type Client struct {
httpClient *http.Client
baseURL string
token string
userAgent string
pageDelay time.Duration
}
type Options struct {
Token string
BaseURL string
UserAgent string
HTTPClient *http.Client
PageDelay time.Duration
}
type ListIssuesOptions struct {
State string
Since string
Limit int
ExpectedTotal int
}
type ListWorkflowRunsOptions struct {
Branch string
HeadSHA string
Limit int
}
type RequestError struct {
Method string
URL string
Status int
Body string
Headers http.Header
}
func (e *RequestError) Error() string {
if e.Body == "" {
return fmt.Sprintf("github %s %s failed with status %d", e.Method, e.URL, e.Status)
}
return fmt.Sprintf("github %s %s failed with status %d: %s", e.Method, e.URL, e.Status, e.Body)
}
func New(options Options) *Client {
baseURL := strings.TrimRight(options.BaseURL, "/")
if baseURL == "" {
baseURL = "https://api.github.com"
}
httpClient := options.HTTPClient
if httpClient == nil {
httpClient = &http.Client{Timeout: 30 * time.Second}
}
userAgent := options.UserAgent
if userAgent == "" {
userAgent = "gitcrawl"
}
return &Client{
httpClient: httpClient,
baseURL: baseURL,
token: options.Token,
userAgent: userAgent,
pageDelay: options.PageDelay,
}
}
func (c *Client) GetRepo(ctx context.Context, owner, repo string, reporter Reporter) (map[string]any, error) {
var out map[string]any
if err := c.doJSON(ctx, http.MethodGet, fmt.Sprintf("/repos/%s/%s", pathEscape(owner), pathEscape(repo)), nil, reporter, &out); err != nil {
return nil, err
}
return out, nil
}
func (c *Client) GetIssue(ctx context.Context, owner, repo string, number int, reporter Reporter) (map[string]any, error) {
var out map[string]any
path := fmt.Sprintf("/repos/%s/%s/issues/%d", pathEscape(owner), pathEscape(repo), number)
if err := c.doJSON(ctx, http.MethodGet, path, nil, reporter, &out); err != nil {
return nil, err
}
return out, nil
}
func (c *Client) GetPull(ctx context.Context, owner, repo string, number int, reporter Reporter) (map[string]any, error) {
var out map[string]any
path := fmt.Sprintf("/repos/%s/%s/pulls/%d", pathEscape(owner), pathEscape(repo), number)
if err := c.doJSON(ctx, http.MethodGet, path, nil, reporter, &out); err != nil {
return nil, err
}
return out, nil
}
func (c *Client) ListRepositoryIssues(ctx context.Context, owner, repo string, options ListIssuesOptions, reporter Reporter) ([]map[string]any, error) {
values := url.Values{}
state := strings.TrimSpace(options.State)
if state == "" {
state = "open"
}
values.Set("state", state)
values.Set("sort", "updated")
values.Set("direction", "desc")
values.Set("per_page", "100")
if options.Since != "" {
values.Set("since", options.Since)
}
path := fmt.Sprintf("/repos/%s/%s/issues?%s", pathEscape(owner), pathEscape(repo), values.Encode())
return c.paginate(ctx, path, options.Limit, options.ExpectedTotal, reporter)
}
func (c *Client) ListIssueComments(ctx context.Context, owner, repo string, number int, reporter Reporter) ([]map[string]any, error) {
path := fmt.Sprintf("/repos/%s/%s/issues/%d/comments?per_page=100", pathEscape(owner), pathEscape(repo), number)
return c.paginate(ctx, path, 0, 0, reporter)
}
func (c *Client) ListPullReviews(ctx context.Context, owner, repo string, number int, reporter Reporter) ([]map[string]any, error) {
path := fmt.Sprintf("/repos/%s/%s/pulls/%d/reviews?per_page=100", pathEscape(owner), pathEscape(repo), number)
return c.paginate(ctx, path, 0, 0, reporter)
}
func (c *Client) ListPullReviewComments(ctx context.Context, owner, repo string, number int, reporter Reporter) ([]map[string]any, error) {
path := fmt.Sprintf("/repos/%s/%s/pulls/%d/comments?per_page=100", pathEscape(owner), pathEscape(repo), number)
return c.paginate(ctx, path, 0, 0, reporter)
}
func (c *Client) ListPullFiles(ctx context.Context, owner, repo string, number int, reporter Reporter) ([]map[string]any, error) {
path := fmt.Sprintf("/repos/%s/%s/pulls/%d/files?per_page=100", pathEscape(owner), pathEscape(repo), number)
return c.paginate(ctx, path, 0, 0, reporter)
}
func (c *Client) ListPullCommits(ctx context.Context, owner, repo string, number int, reporter Reporter) ([]map[string]any, error) {
path := fmt.Sprintf("/repos/%s/%s/pulls/%d/commits?per_page=100", pathEscape(owner), pathEscape(repo), number)
return c.paginate(ctx, path, 0, 0, reporter)
}
func (c *Client) ListCommitCheckRuns(ctx context.Context, owner, repo, ref string, reporter Reporter) ([]map[string]any, error) {
var payload struct {
CheckRuns []map[string]any `json:"check_runs"`
}
path := fmt.Sprintf("/repos/%s/%s/commits/%s/check-runs?per_page=100", pathEscape(owner), pathEscape(repo), pathEscape(ref))
if err := c.doJSON(ctx, http.MethodGet, path, nil, reporter, &payload); err != nil {
return nil, err
}
return payload.CheckRuns, nil
}
func (c *Client) ListWorkflowRuns(ctx context.Context, owner, repo string, options ListWorkflowRunsOptions, reporter Reporter) ([]map[string]any, error) {
values := url.Values{}
values.Set("per_page", "100")
if options.Branch != "" {
values.Set("branch", options.Branch)
}
if options.HeadSHA != "" {
values.Set("head_sha", options.HeadSHA)
}
path := fmt.Sprintf("/repos/%s/%s/actions/runs?%s", pathEscape(owner), pathEscape(repo), values.Encode())
var payload struct {
WorkflowRuns []map[string]any `json:"workflow_runs"`
}
if err := c.doJSON(ctx, http.MethodGet, path, nil, reporter, &payload); err != nil {
return nil, err
}
rows := payload.WorkflowRuns
if options.Limit > 0 && len(rows) > options.Limit {
rows = rows[:options.Limit]
}
return rows, nil
}
func (c *Client) paginate(ctx context.Context, firstPath string, limit int, expectedItems int, reporter Reporter) ([]map[string]any, error) {
var out []map[string]any
nextPath := firstPath
page := 0
totalPages := 0
if expectedItems > 0 {
totalPages = (expectedItems + 99) / 100
}
for nextPath != "" {
page++
var rows []map[string]any
resp, err := c.do(ctx, http.MethodGet, nextPath, nil, reporter)
if err != nil {
return nil, err
}
if err := json.NewDecoder(resp.Body).Decode(&rows); err != nil {
_ = resp.Body.Close()
return nil, fmt.Errorf("decode github page: %w", err)
}
_ = resp.Body.Close()
if limit > 0 && len(out)+len(rows) > limit {
rows = rows[:limit-len(out)]
}
out = append(out, rows...)
linkHeader := resp.Header.Get("Link")
if last := lastPage(linkHeader); last > totalPages {
totalPages = last
}
if totalPages > 0 && page > totalPages {
totalPages = page
}
if totalPages > 0 {
reporter.Printf("[github] page %d/%d fetched count=%d accumulated=%d", page, totalPages, len(rows), len(out))
} else {
reporter.Printf("[github] page %d fetched count=%d accumulated=%d", page, len(rows), len(out))
}
if limit > 0 && len(out) >= limit {
break
}
nextPath = nextPage(linkHeader)
if nextPath != "" && c.pageDelay > 0 {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(c.pageDelay):
}
}
}
return out, nil
}
func (c *Client) doJSON(ctx context.Context, method, path string, body io.Reader, reporter Reporter, out any) error {
resp, err := c.do(ctx, method, path, body, reporter)
if err != nil {
return err
}
defer resp.Body.Close()
if err := json.NewDecoder(resp.Body).Decode(out); err != nil {
return fmt.Errorf("decode github response: %w", err)
}
return nil
}
func (c *Client) do(ctx context.Context, method, path string, body io.Reader, reporter Reporter) (*http.Response, error) {
resp, err := c.doOnce(ctx, method, path, body, reporter)
if err == nil {
return resp, nil
}
wait, ok := rateLimitWait(err)
if !ok {
return nil, err
}
reporter.Printf("[github] rate-limit retry wait=%s", wait.Round(time.Second))
timer := time.NewTimer(wait)
defer timer.Stop()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-timer.C:
}
return c.doOnce(ctx, method, path, body, reporter)
}
func (c *Client) doOnce(ctx context.Context, method, path string, body io.Reader, reporter Reporter) (*http.Response, error) {
fullURL := c.baseURL + path
req, err := http.NewRequestWithContext(ctx, method, fullURL, body)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/vnd.github+json")
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
req.Header.Set("User-Agent", c.userAgent)
if c.token != "" {
req.Header.Set("Authorization", "Bearer "+c.token)
}
reporter.Printf("[github] request %s %s", method, path)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("github request: %w", err)
}
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return resp, nil
}
defer resp.Body.Close()
data, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return nil, &RequestError{
Method: method,
URL: path,
Status: resp.StatusCode,
Body: strings.TrimSpace(string(data)),
Headers: resp.Header,
}
}
func rateLimitWait(err error) (time.Duration, bool) {
reqErr, ok := err.(*RequestError)
if !ok {
return 0, false
}
if reqErr.Status != http.StatusForbidden && reqErr.Status != http.StatusTooManyRequests {
return 0, false
}
if v := strings.TrimSpace(reqErr.Headers.Get("Retry-After")); v != "" {
if secs, err := strconv.Atoi(v); err == nil && secs > 0 {
return time.Duration(secs) * time.Second, true
}
}
if reqErr.Headers.Get("X-RateLimit-Remaining") != "0" {
return 0, false
}
secs, err := strconv.ParseInt(strings.TrimSpace(reqErr.Headers.Get("X-RateLimit-Reset")), 10, 64)
if err != nil {
return 0, false
}
if wait := time.Until(time.Unix(secs, 0)); wait > 0 {
return wait, true
}
return time.Second, true
}
func nextPage(linkHeader string) string {
for _, part := range strings.Split(linkHeader, ",") {
sections := strings.Split(part, ";")
if len(sections) < 2 {
continue
}
if strings.TrimSpace(sections[1]) != `rel="next"` {
continue
}
rawURL := strings.Trim(strings.TrimSpace(sections[0]), "<>")
parsed, err := url.Parse(rawURL)
if err != nil {
return ""
}
if parsed.RawQuery == "" {
return parsed.Path
}
return parsed.Path + "?" + parsed.RawQuery
}
return ""
}
func lastPage(linkHeader string) int {
for _, part := range strings.Split(linkHeader, ",") {
sections := strings.Split(part, ";")
if len(sections) < 2 {
continue
}
if strings.TrimSpace(sections[1]) != `rel="last"` {
continue
}
rawURL := strings.Trim(strings.TrimSpace(sections[0]), "<>")
parsed, err := url.Parse(rawURL)
if err != nil {
return 0
}
page, _ := strconv.Atoi(parsed.Query().Get("page"))
return page
}
return 0
}
func pathEscape(value string) string {
return url.PathEscape(value)
}
func (r Reporter) Printf(format string, args ...any) {
if r != nil {
r(fmt.Sprintf(format, args...))
}
}
func intValue(value any) int {
switch typed := value.(type) {
case float64:
return int(typed)
case int:
return typed
case json.Number:
parsed, _ := strconv.Atoi(string(typed))
return parsed
default:
return 0
}
}