200 lines
6.5 KiB
Go
200 lines
6.5 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"math/rand"
|
|
"net/http"
|
|
"time"
|
|
)
|
|
|
|
// httpllm.go is the shared OpenAI-compatible Chat Completions transport: one
|
|
// HTTP+retry implementation reused by every provider adapter. Grok and Gemini both
|
|
// expose this wire format, so the retry/backoff classification (429/5xx/network =
|
|
// retryable, other 4xx = terminal) lives once here, parameterised by base/key/
|
|
// headers, instead of being copied per provider.
|
|
|
|
// openAIClient performs OpenAI-compatible /chat/completions calls with retry.
|
|
type openAIClient struct {
|
|
name string // provider label for logs/errors ("xai", "gemini")
|
|
base string
|
|
key string
|
|
http *http.Client
|
|
maxTry int
|
|
headers map[string]string // extra static headers (provider-specific), may be nil
|
|
log *slog.Logger
|
|
}
|
|
|
|
func newOpenAIClient(name, base, key string, headers map[string]string, logger *slog.Logger) *openAIClient {
|
|
return &openAIClient{
|
|
name: name,
|
|
base: base,
|
|
key: key,
|
|
http: &http.Client{},
|
|
maxTry: 3,
|
|
headers: headers,
|
|
log: logger,
|
|
}
|
|
}
|
|
|
|
// --- OpenAI-compatible wire types -------------------------------------------------
|
|
|
|
type openAIMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
// openAITool is the wire shape of a model tool (e.g. web search). Only serialized
|
|
// when the request carries tools, so a plain completion's body is unchanged.
|
|
type openAITool struct {
|
|
Type string `json:"type"`
|
|
}
|
|
|
|
type openAIRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []openAIMessage `json:"messages"`
|
|
MaxTokens int `json:"max_tokens"`
|
|
Temperature float64 `json:"temperature"`
|
|
Stream bool `json:"stream"`
|
|
// Optional; omitempty keeps the grok_direct body byte-identical to before.
|
|
Tools []openAITool `json:"tools,omitempty"`
|
|
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
|
// SearchParameters drives xAI Live Search on chat/completions (the web route's
|
|
// grok_web_search provider). nil for every non-web call, so it serializes away.
|
|
SearchParameters any `json:"search_parameters,omitempty"`
|
|
}
|
|
|
|
type openAIUsage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
PromptTokensDetails struct {
|
|
CachedTokens int `json:"cached_tokens"`
|
|
} `json:"prompt_tokens_details"`
|
|
}
|
|
|
|
type openAIResponse struct {
|
|
ID string `json:"id"`
|
|
Choices []struct {
|
|
Message struct {
|
|
Content string `json:"content"`
|
|
} `json:"message"`
|
|
FinishReason string `json:"finish_reason"`
|
|
} `json:"choices"`
|
|
Usage openAIUsage `json:"usage"`
|
|
// Citations is the source list xAI Live Search returns by default (absent on a
|
|
// non-web call → nil).
|
|
Citations []string `json:"citations"`
|
|
}
|
|
|
|
func (r *openAIResponse) Text() string {
|
|
if len(r.Choices) == 0 {
|
|
return ""
|
|
}
|
|
return r.Choices[0].Message.Content
|
|
}
|
|
|
|
// complete calls Chat Completions with retry on transient failures (429 / 5xx /
|
|
// network timeout, exponential backoff + jitter). Non-retryable 4xx fail
|
|
// immediately. On exhaustion the caller refunds the reserved request and notifies
|
|
// the user, so a transient failure is never silently swallowed (F6). reqHeaders are
|
|
// per-request headers (e.g. x-grok-conv-id) merged on top of the static ones; nil is
|
|
// fine.
|
|
func (c *openAIClient) complete(ctx context.Context, reqBody openAIRequest, reqHeaders map[string]string) (*openAIResponse, error) {
|
|
payload, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var lastErr error
|
|
for attempt := 0; attempt < c.maxTry; attempt++ {
|
|
if attempt > 0 {
|
|
// 0.5s, 1s, 2s … capped at 8s, plus up to 250ms jitter.
|
|
backoff := time.Duration(500<<uint(attempt-1)) * time.Millisecond
|
|
if backoff > 8*time.Second {
|
|
backoff = 8 * time.Second
|
|
}
|
|
backoff += time.Duration(rand.Intn(250)) * time.Millisecond
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-time.After(backoff):
|
|
}
|
|
}
|
|
|
|
resp, retryable, err := c.attempt(ctx, payload, reqHeaders)
|
|
if err == nil {
|
|
return resp, nil
|
|
}
|
|
lastErr = err
|
|
if ctx.Err() != nil {
|
|
return nil, ctx.Err()
|
|
}
|
|
if !retryable {
|
|
return nil, err
|
|
}
|
|
if c.log != nil {
|
|
c.log.Warn(c.name+" attempt failed, will retry", "attempt", attempt+1, "max", c.maxTry, "err", err)
|
|
}
|
|
}
|
|
return nil, fmt.Errorf("%s: exhausted %d attempts: %w", c.name, c.maxTry, lastErr)
|
|
}
|
|
|
|
// attempt performs one HTTP call. It returns retryable=true for 429/5xx and
|
|
// network errors, false for other non-2xx (terminal 4xx). The per-attempt deadline
|
|
// bounds a single hung connection; the overall per-request deadline (set by the
|
|
// caller via ctx) bounds the whole retry loop so a cascade can't accrete minutes.
|
|
func (c *openAIClient) attempt(ctx context.Context, payload []byte, reqHeaders map[string]string) (*openAIResponse, bool, error) {
|
|
attemptCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
|
|
defer cancel()
|
|
|
|
req, err := http.NewRequestWithContext(attemptCtx, http.MethodPost, c.base+"/chat/completions", bytes.NewReader(payload))
|
|
if err != nil {
|
|
return nil, false, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+c.key)
|
|
for k, v := range c.headers {
|
|
req.Header.Set(k, v)
|
|
}
|
|
for k, v := range reqHeaders {
|
|
req.Header.Set(k, v)
|
|
}
|
|
|
|
resp, err := c.http.Do(req)
|
|
if err != nil {
|
|
// Network error / timeout — retryable (unless the parent ctx is done).
|
|
return nil, ctx.Err() == nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
data, _ := io.ReadAll(resp.Body)
|
|
|
|
if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500 {
|
|
return nil, true, fmt.Errorf("%s http %d: %s", c.name, resp.StatusCode, snippet(data))
|
|
}
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return nil, false, fmt.Errorf("%s http %d: %s", c.name, resp.StatusCode, snippet(data))
|
|
}
|
|
|
|
var out openAIResponse
|
|
if err := json.Unmarshal(data, &out); err != nil {
|
|
return nil, false, fmt.Errorf("%s decode: %w", c.name, err)
|
|
}
|
|
// A 2xx is a billed call even when the model returns empty content (content
|
|
// filter, finish_reason=length with no text, or no choices). Return it as a
|
|
// success so the caller books the real cost via the ledger instead of refunding
|
|
// the slot and losing the spend — which would let empty replies bypass BOTH the
|
|
// per-user cap and the global ceiling. The caller just won't send an empty body.
|
|
return &out, false, nil
|
|
}
|
|
|
|
func snippet(b []byte) string {
|
|
const max = 300
|
|
if len(b) > max {
|
|
return string(b[:max]) + "…"
|
|
}
|
|
return string(b)
|
|
}
|