diff --git a/cmd/eval/eval.go b/cmd/eval/eval.go index 5a6b39c2..902ca4ca 100644 --- a/cmd/eval/eval.go +++ b/cmd/eval/eval.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "strings" + "time" "github.com/MakeNowJust/heredoc" "github.com/github/gh-models/internal/azuremodels" @@ -80,6 +81,8 @@ func NewEvalCommand(cfg *command.Config) *cobra.Command { By default, results are displayed in a human-readable format. Use the --json flag to output structured JSON data for programmatic use or integration with CI/CD pipelines. + This command will automatically retry on rate limiting errors, waiting for the specified + duration before retrying the request. See https://docs.github.com/github-models/use-github-models/storing-prompts-in-github-repositories#supported-file-format for more information. `), @@ -327,36 +330,65 @@ func (h *evalCommandHandler) templateString(templateStr string, data map[string] return prompt.TemplateString(templateStr, data) } -func (h *evalCommandHandler) callModel(ctx context.Context, messages []azuremodels.ChatMessage) (string, error) { - req := h.evalFile.BuildChatCompletionOptions(messages) - - resp, err := h.client.GetChatCompletionStream(ctx, req, h.org) - if err != nil { - return "", err - } +// callModelWithRetry makes an API call with automatic retry on rate limiting +func (h *evalCommandHandler) callModelWithRetry(ctx context.Context, req azuremodels.ChatCompletionOptions) (string, error) { + const maxRetries = 3 - // For non-streaming requests, we should get a single response - var content strings.Builder - for { - completion, err := resp.Reader.Read() + for attempt := 0; attempt <= maxRetries; attempt++ { + resp, err := h.client.GetChatCompletionStream(ctx, req, h.org) if err != nil { - if errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "EOF") { - break + var rateLimitErr *azuremodels.RateLimitError + if errors.As(err, &rateLimitErr) { + if attempt < maxRetries { + if !h.jsonOutput { + h.cfg.WriteToOut(fmt.Sprintf(" Rate limited, waiting %v before retry (attempt %d/%d)...\n", + rateLimitErr.RetryAfter, attempt+1, maxRetries+1)) + } + + // Wait for the specified duration + select { + case <-ctx.Done(): + return "", ctx.Err() + case <-time.After(rateLimitErr.RetryAfter): + continue + } + } + return "", fmt.Errorf("rate limit exceeded after %d attempts: %w", attempt+1, err) } + // For non-rate-limit errors, return immediately return "", err } - for _, choice := range completion.Choices { - if choice.Delta != nil && choice.Delta.Content != nil { - content.WriteString(*choice.Delta.Content) + var content strings.Builder + for { + completion, err := resp.Reader.Read() + if err != nil { + if errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "EOF") { + break + } + return "", err } - if choice.Message != nil && choice.Message.Content != nil { - content.WriteString(*choice.Message.Content) + + for _, choice := range completion.Choices { + if choice.Delta != nil && choice.Delta.Content != nil { + content.WriteString(*choice.Delta.Content) + } + if choice.Message != nil && choice.Message.Content != nil { + content.WriteString(*choice.Message.Content) + } } } + + return strings.TrimSpace(content.String()), nil } - return strings.TrimSpace(content.String()), nil + // This should never be reached, but just in case + return "", errors.New("unexpected error calling model") +} + +func (h *evalCommandHandler) callModel(ctx context.Context, messages []azuremodels.ChatMessage) (string, error) { + req := h.evalFile.BuildChatCompletionOptions(messages) + return h.callModelWithRetry(ctx, req) } func (h *evalCommandHandler) runEvaluators(ctx context.Context, testCase map[string]interface{}, response string) ([]EvaluationResult, error) { @@ -437,7 +469,6 @@ func (h *evalCommandHandler) runStringEvaluator(name string, eval prompt.StringE } func (h *evalCommandHandler) runLLMEvaluator(ctx context.Context, name string, eval prompt.LLMEvaluator, testCase map[string]interface{}, response string) (EvaluationResult, error) { - // Template the evaluation prompt evalData := make(map[string]interface{}) for k, v := range testCase { evalData[k] = v @@ -449,7 +480,6 @@ func (h *evalCommandHandler) runLLMEvaluator(ctx context.Context, name string, e return EvaluationResult{}, fmt.Errorf("failed to template evaluation prompt: %w", err) } - // Prepare messages for evaluation var messages []azuremodels.ChatMessage if eval.SystemPrompt != "" { messages = append(messages, azuremodels.ChatMessage{ @@ -462,40 +492,19 @@ func (h *evalCommandHandler) runLLMEvaluator(ctx context.Context, name string, e Content: util.Ptr(promptContent), }) - // Call the evaluation model req := azuremodels.ChatCompletionOptions{ Messages: messages, Model: eval.ModelID, Stream: false, } - resp, err := h.client.GetChatCompletionStream(ctx, req, h.org) + evalResponseText, err := h.callModelWithRetry(ctx, req) if err != nil { return EvaluationResult{}, fmt.Errorf("failed to call evaluation model: %w", err) } - var evalResponse strings.Builder - for { - completion, err := resp.Reader.Read() - if err != nil { - if errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "EOF") { - break - } - return EvaluationResult{}, err - } - - for _, choice := range completion.Choices { - if choice.Delta != nil && choice.Delta.Content != nil { - evalResponse.WriteString(*choice.Delta.Content) - } - if choice.Message != nil && choice.Message.Content != nil { - evalResponse.WriteString(*choice.Message.Content) - } - } - } - // Match response to choices - evalResponseText := strings.TrimSpace(strings.ToLower(evalResponse.String())) + evalResponseText = strings.TrimSpace(strings.ToLower(evalResponseText)) for _, choice := range eval.Choices { if strings.Contains(evalResponseText, strings.ToLower(choice.Choice)) { return EvaluationResult{ diff --git a/internal/azuremodels/azure_client.go b/internal/azuremodels/azure_client.go index 76eb537d..3f8c0beb 100644 --- a/internal/azuremodels/azure_client.go +++ b/internal/azuremodels/azure_client.go @@ -10,7 +10,9 @@ import ( "io" "net/http" "slices" + "strconv" "strings" + "time" "github.com/cli/go-gh/v2/pkg/api" "github.com/github/gh-models/internal/modelkey" @@ -259,6 +261,42 @@ func (c *AzureClient) handleHTTPError(resp *http.Response) error { return err } + case http.StatusTooManyRequests: + // Handle rate limiting + retryAfter := time.Duration(0) + + // Check for x-ratelimit-timeremaining header (in seconds) + if timeRemainingStr := resp.Header.Get("x-ratelimit-timeremaining"); timeRemainingStr != "" { + if seconds, parseErr := strconv.Atoi(timeRemainingStr); parseErr == nil { + retryAfter = time.Duration(seconds) * time.Second + } + } + + // Fall back to standard Retry-After header if x-ratelimit-timeremaining is not available + if retryAfter == 0 { + if retryAfterStr := resp.Header.Get("Retry-After"); retryAfterStr != "" { + if seconds, parseErr := strconv.Atoi(retryAfterStr); parseErr == nil { + retryAfter = time.Duration(seconds) * time.Second + } + } + } + + // Default to 60 seconds if no retry-after information is provided + if retryAfter == 0 { + retryAfter = 60 * time.Second + } + + body, _ := io.ReadAll(resp.Body) + message := "rate limit exceeded" + if len(body) > 0 { + message = string(body) + } + + return &RateLimitError{ + RetryAfter: retryAfter, + Message: strings.TrimSpace(message), + } + default: _, err = sb.WriteString("unexpected response from the server: " + resp.Status) if err != nil { @@ -286,3 +324,13 @@ func (c *AzureClient) handleHTTPError(resp *http.Response) error { return errors.New(sb.String()) } + +// RateLimitError represents a rate limiting error from the API +type RateLimitError struct { + RetryAfter time.Duration + Message string +} + +func (e *RateLimitError) Error() string { + return fmt.Sprintf("rate limited: %s (retry after %v)", e.Message, e.RetryAfter) +} diff --git a/internal/azuremodels/rate_limit_test.go b/internal/azuremodels/rate_limit_test.go new file mode 100644 index 00000000..10792016 --- /dev/null +++ b/internal/azuremodels/rate_limit_test.go @@ -0,0 +1,109 @@ +package azuremodels + +import ( + "net/http" + "strings" + "testing" + "time" +) + +func TestRateLimitError(t *testing.T) { + err := &RateLimitError{ + RetryAfter: 30 * time.Second, + Message: "Too many requests", + } + + expected := "rate limited: Too many requests (retry after 30s)" + if err.Error() != expected { + t.Errorf("Expected error message %q, got %q", expected, err.Error()) + } +} + +func TestHandleHTTPError_RateLimit(t *testing.T) { + client := &AzureClient{} + + tests := []struct { + name string + statusCode int + headers map[string]string + expectedRetryAfter time.Duration + }{ + { + name: "Rate limit with x-ratelimit-timeremaining header", + statusCode: http.StatusTooManyRequests, + headers: map[string]string{ + "x-ratelimit-timeremaining": "45", + }, + expectedRetryAfter: 45 * time.Second, + }, + { + name: "Rate limit with Retry-After header", + statusCode: http.StatusTooManyRequests, + headers: map[string]string{ + "Retry-After": "60", + }, + expectedRetryAfter: 60 * time.Second, + }, + { + name: "Rate limit with both headers - x-ratelimit-timeremaining takes precedence", + statusCode: http.StatusTooManyRequests, + headers: map[string]string{ + "x-ratelimit-timeremaining": "30", + "Retry-After": "90", + }, + expectedRetryAfter: 30 * time.Second, + }, + { + name: "Rate limit with no headers - default to 60s", + statusCode: http.StatusTooManyRequests, + headers: map[string]string{}, + expectedRetryAfter: 60 * time.Second, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp := &http.Response{ + StatusCode: tt.statusCode, + Header: make(http.Header), + Body: &mockReadCloser{reader: strings.NewReader("rate limit exceeded")}, + } + + for key, value := range tt.headers { + resp.Header.Set(key, value) + } + + err := client.handleHTTPError(resp) + + var rateLimitErr *RateLimitError + if !isRateLimitError(err, &rateLimitErr) { + t.Fatalf("Expected RateLimitError, got %T: %v", err, err) + } + + if rateLimitErr.RetryAfter != tt.expectedRetryAfter { + t.Errorf("Expected RetryAfter %v, got %v", tt.expectedRetryAfter, rateLimitErr.RetryAfter) + } + }) + } +} + +// Helper function to check if error is a RateLimitError (mimics errors.As) +func isRateLimitError(err error, target **RateLimitError) bool { + if rateLimitErr, ok := err.(*RateLimitError); ok { + *target = rateLimitErr + return true + } + return false +} + +type mockReadCloser struct { + reader *strings.Reader +} + +func (m *mockReadCloser) Read(p []byte) (n int, err error) { + return m.reader.Read(p) +} + +func (m *mockReadCloser) Close() error { + return nil +}