package main import ( "bytes" "context" "encoding/json" "fmt" "io" "log/slog" "math/rand" "net/http" "time" ) // httpllm.go is the shared OpenAI-compatible Chat Completions transport: one // HTTP+retry implementation reused by every provider adapter. Grok and Gemini both // expose this wire format, so the retry/backoff classification (429/5xx/network = // retryable, other 4xx = terminal) lives once here, parameterised by base/key/ // headers, instead of being copied per provider. // openAIClient performs OpenAI-compatible /chat/completions calls with retry. type openAIClient struct { name string // provider label for logs/errors ("xai", "gemini") base string key string http *http.Client maxTry int headers map[string]string // extra static headers (provider-specific), may be nil log *slog.Logger } func newOpenAIClient(name, base, key string, headers map[string]string, logger *slog.Logger) *openAIClient { return &openAIClient{ name: name, base: base, key: key, http: &http.Client{}, maxTry: 3, headers: headers, log: logger, } } // --- OpenAI-compatible wire types ------------------------------------------------- type openAIMessage struct { Role string `json:"role"` Content string `json:"content"` } // openAITool is the wire shape of a model tool (e.g. web search). Only serialized // when the request carries tools, so a plain completion's body is unchanged. type openAITool struct { Type string `json:"type"` } type openAIRequest struct { Model string `json:"model"` Messages []openAIMessage `json:"messages"` MaxTokens int `json:"max_tokens"` Temperature float64 `json:"temperature"` Stream bool `json:"stream"` // Optional; omitempty keeps the grok_direct body byte-identical to before. Tools []openAITool `json:"tools,omitempty"` ReasoningEffort string `json:"reasoning_effort,omitempty"` // SearchParameters drives xAI Live Search on chat/completions (the web route's // grok_web_search provider). nil for every non-web call, so it serializes away. SearchParameters any `json:"search_parameters,omitempty"` } type openAIUsage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` PromptTokensDetails struct { CachedTokens int `json:"cached_tokens"` } `json:"prompt_tokens_details"` } type openAIResponse struct { ID string `json:"id"` Choices []struct { Message struct { Content string `json:"content"` } `json:"message"` FinishReason string `json:"finish_reason"` } `json:"choices"` Usage openAIUsage `json:"usage"` // Citations is the source list xAI Live Search returns by default (absent on a // non-web call → nil). Citations []string `json:"citations"` } func (r *openAIResponse) Text() string { if len(r.Choices) == 0 { return "" } return r.Choices[0].Message.Content } // complete calls Chat Completions with retry on transient failures (429 / 5xx / // network timeout, exponential backoff + jitter). Non-retryable 4xx fail // immediately. On exhaustion the caller refunds the reserved request and notifies // the user, so a transient failure is never silently swallowed (F6). reqHeaders are // per-request headers (e.g. x-grok-conv-id) merged on top of the static ones; nil is // fine. func (c *openAIClient) complete(ctx context.Context, reqBody openAIRequest, reqHeaders map[string]string) (*openAIResponse, error) { payload, err := json.Marshal(reqBody) if err != nil { return nil, err } var lastErr error for attempt := 0; attempt < c.maxTry; attempt++ { if attempt > 0 { // 0.5s, 1s, 2s … capped at 8s, plus up to 250ms jitter. backoff := time.Duration(500< 8*time.Second { backoff = 8 * time.Second } backoff += time.Duration(rand.Intn(250)) * time.Millisecond select { case <-ctx.Done(): return nil, ctx.Err() case <-time.After(backoff): } } resp, retryable, err := c.attempt(ctx, payload, reqHeaders) if err == nil { return resp, nil } lastErr = err if ctx.Err() != nil { return nil, ctx.Err() } if !retryable { return nil, err } if c.log != nil { c.log.Warn(c.name+" attempt failed, will retry", "attempt", attempt+1, "max", c.maxTry, "err", err) } } return nil, fmt.Errorf("%s: exhausted %d attempts: %w", c.name, c.maxTry, lastErr) } // attempt performs one HTTP call. It returns retryable=true for 429/5xx and // network errors, false for other non-2xx (terminal 4xx). The per-attempt deadline // bounds a single hung connection; the overall per-request deadline (set by the // caller via ctx) bounds the whole retry loop so a cascade can't accrete minutes. func (c *openAIClient) attempt(ctx context.Context, payload []byte, reqHeaders map[string]string) (*openAIResponse, bool, error) { attemptCtx, cancel := context.WithTimeout(ctx, 60*time.Second) defer cancel() req, err := http.NewRequestWithContext(attemptCtx, http.MethodPost, c.base+"/chat/completions", bytes.NewReader(payload)) if err != nil { return nil, false, err } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+c.key) for k, v := range c.headers { req.Header.Set(k, v) } for k, v := range reqHeaders { req.Header.Set(k, v) } resp, err := c.http.Do(req) if err != nil { // Network error / timeout — retryable (unless the parent ctx is done). return nil, ctx.Err() == nil, err } defer resp.Body.Close() data, _ := io.ReadAll(resp.Body) if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500 { return nil, true, fmt.Errorf("%s http %d: %s", c.name, resp.StatusCode, snippet(data)) } if resp.StatusCode < 200 || resp.StatusCode >= 300 { return nil, false, fmt.Errorf("%s http %d: %s", c.name, resp.StatusCode, snippet(data)) } var out openAIResponse if err := json.Unmarshal(data, &out); err != nil { return nil, false, fmt.Errorf("%s decode: %w", c.name, err) } // A 2xx is a billed call even when the model returns empty content (content // filter, finish_reason=length with no text, or no choices). Return it as a // success so the caller books the real cost via the ledger instead of refunding // the slot and losing the spend — which would let empty replies bypass BOTH the // per-user cap and the global ceiling. The caller just won't send an empty body. return &out, false, nil } func snippet(b []byte) string { const max = 300 if len(b) > max { return string(b[:max]) + "…" } return string(b) }