62 lines
1.9 KiB
Go
62 lines
1.9 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"log/slog"
|
|
)
|
|
|
|
// provider_xai.go is the thin adapter for xAI's Grok backend. xAI speaks the
|
|
// OpenAI Chat Completions wire format, so this is a shell over the shared
|
|
// openAIClient transport (httpllm.go): it only maps the neutral LLMRequest/
|
|
// LLMResponse to/from the wire types. Any xAI-specific request shaping would live
|
|
// here, but Grok needs none today.
|
|
type xaiClient struct {
|
|
http *openAIClient
|
|
}
|
|
|
|
// NewXAIClient builds the Grok backend. Returns the neutral LLMClient so the bot
|
|
// holds no vendor type.
|
|
func NewXAIClient(base, key string, logger *slog.Logger) LLMClient {
|
|
return &xaiClient{http: newOpenAIClient("xai", base, key, nil, logger)}
|
|
}
|
|
|
|
func (c *xaiClient) Complete(ctx context.Context, req LLMRequest) (*LLMResponse, error) {
|
|
msgs := make([]openAIMessage, len(req.Messages))
|
|
for i, m := range req.Messages {
|
|
msgs[i] = openAIMessage{Role: m.Role, Content: m.Content}
|
|
}
|
|
var tools []openAITool
|
|
for _, t := range req.Tools {
|
|
tools = append(tools, openAITool{Type: t.Type})
|
|
}
|
|
|
|
// x-grok-conv-id pins this conversation to one backend to raise the prompt-cache
|
|
// hit rate (caching itself is automatic on xAI). Only sent when set, so the
|
|
// default path adds no header.
|
|
var headers map[string]string
|
|
if req.ConvID != "" {
|
|
headers = map[string]string{"x-grok-conv-id": req.ConvID}
|
|
}
|
|
|
|
resp, err := c.http.complete(ctx, openAIRequest{
|
|
Model: req.Model,
|
|
Messages: msgs,
|
|
MaxTokens: req.MaxTokens,
|
|
Temperature: req.Temperature,
|
|
Stream: false,
|
|
Tools: tools,
|
|
ReasoningEffort: req.ReasoningEffort,
|
|
}, headers)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &LLMResponse{
|
|
Text: resp.Text(),
|
|
Usage: Usage{
|
|
PromptTokens: resp.Usage.PromptTokens,
|
|
CachedTokens: resp.Usage.PromptTokensDetails.CachedTokens,
|
|
CompletionTokens: resp.Usage.CompletionTokens,
|
|
},
|
|
ProviderRequestID: resp.ID,
|
|
}, nil
|
|
}
|