vojo/apps/ai-bot/store.go

338 lines
12 KiB
Go

package main
import (
"context"
"errors"
"fmt"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
// reserveResult is the outcome of a pre-call limiter reservation.
type reserveResult int
const (
reserveOK reserveResult = iota
reserveDeniedUser // per-user daily request cap hit (⏳ rate-limit reaction, F24)
reserveDeniedGlobal // global daily USD ceiling hit (⏳ rate-limit reaction, F24)
)
// LRU bounds for the dedup tables (unchanged from the former SQLite store): keep
// only the most recent ids so the tables don't grow without limit.
const (
maxProcessedTxn = 5000
maxProcessedEvent = 20000
)
// opTimeout bounds every store operation. SQLite (a local file) effectively never
// blocked; Postgres is over the docker network, so a cap keeps a stalled DB from
// hanging a per-room handler goroutine forever.
const opTimeout = 10 * time.Second
// Store is the durable bot state: transaction + event dedup, the daily spend
// ledger, and the encrypted-room warned set. It holds ONLY operational data — no
// message content (the room timeline lives in Synapse). Backed by a dedicated
// Postgres database (`vojo_ai`), in line with the per-service bridge databases, so
// the spend ledger, dedup state and warned set share the server's backup/restore.
type Store struct {
pool *pgxpool.Pool
}
// OpenStore connects to the `vojo_ai` Postgres database via the AI_BOT_DATABASE_URL
// DSN, applies pending migrations, and returns a ready Store. A small pool suffices:
// the bot processes transactions serially and every statement here is short.
func OpenStore(dsn string) (*Store, error) {
cfg, err := pgxpool.ParseConfig(dsn)
if err != nil {
return nil, fmt.Errorf("parse AI_BOT_DATABASE_URL: %w", err)
}
// The former SQLite store pinned a single connection to serialize all callers;
// pgx gives us a real pool. Keep it small — the per-room handler goroutines only
// ever issue brief statements, and the shared server runs many other databases.
cfg.MaxConns = 4
cfg.MinConns = 1
ctx, cancel := context.WithTimeout(context.Background(), opTimeout)
defer cancel()
pool, err := pgxpool.NewWithConfig(ctx, cfg)
if err != nil {
return nil, fmt.Errorf("connect vojo_ai: %w", err)
}
if err := pool.Ping(ctx); err != nil {
pool.Close()
return nil, fmt.Errorf("ping vojo_ai: %w", err)
}
s := &Store{pool: pool}
if err := s.migrate(ctx); err != nil {
pool.Close()
return nil, err
}
return s, nil
}
func (s *Store) Close() error {
s.pool.Close()
return nil
}
// migrationLockKey namespaces the advisory lock that guards the migration runner so
// two starting instances (the bot is single-instance, but be robust) can't race the
// version check. Arbitrary fixed constant.
const migrationLockKey = 0x76_6f_6a_6f // "vojo"
// migrations are applied in order; schema_version records the highest applied
// version so re-runs are no-ops. Every step is also idempotent (CREATE TABLE IF NOT
// EXISTS) so a half-applied database still converges.
var migrations = []string{
// v1: the operational schema — a 1:1 port of the former SQLite tables.
// processed_* carry a surrogate identity column because Postgres has no rowid:
// the LRU trim orders by it, and txn_id/event_id stay UNIQUE for the upsert.
`CREATE TABLE IF NOT EXISTS processed_txn (
id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
txn_id TEXT UNIQUE NOT NULL
);
CREATE TABLE IF NOT EXISTS processed_event (
id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
event_id TEXT UNIQUE NOT NULL
);
CREATE TABLE IF NOT EXISTS spend (
date TEXT NOT NULL,
mxid TEXT NOT NULL,
requests INTEGER NOT NULL DEFAULT 0,
usd DOUBLE PRECISION NOT NULL DEFAULT 0,
PRIMARY KEY (date, mxid)
);
CREATE TABLE IF NOT EXISTS warned_encrypted (room_id TEXT PRIMARY KEY);`,
}
// migrate runs all pending migrations on a single connection under a session
// advisory lock, recording each in schema_version.
func (s *Store) migrate(ctx context.Context) error {
conn, err := s.pool.Acquire(ctx)
if err != nil {
return fmt.Errorf("migrate: acquire: %w", err)
}
defer conn.Release()
if _, err := conn.Exec(ctx, `SELECT pg_advisory_lock($1)`, int64(migrationLockKey)); err != nil {
return fmt.Errorf("migrate: lock: %w", err)
}
defer func() {
_, _ = conn.Exec(ctx, `SELECT pg_advisory_unlock($1)`, int64(migrationLockKey))
}()
if _, err := conn.Exec(ctx, `CREATE TABLE IF NOT EXISTS schema_version (version INTEGER PRIMARY KEY)`); err != nil {
return fmt.Errorf("migrate: schema_version: %w", err)
}
var current int
if err := conn.QueryRow(ctx, `SELECT COALESCE(MAX(version), 0) FROM schema_version`).Scan(&current); err != nil {
return fmt.Errorf("migrate: read version: %w", err)
}
for v := current; v < len(migrations); v++ {
tx, err := conn.Begin(ctx)
if err != nil {
return fmt.Errorf("migrate: begin %d: %w", v+1, err)
}
if _, err := tx.Exec(ctx, migrations[v]); err != nil {
_ = tx.Rollback(ctx)
return fmt.Errorf("migrate: apply %d: %w", v+1, err)
}
if _, err := tx.Exec(ctx, `INSERT INTO schema_version (version) VALUES ($1)`, v+1); err != nil {
_ = tx.Rollback(ctx)
return fmt.Errorf("migrate: record %d: %w", v+1, err)
}
if err := tx.Commit(ctx); err != nil {
return fmt.Errorf("migrate: commit %d: %w", v+1, err)
}
}
return nil
}
func todayUTC() string { return time.Now().UTC().Format("2006-01-02") }
// opContext derives a bounded context for a single store operation.
func opContext() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), opTimeout)
}
// HasTxn / MarkTxn give appservice transactions idempotency across restarts: a
// transaction Synapse retries (because our 200 was lost) is processed at most
// once. The table is bounded to the most recent ids.
func (s *Store) HasTxn(txnID string) (bool, error) {
ctx, cancel := opContext()
defer cancel()
var one int
err := s.pool.QueryRow(ctx, `SELECT 1 FROM processed_txn WHERE txn_id = $1`, txnID).Scan(&one)
if errors.Is(err, pgx.ErrNoRows) {
return false, nil
}
return err == nil, err
}
func (s *Store) MarkTxn(txnID string) error {
ctx, cancel := opContext()
defer cancel()
if _, err := s.pool.Exec(ctx,
`INSERT INTO processed_txn (txn_id) VALUES ($1) ON CONFLICT DO NOTHING`, txnID); err != nil {
return err
}
_, err := s.pool.Exec(ctx, `DELETE FROM processed_txn WHERE id NOT IN
(SELECT id FROM processed_txn ORDER BY id DESC LIMIT $1)`, maxProcessedTxn)
return err
}
// SeenEvent records an event id as handled and reports whether it was NEW (true)
// or already seen (false) — the DURABLE equivalent of the in-memory dedup set, so
// a crash/restart between handling an event and acking its transaction can't make
// the bot reprocess it (dup answer + double-bill + cap inflation). Bounded to the
// most recent ids. INSERT … ON CONFLICT DO NOTHING affects 1 row on insert and 0 on
// conflict, so RowsAffected distinguishes new from already-seen.
func (s *Store) SeenEvent(eventID string) (bool, error) {
ctx, cancel := opContext()
defer cancel()
tag, err := s.pool.Exec(ctx,
`INSERT INTO processed_event (event_id) VALUES ($1) ON CONFLICT DO NOTHING`, eventID)
if err != nil {
return false, err
}
if tag.RowsAffected() == 0 {
return false, nil // already recorded → not new
}
_, err = s.pool.Exec(ctx, `DELETE FROM processed_event WHERE id NOT IN
(SELECT id FROM processed_event ORDER BY id DESC LIMIT $1)`, maxProcessedEvent)
return true, err
}
// SpentTodayUSD sums all spend for the current UTC day. SUM over no rows is NULL,
// which scans into a nil *float64 → treated as 0.
func (s *Store) SpentTodayUSD() (float64, error) {
ctx, cancel := opContext()
defer cancel()
var v *float64
if err := s.pool.QueryRow(ctx, `SELECT SUM(usd) FROM spend WHERE date = $1`, todayUTC()).Scan(&v); err != nil {
return 0, err
}
if v == nil {
return 0, nil
}
return *v, nil
}
// Reserve runs the two independent gates in one transaction, BEFORE the xAI call
// (F4): the global USD ceiling protects the wallet; the per-user request cap is
// anti-abuse. It increments the per-user request count on success; the USD is
// reconciled after the response. Order: global first (cheapest to deny), then
// per-user.
//
// A transaction-scoped advisory lock on (date, mxid) serializes concurrent
// reservations for the SAME user+day, so the per-user check-then-increment stays
// atomic. The former SQLite store got this for free (one connection serialized all
// callers); the pgx pool is concurrent, and the same user messaging from two rooms
// at once would otherwise be able to slip past the per-user cap. Different users
// never contend.
func (s *Store) Reserve(mxid string, perUserCap int, dailyUSDCeiling float64) (reserveResult, error) {
ctx, cancel := opContext()
defer cancel()
day := todayUTC()
tx, err := s.pool.Begin(ctx)
if err != nil {
return reserveOK, err
}
defer tx.Rollback(ctx)
// Key on date|mxid. The separator only needs to avoid cross-key ambiguity; a
// hash collision would merely over-serialize two unrelated users, never corrupt a
// count. (NUL is rejected by Postgres text, so use a printable separator.)
if _, err := tx.Exec(ctx, `SELECT pg_advisory_xact_lock(hashtextextended($1, 0))`, day+"|"+mxid); err != nil {
return reserveOK, err
}
// SUM over zero rows is NULL → nil pointer → treat as 0.0, exactly as the SQLite
// store's sql.NullFloat64 did (and as SpentTodayUSD does). This keeps the gate 1:1
// even at the degenerate dailyUSDCeiling == 0 (deny everything), where 0 >= 0.
var global *float64
if err := tx.QueryRow(ctx, `SELECT SUM(usd) FROM spend WHERE date = $1`, day).Scan(&global); err != nil {
return reserveOK, err
}
spentToday := 0.0
if global != nil {
spentToday = *global
}
if spentToday >= dailyUSDCeiling {
return reserveDeniedGlobal, nil
}
var requests int
err = tx.QueryRow(ctx, `SELECT requests FROM spend WHERE date = $1 AND mxid = $2`, day, mxid).Scan(&requests)
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
return reserveOK, err
}
if requests >= perUserCap {
return reserveDeniedUser, nil
}
if _, err := tx.Exec(ctx,
`INSERT INTO spend (date, mxid, requests, usd) VALUES ($1, $2, 1, 0)
ON CONFLICT (date, mxid) DO UPDATE SET requests = spend.requests + 1`,
day, mxid); err != nil {
return reserveOK, err
}
if err := tx.Commit(ctx); err != nil {
return reserveOK, err
}
return reserveOK, nil
}
// RefundRequest gives back a reserved request slot when the call ultimately
// failed (e.g. an xAI outage), so a transient failure doesn't burn the user's
// daily cap. Never drops below zero. A single UPDATE is atomic, so concurrent
// refunds settle correctly without extra locking.
func (s *Store) RefundRequest(mxid string) error {
ctx, cancel := opContext()
defer cancel()
_, err := s.pool.Exec(ctx,
`UPDATE spend SET requests = GREATEST(0, requests - 1) WHERE date = $1 AND mxid = $2`,
todayUTC(), mxid)
return err
}
// Reconcile books the actual USD cost of a completed call against the user's
// daily row (and thus the global total). The accumulating upsert is atomic and
// commutative, so concurrent reconciles for the same user sum correctly.
func (s *Store) Reconcile(mxid string, usd float64) error {
ctx, cancel := opContext()
defer cancel()
_, err := s.pool.Exec(ctx,
`INSERT INTO spend (date, mxid, requests, usd) VALUES ($1, $2, 0, $3)
ON CONFLICT (date, mxid) DO UPDATE SET usd = spend.usd + excluded.usd`,
todayUTC(), mxid, usd)
return err
}
// HasWarnedEncrypted / SetWarnedEncrypted persist the one-shot "reacted 🔒 to this
// room because I can't read encryption" flag so a restart doesn't re-react on every
// message (F5). The bot never reacts to its own events: m.reaction is not an
// m.room.message, so it never re-enters handleMessage.
func (s *Store) HasWarnedEncrypted(roomID string) (bool, error) {
ctx, cancel := opContext()
defer cancel()
var one int
err := s.pool.QueryRow(ctx, `SELECT 1 FROM warned_encrypted WHERE room_id = $1`, roomID).Scan(&one)
if errors.Is(err, pgx.ErrNoRows) {
return false, nil
}
return err == nil, err
}
func (s *Store) SetWarnedEncrypted(roomID string) error {
ctx, cancel := opContext()
defer cancel()
_, err := s.pool.Exec(ctx,
`INSERT INTO warned_encrypted (room_id) VALUES ($1) ON CONFLICT DO NOTHING`, roomID)
return err
}