vojo/apps/ai-bot/provider_gemini.go

191 lines
6.5 KiB
Go

package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"strings"
"time"
)
// provider_gemini.go is the Gemini backend. Two faces:
//
// - geminiClient: a thin LLMClient over the OpenAI-compatible endpoint, used for the
// cheap trivial route and the Layer-1 router classifier. Same wire format as Grok,
// so it reuses the shared transport (httpllm.go).
// - groundedSearch: a SEPARATE call against the NATIVE v1beta generateContent endpoint
// with the google_search tool. Grounding does NOT work on the OpenAI-compat layer —
// it is silently ignored THERE (F-EXT-3, an endpoint limitation, NOT a model-version
// one: the google_search tool is supported by current models including
// gemini-2.5-flash-lite per ai.google.dev). So the web layer that wants Gemini
// grounding must use this native path and VERIFY citations came back, else degrade.
type geminiClient struct {
http *openAIClient
nativeBase string // …/v1beta — derived from the OpenAI-compat base by dropping /openai
key string
model string
httpc *http.Client
log *slog.Logger
}
// NewGeminiClient builds the Gemini backend. base is the OpenAI-compatible endpoint
// (…/v1beta/openai); the native grounding endpoint is derived from it. Returns the
// concrete type (not just LLMClient) because the web layer needs groundedSearch too.
func NewGeminiClient(base, key, model string, logger *slog.Logger) *geminiClient {
return &geminiClient{
http: newOpenAIClient("gemini", base, key, nil, logger),
nativeBase: strings.TrimSuffix(base, "/openai"),
key: key,
model: model,
httpc: &http.Client{},
log: logger,
}
}
// Complete answers via the OpenAI-compatible endpoint (trivial route + classifier).
func (c *geminiClient) Complete(ctx context.Context, req LLMRequest) (*LLMResponse, error) {
msgs := make([]openAIMessage, len(req.Messages))
for i, m := range req.Messages {
msgs[i] = openAIMessage{Role: m.Role, Content: m.Content}
}
resp, err := c.http.complete(ctx, openAIRequest{
Model: req.Model,
Messages: msgs,
MaxTokens: req.MaxTokens,
Temperature: req.Temperature,
Stream: false,
}, nil)
if err != nil {
return nil, err
}
return &LLMResponse{
Text: resp.Text(),
Usage: Usage{
PromptTokens: resp.Usage.PromptTokens,
CachedTokens: resp.Usage.PromptTokensDetails.CachedTokens,
CompletionTokens: resp.Usage.CompletionTokens,
},
ProviderRequestID: resp.ID,
}, nil
}
// --- native v1beta grounded search (google_search tool) ---------------------------
type geminiGroundResult struct {
Digest string
Citations []string
Usage Usage
}
// native generateContent wire types (only the fields we read/write).
type geminiNativeRequest struct {
Contents []geminiContent `json:"contents"`
Tools []geminiTool `json:"tools"`
}
type geminiContent struct {
Role string `json:"role,omitempty"`
Parts []geminiPart `json:"parts"`
}
type geminiPart struct {
Text string `json:"text"`
}
type geminiTool struct {
// google_search is the current grounding tool (all current models, incl. the 2.5
// family; legacy models used google_search_retrieval). The empty object enables it.
GoogleSearch struct{} `json:"google_search"`
}
type geminiNativeResponse struct {
Candidates []struct {
Content struct {
Parts []geminiPart `json:"parts"`
} `json:"content"`
GroundingMetadata struct {
GroundingChunks []struct {
Web struct {
URI string `json:"uri"`
Title string `json:"title"`
} `json:"web"`
} `json:"groundingChunks"`
} `json:"groundingMetadata"`
} `json:"candidates"`
UsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
CachedContentTokenCount int `json:"cachedContentTokenCount"`
} `json:"usageMetadata"`
}
// groundedSearch runs one grounded generateContent against the native endpoint and
// returns the model's grounded answer plus the source URLs. It REQUIRES citations:
// if groundingMetadata has no chunks the request was not actually grounded (the
// silent-ignore failure mode, F-EXT-3), so it errors and the caller degrades rather
// than passing off ungrounded — possibly stale — text as fresh.
func (c *geminiClient) groundedSearch(ctx context.Context, query string) (geminiGroundResult, error) {
body, err := json.Marshal(geminiNativeRequest{
Contents: []geminiContent{{Role: "user", Parts: []geminiPart{{Text: query}}}},
Tools: []geminiTool{{}},
})
if err != nil {
return geminiGroundResult{}, err
}
// API key in the query string is the native v1beta convention.
endpoint := fmt.Sprintf("%s/models/%s:generateContent?key=%s",
c.nativeBase, url.PathEscape(c.model), url.QueryEscape(c.key))
reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second) // web/grounding budget (§8.2.2)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return geminiGroundResult{}, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpc.Do(req)
if err != nil {
return geminiGroundResult{}, err
}
defer resp.Body.Close()
data, _ := io.ReadAll(resp.Body)
logLLMExchange(ctx, c.log, "gemini_grounding", body, resp.StatusCode, data)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return geminiGroundResult{}, fmt.Errorf("gemini grounding http %d: %s", resp.StatusCode, snippet(data))
}
var out geminiNativeResponse
if err := json.Unmarshal(data, &out); err != nil {
return geminiGroundResult{}, fmt.Errorf("gemini grounding decode: %w", err)
}
if len(out.Candidates) == 0 {
return geminiGroundResult{}, fmt.Errorf("gemini grounding: no candidates")
}
var sb strings.Builder
for _, p := range out.Candidates[0].Content.Parts {
sb.WriteString(p.Text)
}
var citations []string
for _, ch := range out.Candidates[0].GroundingMetadata.GroundingChunks {
if ch.Web.URI != "" {
citations = append(citations, ch.Web.URI)
}
}
// The verify-gate: no citations ⇒ not actually grounded ⇒ degrade.
if len(citations) == 0 {
return geminiGroundResult{}, fmt.Errorf("gemini grounding: no citations (ungrounded — degrade)")
}
return geminiGroundResult{
Digest: strings.TrimSpace(sb.String()),
Citations: citations,
Usage: Usage{
PromptTokens: out.UsageMetadata.PromptTokenCount,
CachedTokens: out.UsageMetadata.CachedContentTokenCount,
CompletionTokens: out.UsageMetadata.CandidatesTokenCount,
},
}, nil
}