163 lines
4.2 KiB
Go
163 lines
4.2 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"math/rand"
|
|
"net/http"
|
|
"time"
|
|
)
|
|
|
|
// XAIClient talks the OpenAI-compatible Chat Completions endpoint at
|
|
// {base}/chat/completions with a Bearer key.
|
|
type XAIClient struct {
|
|
base string
|
|
key string
|
|
http *http.Client
|
|
maxTry int
|
|
}
|
|
|
|
func NewXAIClient(base, key string) *XAIClient {
|
|
return &XAIClient{
|
|
base: base,
|
|
key: key,
|
|
http: &http.Client{},
|
|
maxTry: 3,
|
|
}
|
|
}
|
|
|
|
type xaiMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
type xaiRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []xaiMessage `json:"messages"`
|
|
MaxTokens int `json:"max_tokens"`
|
|
Temperature float64 `json:"temperature"`
|
|
Stream bool `json:"stream"`
|
|
}
|
|
|
|
type xaiUsage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
PromptTokensDetails struct {
|
|
CachedTokens int `json:"cached_tokens"`
|
|
} `json:"prompt_tokens_details"`
|
|
}
|
|
|
|
type xaiResponse struct {
|
|
Choices []struct {
|
|
Message struct {
|
|
Content string `json:"content"`
|
|
} `json:"message"`
|
|
FinishReason string `json:"finish_reason"`
|
|
} `json:"choices"`
|
|
Usage xaiUsage `json:"usage"`
|
|
}
|
|
|
|
func (r *xaiResponse) Text() string {
|
|
if len(r.Choices) == 0 {
|
|
return ""
|
|
}
|
|
return r.Choices[0].Message.Content
|
|
}
|
|
|
|
// Complete calls Chat Completions with at-most-once retry on transient failures
|
|
// (429 / 5xx / network timeout, exponential backoff + jitter). Non-retryable 4xx
|
|
// fail immediately. The caller advances since/seen only AFTER this returns so a
|
|
// transient failure isn't silently swallowed by a moved cursor (F6).
|
|
func (x *XAIClient) Complete(ctx context.Context, model string, msgs []xaiMessage, maxTokens int, temp float64) (*xaiResponse, error) {
|
|
reqBody := xaiRequest{
|
|
Model: model,
|
|
Messages: msgs,
|
|
MaxTokens: maxTokens,
|
|
Temperature: temp,
|
|
Stream: false,
|
|
}
|
|
payload, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var lastErr error
|
|
for attempt := 0; attempt < x.maxTry; attempt++ {
|
|
if attempt > 0 {
|
|
// 0.5s, 1s, 2s … capped at 8s, plus up to 250ms jitter.
|
|
backoff := time.Duration(500<<uint(attempt-1)) * time.Millisecond
|
|
if backoff > 8*time.Second {
|
|
backoff = 8 * time.Second
|
|
}
|
|
backoff += time.Duration(rand.Intn(250)) * time.Millisecond
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-time.After(backoff):
|
|
}
|
|
}
|
|
|
|
resp, retryable, err := x.attempt(ctx, payload)
|
|
if err == nil {
|
|
return resp, nil
|
|
}
|
|
lastErr = err
|
|
if ctx.Err() != nil {
|
|
return nil, ctx.Err()
|
|
}
|
|
if !retryable {
|
|
return nil, err
|
|
}
|
|
}
|
|
return nil, fmt.Errorf("xai: exhausted %d attempts: %w", x.maxTry, lastErr)
|
|
}
|
|
|
|
// attempt performs one HTTP call. It returns retryable=true for 429/5xx and
|
|
// network errors, false for other non-2xx (terminal 4xx).
|
|
func (x *XAIClient) attempt(ctx context.Context, payload []byte) (*xaiResponse, bool, error) {
|
|
// Per-attempt deadline so a hung connection doesn't block the whole loop.
|
|
attemptCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
|
|
defer cancel()
|
|
|
|
req, err := http.NewRequestWithContext(attemptCtx, http.MethodPost, x.base+"/chat/completions", bytes.NewReader(payload))
|
|
if err != nil {
|
|
return nil, false, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+x.key)
|
|
|
|
resp, err := x.http.Do(req)
|
|
if err != nil {
|
|
// Network error / timeout — retryable (unless the parent ctx is done).
|
|
return nil, ctx.Err() == nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
data, _ := io.ReadAll(resp.Body)
|
|
|
|
if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500 {
|
|
return nil, true, fmt.Errorf("xai http %d: %s", resp.StatusCode, snippet(data))
|
|
}
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return nil, false, fmt.Errorf("xai http %d: %s", resp.StatusCode, snippet(data))
|
|
}
|
|
|
|
var out xaiResponse
|
|
if err := json.Unmarshal(data, &out); err != nil {
|
|
return nil, false, fmt.Errorf("xai decode: %w", err)
|
|
}
|
|
if out.Text() == "" {
|
|
return nil, false, fmt.Errorf("xai returned no choices")
|
|
}
|
|
return &out, false, nil
|
|
}
|
|
|
|
func snippet(b []byte) string {
|
|
const max = 300
|
|
if len(b) > max {
|
|
return string(b[:max]) + "…"
|
|
}
|
|
return string(b)
|
|
}
|