fix(ai-bot): drop reasoning_effort for models that reject it, healing once and caching it instead of failing every request
This commit is contained in:
parent
9beb5a19bd
commit
b56a47db4d
2 changed files with 194 additions and 7 deletions
|
|
@ -9,6 +9,8 @@ import (
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -27,6 +29,13 @@ type openAIClient struct {
|
||||||
maxTry int
|
maxTry int
|
||||||
headers map[string]string // extra static headers (provider-specific), may be nil
|
headers map[string]string // extra static headers (provider-specific), may be nil
|
||||||
log *slog.Logger
|
log *slog.Logger
|
||||||
|
|
||||||
|
// noReasoningEffort remembers models that 400'd on the reasoning_effort param so we
|
||||||
|
// drop it up front on every later call instead of paying the 400+heal each time. A
|
||||||
|
// reasoning_effort/model mismatch is a config error (operator set both); we heal it
|
||||||
|
// ONCE (and WARN once, in markNoReasoningEffort) rather than per-message. Guarded by mu.
|
||||||
|
mu sync.Mutex
|
||||||
|
noReasoningEffort map[string]bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newOpenAIClient(name, base, key string, headers map[string]string, logger *slog.Logger) *openAIClient {
|
func newOpenAIClient(name, base, key string, headers map[string]string, logger *slog.Logger) *openAIClient {
|
||||||
|
|
@ -38,6 +47,27 @@ func newOpenAIClient(name, base, key string, headers map[string]string, logger *
|
||||||
maxTry: 3,
|
maxTry: 3,
|
||||||
headers: headers,
|
headers: headers,
|
||||||
log: logger,
|
log: logger,
|
||||||
|
noReasoningEffort: map[string]bool{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// rejectsReasoningEffort reports whether a prior call already learned this model 400s on
|
||||||
|
// the reasoning_effort param (so we omit it up front).
|
||||||
|
func (c *openAIClient) rejectsReasoningEffort(model string) bool {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
return c.noReasoningEffort[model]
|
||||||
|
}
|
||||||
|
|
||||||
|
// markNoReasoningEffort records that a model rejects reasoning_effort and WARNs exactly
|
||||||
|
// once (the first time), so the operator sees the config mismatch without a per-message log.
|
||||||
|
func (c *openAIClient) markNoReasoningEffort(ctx context.Context, model string) {
|
||||||
|
c.mu.Lock()
|
||||||
|
first := !c.noReasoningEffort[model]
|
||||||
|
c.noReasoningEffort[model] = true
|
||||||
|
c.mu.Unlock()
|
||||||
|
if first && c.log != nil {
|
||||||
|
c.log.WarnContext(ctx, c.name+": model rejects reasoning_effort; dropping it for this model — unset GROK_REASONING_EFFORT or use a model that supports it", "model", model)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -104,6 +134,11 @@ func (r *openAIResponse) Text() string {
|
||||||
// per-request headers (e.g. x-grok-conv-id) merged on top of the static ones; nil is
|
// per-request headers (e.g. x-grok-conv-id) merged on top of the static ones; nil is
|
||||||
// fine.
|
// fine.
|
||||||
func (c *openAIClient) complete(ctx context.Context, reqBody openAIRequest, reqHeaders map[string]string) (*openAIResponse, error) {
|
func (c *openAIClient) complete(ctx context.Context, reqBody openAIRequest, reqHeaders map[string]string) (*openAIResponse, error) {
|
||||||
|
// If a prior call already learned this model rejects reasoning_effort, drop it up front
|
||||||
|
// so we never pay the 400+heal again (healed once below; this is the steady state after).
|
||||||
|
if reqBody.ReasoningEffort != "" && c.rejectsReasoningEffort(reqBody.Model) {
|
||||||
|
reqBody.ReasoningEffort = ""
|
||||||
|
}
|
||||||
payload, err := json.Marshal(reqBody)
|
payload, err := json.Marshal(reqBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -134,6 +169,24 @@ func (c *openAIClient) complete(ctx context.Context, reqBody openAIRequest, reqH
|
||||||
return nil, ctx.Err()
|
return nil, ctx.Err()
|
||||||
}
|
}
|
||||||
if !retryable {
|
if !retryable {
|
||||||
|
// Self-heal (first time only): a model that doesn't support reasoning_effort rejects
|
||||||
|
// it with a 400. Remember the model (so every later call drops the param up front —
|
||||||
|
// see top of complete), then strip it and retry ONCE, immediately and inline — the
|
||||||
|
// error is deterministic, so a backoff buys nothing, and the retry must NOT depend on
|
||||||
|
// a remaining loop slot (else a 400 on the final attempt would never be re-tried).
|
||||||
|
// markNoReasoningEffort WARNs once. This lets switching XAI_MODEL to such a model
|
||||||
|
// degrade gracefully instead of hard-failing every request into a react.
|
||||||
|
if reqBody.ReasoningEffort != "" && isReasoningEffortUnsupported(err) {
|
||||||
|
reqBody.ReasoningEffort = ""
|
||||||
|
c.markNoReasoningEffort(ctx, reqBody.Model)
|
||||||
|
if p, mErr := json.Marshal(reqBody); mErr == nil {
|
||||||
|
resp, _, rErr := c.attempt(ctx, p, reqHeaders)
|
||||||
|
if rErr != nil {
|
||||||
|
return nil, rErr
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if c.log != nil {
|
if c.log != nil {
|
||||||
|
|
@ -195,6 +248,18 @@ func (c *openAIClient) attempt(ctx context.Context, payload []byte, reqHeaders m
|
||||||
return &out, false, nil
|
return &out, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isReasoningEffortUnsupported reports whether an xAI error is the specific 400 a
|
||||||
|
// non-reasoning model returns when sent reasoning_effort ("...does not support parameter
|
||||||
|
// reasoningEffort"). Matched loosely so both reasoning_effort and reasoningEffort spellings
|
||||||
|
// trip it, gating the one-shot strip-and-retry in complete().
|
||||||
|
func isReasoningEffortUnsupported(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s := strings.ToLower(err.Error())
|
||||||
|
return strings.Contains(s, "reasoning") && strings.Contains(s, "effort") && strings.Contains(s, "support")
|
||||||
|
}
|
||||||
|
|
||||||
func snippet(b []byte) string {
|
func snippet(b []byte) string {
|
||||||
const max = 300
|
const max = 300
|
||||||
if len(b) > max {
|
if len(b) > max {
|
||||||
|
|
|
||||||
122
apps/ai-bot/httpllm_test.go
Normal file
122
apps/ai-bot/httpllm_test.go
Normal file
|
|
@ -0,0 +1,122 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestCompleteReasoningEffortSelfHeal verifies that when a model rejects the
|
||||||
|
// reasoning_effort param (the HTTP 400 a non-reasoning Grok model returns), the transport
|
||||||
|
// strips the param and retries once — so switching XAI_MODEL to a non-reasoning model
|
||||||
|
// degrades gracefully instead of hard-failing every request into a react.
|
||||||
|
func TestCompleteReasoningEffortSelfHeal(t *testing.T) {
|
||||||
|
var calls int
|
||||||
|
var sawEffortFirst, sawEffortSecond bool
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
hasEffort := strings.Contains(string(body), "reasoning_effort")
|
||||||
|
calls++
|
||||||
|
if calls == 1 {
|
||||||
|
sawEffortFirst = hasEffort
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
io.WriteString(w, `{"code":"Client specified an invalid argument","error":"Model grok-x-non-reasoning does not support parameter reasoningEffort."}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sawEffortSecond = hasEffort
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
io.WriteString(w, `{"id":"ok","choices":[{"message":{"content":"hi"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1}}`)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := newOpenAIClient("xai", srv.URL, "key", nil, discardLog())
|
||||||
|
resp, err := c.complete(context.Background(), openAIRequest{
|
||||||
|
Model: "grok-x-non-reasoning",
|
||||||
|
Messages: []openAIMessage{{Role: "user", Content: "hi"}},
|
||||||
|
MaxTokens: 10,
|
||||||
|
Temperature: 0.6,
|
||||||
|
ReasoningEffort: "low",
|
||||||
|
}, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("complete returned error, want self-heal success: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Text() != "hi" {
|
||||||
|
t.Fatalf("got %q, want %q", resp.Text(), "hi")
|
||||||
|
}
|
||||||
|
if calls != 2 {
|
||||||
|
t.Fatalf("expected exactly 2 calls (400, then stripped retry), got %d", calls)
|
||||||
|
}
|
||||||
|
if !sawEffortFirst {
|
||||||
|
t.Fatal("first call should have sent reasoning_effort")
|
||||||
|
}
|
||||||
|
if sawEffortSecond {
|
||||||
|
t.Fatal("retry must NOT send reasoning_effort")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCompleteReasoningEffortCachedAfterFirst: after the first 400+heal, the client
|
||||||
|
// remembers the model rejects reasoning_effort and drops the param UP FRONT on later calls —
|
||||||
|
// so a misconfigured GROK_REASONING_EFFORT costs one 400 per process, not one per message.
|
||||||
|
func TestCompleteReasoningEffortCachedAfterFirst(t *testing.T) {
|
||||||
|
var calls, rejected int
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
calls++
|
||||||
|
if strings.Contains(string(body), "reasoning_effort") {
|
||||||
|
rejected++
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
io.WriteString(w, `{"error":"Model m does not support parameter reasoningEffort."}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
io.WriteString(w, `{"id":"ok","choices":[{"message":{"content":"hi"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1}}`)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := newOpenAIClient("xai", srv.URL, "key", nil, discardLog())
|
||||||
|
req := func() openAIRequest {
|
||||||
|
return openAIRequest{Model: "m", Messages: []openAIMessage{{Role: "user", Content: "hi"}}, ReasoningEffort: "low"}
|
||||||
|
}
|
||||||
|
// First call: 400 (with effort) then a stripped retry → 2 HTTP calls, 1 rejection.
|
||||||
|
if _, err := c.complete(context.Background(), req(), nil); err != nil {
|
||||||
|
t.Fatalf("first complete: %v", err)
|
||||||
|
}
|
||||||
|
// Second call: the param is dropped up front → exactly 1 HTTP call, no new rejection.
|
||||||
|
if _, err := c.complete(context.Background(), req(), nil); err != nil {
|
||||||
|
t.Fatalf("second complete: %v", err)
|
||||||
|
}
|
||||||
|
if calls != 3 {
|
||||||
|
t.Fatalf("want 3 HTTP calls total (400+retry, then cached single), got %d", calls)
|
||||||
|
}
|
||||||
|
if rejected != 1 {
|
||||||
|
t.Fatalf("want exactly 1 reasoning_effort rejection (cached after), got %d", rejected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCompleteTerminal4xxNoSelfHeal guards that the strip-and-retry is scoped to the
|
||||||
|
// reasoning_effort 400 only: an unrelated 400 still fails fast (no spurious retry).
|
||||||
|
func TestCompleteTerminal4xxNoSelfHeal(t *testing.T) {
|
||||||
|
var calls int
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
calls++
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
io.WriteString(w, `{"error":"some other invalid argument"}`)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
c := newOpenAIClient("xai", srv.URL, "key", nil, discardLog())
|
||||||
|
_, err := c.complete(context.Background(), openAIRequest{
|
||||||
|
Model: "grok-x-non-reasoning",
|
||||||
|
Messages: []openAIMessage{{Role: "user", Content: "hi"}},
|
||||||
|
ReasoningEffort: "low",
|
||||||
|
}, nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected a terminal error on an unrelated 400")
|
||||||
|
}
|
||||||
|
if calls != 1 {
|
||||||
|
t.Fatalf("unrelated 400 must fail fast (1 call), got %d", calls)
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Add table
Reference in a new issue