gogcli/internal/googleapi/transport.go
2025-12-31 19:47:32 +01:00

239 lines
4.7 KiB
Go

package googleapi
import (
"context"
"fmt"
"io"
"log/slog"
"math/rand/v2"
"net/http"
"strconv"
"time"
)
// RetryTransport wraps an http.RoundTripper with retry logic for
// rate limits (429) and server errors (5xx).
type RetryTransport struct {
Base http.RoundTripper
MaxRetries429 int
MaxRetries5xx int
BaseDelay time.Duration
CircuitBreaker *CircuitBreaker
}
// NewRetryTransport creates a RetryTransport with sensible defaults.
func NewRetryTransport(base http.RoundTripper) *RetryTransport {
if base == nil {
base = http.DefaultTransport
}
return &RetryTransport{
Base: base,
MaxRetries429: MaxRateLimitRetries,
MaxRetries5xx: Max5xxRetries,
BaseDelay: RateLimitBaseDelay,
CircuitBreaker: NewCircuitBreaker(),
}
}
// RoundTrip implements http.RoundTripper with retry logic.
func (t *RetryTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if t.CircuitBreaker != nil && t.CircuitBreaker.IsOpen() {
return nil, &CircuitBreakerError{}
}
if err := ensureReplayableBody(req); err != nil {
return nil, err
}
var resp *http.Response
var err error
retries429 := 0
retries5xx := 0
for {
// Reset body for retry
if req.GetBody != nil {
if req.Body != nil {
_ = req.Body.Close()
}
if body, getErr := req.GetBody(); getErr != nil {
return nil, fmt.Errorf("reset request body: %w", getErr)
} else {
req.Body = body
}
}
resp, err = t.Base.RoundTrip(req)
if err != nil {
return nil, fmt.Errorf("round trip: %w", err)
}
// Success
if resp.StatusCode < 400 {
if t.CircuitBreaker != nil {
t.CircuitBreaker.RecordSuccess()
}
return resp, nil
}
// Rate limit (429)
if resp.StatusCode == http.StatusTooManyRequests {
if retries429 >= t.MaxRetries429 {
return resp, nil // Return the 429 response after max retries
}
delay := t.calculateBackoff(retries429, resp)
slog.Debug("rate limited, retrying",
"delay", delay,
"attempt", retries429+1,
"max_retries", t.MaxRetries429)
drainAndClose(resp.Body)
if err := t.sleep(req.Context(), delay); err != nil {
return nil, err
}
retries429++
continue
}
// Server error (5xx)
if resp.StatusCode >= 500 {
if t.CircuitBreaker != nil {
t.CircuitBreaker.RecordFailure()
}
if retries5xx >= t.MaxRetries5xx {
return resp, nil
}
slog.Debug("server error, retrying",
"status", resp.StatusCode,
"attempt", retries5xx+1)
drainAndClose(resp.Body)
if err := t.sleep(req.Context(), ServerErrorRetryDelay); err != nil {
return nil, err
}
retries5xx++
continue
}
// Other errors (4xx except 429): don't retry
return resp, nil
}
}
func (t *RetryTransport) calculateBackoff(attempt int, resp *http.Response) time.Duration {
// Check Retry-After header
if retryAfter := resp.Header.Get("Retry-After"); retryAfter != "" {
if seconds, err := strconv.Atoi(retryAfter); err == nil {
if seconds < 0 {
return 0
}
return time.Duration(seconds) * time.Second
}
if t, err := http.ParseTime(retryAfter); err == nil {
d := time.Until(t)
if d < 0 {
return 0
}
return d
}
}
// Exponential backoff with jitter: 1s, 2s, 4s...
if t.BaseDelay <= 0 {
return 0
}
var baseDelay time.Duration
if bd := t.BaseDelay * time.Duration(1<<attempt); bd <= 0 {
return 0
} else {
baseDelay = bd
}
jitterRange := baseDelay / 2
if jitterRange <= 0 {
return baseDelay
}
jitter := time.Duration(rand.Int64N(int64(jitterRange))) //nolint:gosec // non-crypto jitter
return baseDelay + jitter
}
func (t *RetryTransport) sleep(ctx context.Context, d time.Duration) error {
if d <= 0 {
return nil
}
timer := time.NewTimer(d)
defer timer.Stop()
select {
case <-timer.C:
return nil
case <-ctx.Done():
return fmt.Errorf("sleep interrupted: %w", ctx.Err())
}
}
// bytesReader is a simple bytes.Reader replacement to avoid import
type bytesReader struct {
data []byte
pos int
}
func newBytesReader(data []byte) *bytesReader {
return &bytesReader{data: data}
}
func (r *bytesReader) Read(p []byte) (n int, err error) {
if r.pos >= len(r.data) {
return 0, io.EOF
}
n = copy(p, r.data[r.pos:])
r.pos += n
return n, nil
}
func ensureReplayableBody(req *http.Request) error {
if req == nil || req.Body == nil || req.GetBody != nil {
return nil
}
bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
return fmt.Errorf("read request body: %w", err)
}
_ = req.Body.Close()
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(newBytesReader(bodyBytes)), nil
}
req.Body = io.NopCloser(newBytesReader(bodyBytes))
return nil
}
func drainAndClose(body io.ReadCloser) {
if body == nil {
return
}
_, _ = io.Copy(io.Discard, io.LimitReader(body, 1<<20))
_ = body.Close()
}