feat(github): retry on X-RateLimit-Remaining and Retry-After in client.do (#5)

Single chokepoint Client.do now inspects the response on 403/429 and,
when the headers say we're rate-limited (X-RateLimit-Remaining=0 or a
Retry-After value), sleeps until the reset and retries once. The sleep
honors ctx cancellation, so callers bound the wait with
context.WithTimeout if they want a deadline.

RequestError gains a Headers field so the retry decision can read the
underlying response without re-fetching.
This commit is contained in:
Dallin Romney 2026-05-01 14:32:52 +08:00 committed by GitHub
parent b1411e8760
commit 330f492666
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 148 additions and 7 deletions

View File

@ -37,10 +37,11 @@ type ListIssuesOptions struct {
}
type RequestError struct {
Method string
URL string
Status int
Body string
Method string
URL string
Status int
Body string
Headers http.Header
}
func (e *RequestError) Error() string {
@ -63,13 +64,12 @@ func New(options Options) *Client {
if userAgent == "" {
userAgent = "gitcrawl"
}
pageDelay := options.PageDelay
return &Client{
httpClient: httpClient,
baseURL: baseURL,
token: options.Token,
userAgent: userAgent,
pageDelay: pageDelay,
pageDelay: options.PageDelay,
}
}
@ -185,6 +185,26 @@ func (c *Client) doJSON(ctx context.Context, method, path string, body io.Reader
}
func (c *Client) do(ctx context.Context, method, path string, body io.Reader, reporter Reporter) (*http.Response, error) {
resp, err := c.doOnce(ctx, method, path, body, reporter)
if err == nil {
return resp, nil
}
wait, ok := rateLimitWait(err)
if !ok {
return nil, err
}
reporter.Printf("[github] rate-limit retry wait=%s", wait.Round(time.Second))
timer := time.NewTimer(wait)
defer timer.Stop()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-timer.C:
}
return c.doOnce(ctx, method, path, body, reporter)
}
func (c *Client) doOnce(ctx context.Context, method, path string, body io.Reader, reporter Reporter) (*http.Response, error) {
fullURL := c.baseURL + path
req, err := http.NewRequestWithContext(ctx, method, fullURL, body)
if err != nil {
@ -206,7 +226,39 @@ func (c *Client) do(ctx context.Context, method, path string, body io.Reader, re
}
defer resp.Body.Close()
data, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return nil, &RequestError{Method: method, URL: path, Status: resp.StatusCode, Body: strings.TrimSpace(string(data))}
return nil, &RequestError{
Method: method,
URL: path,
Status: resp.StatusCode,
Body: strings.TrimSpace(string(data)),
Headers: resp.Header,
}
}
func rateLimitWait(err error) (time.Duration, bool) {
reqErr, ok := err.(*RequestError)
if !ok {
return 0, false
}
if reqErr.Status != http.StatusForbidden && reqErr.Status != http.StatusTooManyRequests {
return 0, false
}
if v := strings.TrimSpace(reqErr.Headers.Get("Retry-After")); v != "" {
if secs, err := strconv.Atoi(v); err == nil && secs > 0 {
return time.Duration(secs) * time.Second, true
}
}
if reqErr.Headers.Get("X-RateLimit-Remaining") != "0" {
return 0, false
}
secs, err := strconv.ParseInt(strings.TrimSpace(reqErr.Headers.Get("X-RateLimit-Reset")), 10, 64)
if err != nil {
return 0, false
}
if wait := time.Until(time.Unix(secs, 0)); wait > 0 {
return wait, true
}
return time.Second, true
}
func nextPage(linkHeader string) string {

View File

@ -3,10 +3,14 @@ package github
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
)
func TestListRepositoryIssuesPaginatesAndLimits(t *testing.T) {
@ -173,6 +177,91 @@ func TestClientErrorAndHelperBranches(t *testing.T) {
}
}
func TestRateLimitRetriesOn403WithRemainingZero(t *testing.T) {
var calls int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if atomic.AddInt32(&calls, 1) == 1 {
w.Header().Set("X-RateLimit-Remaining", "0")
w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", time.Now().Unix()))
http.Error(w, "rate limited", http.StatusForbidden)
return
}
_ = json.NewEncoder(w).Encode(map[string]any{"id": 1})
}))
defer server.Close()
client := New(Options{BaseURL: server.URL, PageDelay: -1})
row, err := client.GetRepo(context.Background(), "openclaw", "gitcrawl", nil)
if err != nil {
t.Fatalf("get repo: %v", err)
}
if intValue(row["id"]) != 1 {
t.Fatalf("row = %#v", row)
}
if got := atomic.LoadInt32(&calls); got != 2 {
t.Fatalf("calls = %d want 2", got)
}
}
func TestRateLimitRetriesOn429WithRetryAfter(t *testing.T) {
var calls int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if atomic.AddInt32(&calls, 1) == 1 {
w.Header().Set("Retry-After", "1")
http.Error(w, "slow down", http.StatusTooManyRequests)
return
}
_ = json.NewEncoder(w).Encode(map[string]any{"id": 2})
}))
defer server.Close()
client := New(Options{BaseURL: server.URL, PageDelay: -1})
row, err := client.GetRepo(context.Background(), "openclaw", "gitcrawl", nil)
if err != nil {
t.Fatalf("get repo: %v", err)
}
if intValue(row["id"]) != 2 {
t.Fatalf("row = %#v", row)
}
}
func TestRateLimitRespectsContextCancellation(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-RateLimit-Remaining", "0")
w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", time.Now().Add(time.Hour).Unix()))
http.Error(w, "rate limited", http.StatusForbidden)
}))
defer server.Close()
client := New(Options{BaseURL: server.URL, PageDelay: -1})
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
_, err := client.GetRepo(ctx, "openclaw", "gitcrawl", nil)
if err == nil {
t.Fatal("expected error")
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("err = %v", err)
}
}
func TestNonRateLimit403IsNotRetried(t *testing.T) {
var calls int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&calls, 1)
http.Error(w, "forbidden", http.StatusForbidden)
}))
defer server.Close()
client := New(Options{BaseURL: server.URL, PageDelay: -1})
if _, err := client.GetRepo(context.Background(), "openclaw", "gitcrawl", nil); err == nil {
t.Fatal("expected error")
}
if got := atomic.LoadInt32(&calls); got != 1 {
t.Fatalf("calls = %d want 1", got)
}
}
func serverURL(r *http.Request) string {
scheme := "http"
if r.TLS != nil {