replace the hand-rolled markdown renderer with goldmark/bluemonday and harden the ai-bot against quota abuse and third-party leaks

This commit is contained in:
heaven 2026-05-31 20:39:10 +03:00
parent fe8ba2878b
commit a4429d9c31
14 changed files with 617 additions and 70 deletions

View file

@ -4,7 +4,7 @@ import (
"context" "context"
"crypto/subtle" "crypto/subtle"
"encoding/json" "encoding/json"
"log" "log/slog"
"net/http" "net/http"
"strings" "strings"
"time" "time"
@ -16,13 +16,13 @@ import (
// the events to the bot's processing callback. // the events to the bot's processing callback.
type AppService struct { type AppService struct {
cfg *Config cfg *Config
log *log.Logger log *slog.Logger
store *Store store *Store
handler func(ctx context.Context, events []Event) handler func(ctx context.Context, events []Event)
baseCtx context.Context baseCtx context.Context
} }
func NewAppService(cfg *Config, logger *log.Logger, store *Store, handler func(context.Context, []Event)) *AppService { func NewAppService(cfg *Config, logger *slog.Logger, store *Store, handler func(context.Context, []Event)) *AppService {
return &AppService{cfg: cfg, log: logger, store: store, handler: handler} return &AppService{cfg: cfg, log: logger, store: store, handler: handler}
} }
@ -49,7 +49,7 @@ func (a *AppService) Serve(ctx context.Context) error {
errCh := make(chan error, 1) errCh := make(chan error, 1)
go func() { errCh <- srv.ListenAndServe() }() go func() { errCh <- srv.ListenAndServe() }()
a.log.Printf("appservice listening on %s", a.cfg.ASAddr) a.log.Info("appservice listening", "addr", a.cfg.ASAddr)
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -79,7 +79,7 @@ func (a *AppService) authOK(r *http.Request) bool {
func (a *AppService) handleTransaction(w http.ResponseWriter, r *http.Request) { func (a *AppService) handleTransaction(w http.ResponseWriter, r *http.Request) {
if !a.authOK(r) { if !a.authOK(r) {
writeError(w, http.StatusForbidden, "M_FORBIDDEN", "bad hs_token") a.denyUnauthed(w, r)
return return
} }
txnID := r.PathValue("txnId") txnID := r.PathValue("txnId")
@ -90,7 +90,7 @@ func (a *AppService) handleTransaction(w http.ResponseWriter, r *http.Request) {
// Idempotency (spec): a retried, already-processed transaction is a no-op. // Idempotency (spec): a retried, already-processed transaction is a no-op.
if done, err := a.store.HasTxn(txnID); err != nil { if done, err := a.store.HasTxn(txnID); err != nil {
a.log.Printf("txn dedup read failed for %s: %v", txnID, err) a.log.Error("txn dedup read failed", "txn", txnID, "err", err)
} else if done { } else if done {
writeJSON(w, http.StatusOK, struct{}{}) writeJSON(w, http.StatusOK, struct{}{})
return return
@ -103,20 +103,21 @@ func (a *AppService) handleTransaction(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusBadRequest, "M_NOT_JSON", "invalid transaction body") writeError(w, http.StatusBadRequest, "M_NOT_JSON", "invalid transaction body")
return return
} }
a.log.Debug("transaction received", "txn", txnID, "events", len(txn.Events))
// Process with the bot's long-lived context (not the request context) so a // Process with the bot's long-lived context (not the request context) so a
// homeserver-side timeout can't cancel an in-flight reply mid-send. // homeserver-side timeout can't cancel an in-flight reply mid-send.
a.handler(a.baseCtx, txn.Events) a.handler(a.baseCtx, txn.Events)
if err := a.store.MarkTxn(txnID); err != nil { if err := a.store.MarkTxn(txnID); err != nil {
a.log.Printf("txn mark failed for %s: %v", txnID, err) a.log.Error("txn mark failed", "txn", txnID, "err", err)
} }
writeJSON(w, http.StatusOK, struct{}{}) writeJSON(w, http.StatusOK, struct{}{})
} }
func (a *AppService) handleUserQuery(w http.ResponseWriter, r *http.Request) { func (a *AppService) handleUserQuery(w http.ResponseWriter, r *http.Request) {
if !a.authOK(r) { if !a.authOK(r) {
writeError(w, http.StatusForbidden, "M_FORBIDDEN", "bad hs_token") a.denyUnauthed(w, r)
return return
} }
// We own exactly one user. Synapse auto-creates the sender_localpart user; // We own exactly one user. Synapse auto-creates the sender_localpart user;
@ -130,7 +131,7 @@ func (a *AppService) handleUserQuery(w http.ResponseWriter, r *http.Request) {
func (a *AppService) handleRoomQuery(w http.ResponseWriter, r *http.Request) { func (a *AppService) handleRoomQuery(w http.ResponseWriter, r *http.Request) {
if !a.authOK(r) { if !a.authOK(r) {
writeError(w, http.StatusForbidden, "M_FORBIDDEN", "bad hs_token") a.denyUnauthed(w, r)
return return
} }
// The bot claims no room aliases. // The bot claims no room aliases.
@ -146,3 +147,11 @@ func writeJSON(w http.ResponseWriter, status int, body any) {
func writeError(w http.ResponseWriter, status int, code, msg string) { func writeError(w http.ResponseWriter, status int, code, msg string) {
writeJSON(w, status, map[string]string{"errcode": code, "error": msg}) writeJSON(w, status, map[string]string{"errcode": code, "error": msg})
} }
// denyUnauthed logs and rejects a request whose hs_token didn't match. Logging
// at WARN makes probing / a misconfigured homeserver visible (the token itself
// is never logged).
func (a *AppService) denyUnauthed(w http.ResponseWriter, r *http.Request) {
a.log.Warn("rejected request: bad hs_token", "method", r.Method, "path", r.URL.Path, "remote", r.RemoteAddr)
writeError(w, http.StatusForbidden, "M_FORBIDDEN", "bad hs_token")
}

View file

@ -3,7 +3,7 @@ package main
import ( import (
"context" "context"
"io" "io"
"log" "log/slog"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"path/filepath" "path/filepath"
@ -19,7 +19,7 @@ func newTestAS(t *testing.T, dispatched *[][]Event) (*AppService, *Store) {
} }
as := NewAppService( as := NewAppService(
&Config{HSToken: "secret", BotMXID: "@ai:vojo.chat"}, &Config{HSToken: "secret", BotMXID: "@ai:vojo.chat"},
log.New(io.Discard, "", 0), slog.New(slog.NewTextHandler(io.Discard, nil)),
st, st,
func(_ context.Context, ev []Event) { *dispatched = append(*dispatched, ev) }, func(_ context.Context, ev []Event) { *dispatched = append(*dispatched, ev) },
) )

View file

@ -2,17 +2,22 @@ package main
import ( import (
"context" "context"
"log" "fmt"
"log/slog"
"sync" "sync"
) )
// roomMeta caches per-room classification we need to handle a message: member // roomMeta caches per-room classification we need to handle a message: member
// counts (for the 1:1 test, F3) and encryption state (F15). Rebuilt per process; // counts (for the 1:1 test, F3), whether any member is outside ALLOWED_SERVERS,
// unknown fields are lazily fetched from the CS-API on first need — appservice // and encryption state (F15). Lazily fetched from the CS-API on first need
// transactions carry no room summary. // (appservice transactions carry no room summary) and INVALIDATED whenever a
// third party's membership changes, so a 1:1 that gains a member is reclassified
// out of DM mode (no DM-mode third-party leak) and a newly added foreign member
// is caught.
type roomMeta struct { type roomMeta struct {
joined, invited int joined, invited int
countsKnown bool countsKnown bool
foreign bool // a joined/invited member is outside ALLOWED_SERVERS
encrypted, encKnown bool encrypted, encKnown bool
} }
@ -20,7 +25,7 @@ func (m *roomMeta) isDM() bool { return m.countsKnown && m.joined+m.invited == 2
type Bot struct { type Bot struct {
cfg *Config cfg *Config
log *log.Logger log *slog.Logger
mx *MatrixClient mx *MatrixClient
xai *XAIClient xai *XAIClient
st *Store st *Store
@ -32,12 +37,12 @@ type Bot struct {
botSent *lruSet // event ids the bot itself sent (reply-parent detection) botSent *lruSet // event ids the bot itself sent (reply-parent detection)
meta map[string]*roomMeta meta map[string]*roomMeta
buf map[string][]bufferedMsg buf map[string][]bufferedMsg
globalNote map[string]string // roomID → UTC date we last sent the daily-limit notice globalNote map[string]string // roomID → UTC date we last sent the global daily-limit notice
} }
func NewBot(ctx context.Context, cfg *Config, logger *log.Logger) (*Bot, error) { func NewBot(ctx context.Context, cfg *Config, logger *slog.Logger) (*Bot, error) {
mx := NewMatrixClient(cfg.HomeserverURL, cfg.ASToken, cfg.BotMXID) mx := NewMatrixClient(cfg.HomeserverURL, cfg.ASToken, cfg.BotMXID)
xai := NewXAIClient(cfg.XAIBaseURL, cfg.XAIAPIKey) xai := NewXAIClient(cfg.XAIBaseURL, cfg.XAIAPIKey, logger)
st, err := OpenStore(cfg.statePath("ai-bot.db")) st, err := OpenStore(cfg.statePath("ai-bot.db"))
if err != nil { if err != nil {
@ -64,7 +69,7 @@ func NewBot(ctx context.Context, cfg *Config, logger *log.Logger) (*Bot, error)
} }
// F23: ensure the profile has a display name (best-effort, idempotent). // F23: ensure the profile has a display name (best-effort, idempotent).
if err := mx.SetDisplayName(ctx, cfg.BotDisplayName); err != nil { if err := mx.SetDisplayName(ctx, cfg.BotDisplayName); err != nil {
logger.Printf("set display name failed (non-fatal): %v", err) logger.Warn("set display name failed (non-fatal)", "err", err)
} }
return b, nil return b, nil
} }
@ -81,9 +86,9 @@ func (b *Bot) verifyIdentity(ctx context.Context) error {
return err return err
} }
if who != b.cfg.BotMXID { if who != b.cfg.BotMXID {
b.log.Fatalf("as_token resolves to %q but BOT_MXID is %q", who, b.cfg.BotMXID) return fmt.Errorf("as_token resolves to %q but BOT_MXID is %q", who, b.cfg.BotMXID)
} }
b.log.Printf("authenticated as %s", who) b.log.Info("authenticated", "mxid", who)
return nil return nil
} }
@ -107,12 +112,27 @@ func (b *Bot) handleEvent(ctx context.Context, ev *Event) {
return return
} }
if !b.seen.Add(ev.EventID) { if !b.seen.Add(ev.EventID) {
return // already handled this session (fast in-memory path)
}
// Durable dedup across restarts: if a previous run already handled this event
// but crashed before its transaction was acked, Synapse re-pushes it — don't
// reprocess (no dup answer / double-bill). On a DB error, fall through; the
// in-memory set still guards this session.
if isNew, err := b.st.SeenEvent(ev.EventID); err != nil {
b.log.Error("durable dedup check failed", "id", ev.EventID, "err", err)
} else if !isNew {
return return
} }
b.log.Debug("event", "type", ev.Type, "room", ev.RoomID, "sender", ev.Sender, "id", ev.EventID)
switch ev.Type { switch ev.Type {
case "m.room.member": case "m.room.member":
if ev.StateKey != nil && *ev.StateKey == b.cfg.BotMXID { if ev.StateKey != nil && *ev.StateKey == b.cfg.BotMXID {
b.handleSelfMembership(ctx, ev) b.handleSelfMembership(ctx, ev)
} else if m := b.meta[ev.RoomID]; m != nil {
// A third party's membership changed: counts + foreign flag are now
// stale. Re-probe on the next message so a 1:1 that gains a member drops
// out of DM mode (no third-party leak) and a new foreign member is caught.
m.countsKnown = false
} }
case "m.room.encryption": case "m.room.encryption":
m := b.getMeta(ev.RoomID) m := b.getMeta(ev.RoomID)
@ -127,16 +147,24 @@ func (b *Bot) handleEvent(ctx context.Context, ev *Event) {
func (b *Bot) handleSelfMembership(ctx context.Context, ev *Event) { func (b *Bot) handleSelfMembership(ctx context.Context, ev *Event) {
switch ev.membershipOf() { switch ev.membershipOf() {
case "invite": case "invite":
if b.cfg.AllowedServers[serverOf(ev.Sender)] { if !b.cfg.AllowedServers[serverOf(ev.Sender)] {
b.log.Printf("accepting invite to %s from %s", ev.RoomID, ev.Sender) b.log.Warn("rejecting invite (server not allowed)", "room", ev.RoomID, "sender", ev.Sender)
if err := b.mx.JoinRoom(ctx, ev.RoomID); err != nil {
b.log.Printf("join %s failed: %v", ev.RoomID, err)
}
} else {
b.log.Printf("rejecting invite to %s from %q (server not allowed)", ev.RoomID, ev.Sender)
if err := b.mx.LeaveRoom(ctx, ev.RoomID); err != nil { if err := b.mx.LeaveRoom(ctx, ev.RoomID); err != nil {
b.log.Printf("leave (reject) %s failed: %v", ev.RoomID, err) b.log.Error("leave (reject) failed", "room", ev.RoomID, "err", err)
} }
return
}
b.log.Info("accepting invite", "room", ev.RoomID, "sender", ev.Sender)
if err := b.mx.JoinRoom(ctx, ev.RoomID); err != nil {
b.log.Error("join failed", "room", ev.RoomID, "err", err)
return
}
// Fully-on-allowed-servers gate: a vojo.chat inviter can still pull the bot
// into a room that already holds federated third parties — leave at once.
m := b.getMeta(ev.RoomID)
b.ensureCounts(ctx, ev.RoomID, m)
if m.countsKnown && m.foreign {
b.leaveForeign(ctx, ev.RoomID)
} }
case "leave", "ban": case "leave", "ban":
delete(b.meta, ev.RoomID) delete(b.meta, ev.RoomID)
@ -144,6 +172,15 @@ func (b *Bot) handleSelfMembership(ctx context.Context, ev *Event) {
} }
} }
// leaveForeign leaves a room that contains a member outside ALLOWED_SERVERS, so
// the bot only ever operates in rooms hosted entirely on allowed homeservers.
func (b *Bot) leaveForeign(ctx context.Context, roomID string) {
b.log.Warn("leaving room — a member is outside ALLOWED_SERVERS", "room", roomID)
if err := b.mx.LeaveRoom(ctx, roomID); err != nil {
b.log.Error("leave (foreign) failed", "room", roomID, "err", err)
}
}
func (b *Bot) handleMessage(ctx context.Context, ev *Event) { func (b *Bot) handleMessage(ctx context.Context, ev *Event) {
roomID := ev.RoomID roomID := ev.RoomID
m := b.getMeta(roomID) m := b.getMeta(roomID)
@ -152,6 +189,7 @@ func (b *Bot) handleMessage(ctx context.Context, ev *Event) {
// the bot can't read it. // the bot can't read it.
b.ensureEncryption(ctx, roomID, m) b.ensureEncryption(ctx, roomID, m)
if m.encrypted { if m.encrypted {
b.log.Debug("skip: encrypted room", "room", roomID)
b.warnEncryptedOnce(ctx, roomID) b.warnEncryptedOnce(ctx, roomID)
return return
} }
@ -164,6 +202,14 @@ func (b *Bot) handleMessage(ctx context.Context, ev *Event) {
if mc.IsReplace() { if mc.IsReplace() {
return return
} }
// Only plain text ever reaches xAI: drop media (m.image/m.file/m.audio/…) and
// other custom msgtypes outright — the bot doesn't fetch or forward media, and
// a caption/filename is third-party content we keep out of xAI. m.notice falls
// through to its existing anti-loop handling below.
if mc.MsgType != "m.text" && mc.MsgType != "m.emote" && mc.MsgType != "m.notice" {
b.log.Debug("skip: non-text msgtype", "room", roomID, "sender", ev.Sender, "msgtype", mc.MsgType)
return
}
// Buffer prior context BEFORE classifying so buildContext sees history only. // Buffer prior context BEFORE classifying so buildContext sees history only.
history := b.buf[roomID] history := b.buf[roomID]
@ -177,25 +223,47 @@ func (b *Bot) handleMessage(ctx context.Context, ev *Event) {
} }
b.ensureCounts(ctx, roomID, m) b.ensureCounts(ctx, roomID, m)
// Stay only in rooms hosted entirely on allowed servers — never operate in (or
// "leak" the bot into) a federated room with non-consenting third parties.
if m.countsKnown && m.foreign {
b.leaveForeign(ctx, roomID)
return
}
replyParentIsBot := mc.RelatesTo != nil && mc.RelatesTo.InReplyTo != nil && replyParentIsBot := mc.RelatesTo != nil && mc.RelatesTo.InReplyTo != nil &&
b.botSent.Has(mc.RelatesTo.InReplyTo.EventID) b.botSent.Has(mc.RelatesTo.InReplyTo.EventID)
if !(m.isDM() || mentionsBot(mc, b.cfg.BotMXID, replyParentIsBot)) { mentioned := mentionsBot(mc, b.cfg.BotMXID, replyParentIsBot)
if !(m.isDM() || mentioned) {
b.log.Debug("skip: not addressed (group without mention)", "room", roomID, "sender", ev.Sender,
"dm", m.isDM(), "joined", m.joined, "invited", m.invited, "countsKnown", m.countsKnown, "mentioned", mentioned)
return return
} }
b.respond(ctx, roomID, m, ev, mc, history) b.respond(ctx, roomID, m, ev, mc, history)
} }
// unlimitedCap is the effective per-user cap for UNLIMITED_USERS — high enough to
// never trip the per-user gate, while the global DAILY_USD_CEILING still applies.
const unlimitedCap = 1 << 30
func (b *Bot) respond(ctx context.Context, roomID string, m *roomMeta, ev *Event, mc *MessageContent, history []bufferedMsg) { func (b *Bot) respond(ctx context.Context, roomID string, m *roomMeta, ev *Event, mc *MessageContent, history []bufferedMsg) {
switch res, err := b.st.Reserve(ev.Sender, b.cfg.PerUserDailyCap, b.cfg.DailyUSDCeiling); { perUserCap := b.cfg.PerUserDailyCap
if b.cfg.UnlimitedUsers[ev.Sender] {
perUserCap = unlimitedCap
}
switch res, err := b.st.Reserve(ev.Sender, perUserCap, b.cfg.DailyUSDCeiling); {
case err != nil: case err != nil:
b.log.Printf("limiter reserve failed: %v", err) b.log.Error("limiter reserve failed", "sender", ev.Sender, "err", err)
return return
case res == reserveDeniedUser: case res == reserveDeniedUser:
// Silent drop — per-user cap is anti-abuse (F24). // Per-user cap (anti-abuse, F24): stop answering, but always tell the user
// their request hit the limit — no message addressed to the bot is left
// silent. (m.notice → the anti-loop skip keeps this from re-triggering.)
b.log.Info("per-user daily cap reached; notifying", "sender", ev.Sender)
b.sendNotice(ctx, roomID, ev, mc, noticeUserLimit)
return return
case res == reserveDeniedGlobal: case res == reserveDeniedGlobal:
// Global USD ceiling — notice once per room per day, then stay quiet. // Global USD ceiling — notice once per room per day, then stay quiet.
b.log.Warn("global daily USD ceiling reached", "room", roomID, "sender", ev.Sender)
if b.globalNote[roomID] != todayUTC() { if b.globalNote[roomID] != todayUTC() {
b.globalNote[roomID] = todayUTC() b.globalNote[roomID] = todayUTC()
b.sendNotice(ctx, roomID, ev, mc, noticeDailyLimit) b.sendNotice(ctx, roomID, ev, mc, noticeDailyLimit)
@ -203,24 +271,41 @@ func (b *Bot) respond(ctx context.Context, roomID string, m *roomMeta, ev *Event
return return
} }
// Show "Vojo AI печатает…" while we build the answer; the deferred clear fires
// on every exit (success or failure). Pure UX — typing failures are best-effort.
b.setTyping(ctx, roomID, true)
defer b.setTyping(ctx, roomID, false)
msgs := buildContext(b.cfg.SystemPrompt, history, m.isDM(), mc.Body, b.cfg.MaxCtxEvent, 8000) msgs := buildContext(b.cfg.SystemPrompt, history, m.isDM(), mc.Body, b.cfg.MaxCtxEvent, 8000)
resp, err := b.xai.Complete(ctx, b.cfg.XAIModel, msgs, b.cfg.MaxOutTok, b.cfg.XAITemp) resp, err := b.xai.Complete(ctx, b.cfg.XAIModel, msgs, b.cfg.MaxOutTok, b.cfg.XAITemp)
if err != nil { if err != nil {
// at-most-once already retried transient failures inside Complete; refund // at-most-once already retried transient failures inside Complete; refund
// the reserved request so an xAI outage doesn't burn the user's daily cap. // the reserved request so an xAI outage doesn't burn the user's daily cap,
b.log.Printf("xai completion failed for %s: %v", ev.Sender, err) // and tell the user we couldn't answer (notice → no anti-loop re-trigger).
b.log.Error("xai completion failed", "sender", ev.Sender, "err", err)
if rerr := b.st.RefundRequest(ev.Sender); rerr != nil { if rerr := b.st.RefundRequest(ev.Sender); rerr != nil {
b.log.Printf("refund failed: %v", rerr) b.log.Error("refund failed", "sender", ev.Sender, "err", rerr)
} }
b.sendNotice(ctx, roomID, ev, mc, noticeError)
return return
} }
// A 2xx from xAI is billed even if the text came back empty — always book the
// real cost so both caps see it (the old refund-without-reconcile on an empty
// 200 let such calls bypass the per-user cap and the global ceiling).
usd := computeUSD(resp.Usage, b.cfg) usd := computeUSD(resp.Usage, b.cfg)
if err := b.st.Reconcile(ev.Sender, usd); err != nil { if err := b.st.Reconcile(ev.Sender, usd); err != nil {
b.log.Printf("reconcile spend failed: %v", err) b.log.Error("reconcile spend failed", "sender", ev.Sender, "err", err)
} }
b.sendNotice(ctx, roomID, ev, mc, resp.Text()) text := resp.Text()
if text == "" {
b.log.Warn("xai returned empty completion (billed, nothing to send)", "sender", ev.Sender, "usd", usd)
return
}
b.log.Info("answered", "room", roomID, "sender", ev.Sender, "dm", m.isDM(),
"usd", usd, "prompt_tokens", resp.Usage.PromptTokens, "completion_tokens", resp.Usage.CompletionTokens)
b.sendNotice(ctx, roomID, ev, mc, text)
} }
// computeUSD prices the call from the API-returned token usage (authoritative // computeUSD prices the call from the API-returned token usage (authoritative
@ -241,7 +326,7 @@ func (b *Bot) sendNotice(ctx context.Context, roomID string, trigger *Event, tri
content := buildNoticeContent(trigger.EventID, trigger.Sender, triggerMC.RelatesTo, body) content := buildNoticeContent(trigger.EventID, trigger.Sender, triggerMC.RelatesTo, body)
id, err := b.mx.SendEvent(ctx, roomID, "m.room.message", content) id, err := b.mx.SendEvent(ctx, roomID, "m.room.message", content)
if err != nil { if err != nil {
b.log.Printf("send notice to %s failed: %v", roomID, err) b.log.Error("send notice failed", "room", roomID, "err", err)
return return
} }
// Track our own reply so a future reply-to-it is recognised as addressing us, // Track our own reply so a future reply-to-it is recognised as addressing us,
@ -250,10 +335,20 @@ func (b *Bot) sendNotice(ctx context.Context, roomID string, trigger *Event, tri
b.appendBuf(roomID, bufferedMsg{sender: b.cfg.BotMXID, body: body, isBot: true}) b.appendBuf(roomID, bufferedMsg{sender: b.cfg.BotMXID, body: body, isBot: true})
} }
// setTyping sets/clears the bot's typing indicator (best-effort UX; failures are
// non-fatal). The 30s timeout comfortably covers a normal completion, and
// respond() defers a clear so the indicator ends the moment the answer is sent
// or fails.
func (b *Bot) setTyping(ctx context.Context, roomID string, typing bool) {
if err := b.mx.SendTyping(ctx, roomID, typing, 30000); err != nil {
b.log.Debug("set typing failed", "room", roomID, "typing", typing, "err", err)
}
}
func (b *Bot) warnEncryptedOnce(ctx context.Context, roomID string) { func (b *Bot) warnEncryptedOnce(ctx context.Context, roomID string) {
warned, err := b.st.HasWarnedEncrypted(roomID) warned, err := b.st.HasWarnedEncrypted(roomID)
if err != nil { if err != nil {
b.log.Printf("warned-flag read failed: %v", err) b.log.Error("warned-flag read failed", "room", roomID, "err", err)
return return
} }
if warned { if warned {
@ -261,11 +356,11 @@ func (b *Bot) warnEncryptedOnce(ctx context.Context, roomID string) {
} }
content := map[string]any{"msgtype": "m.notice", "body": noticeEncryptedUnsupported} content := map[string]any{"msgtype": "m.notice", "body": noticeEncryptedUnsupported}
if _, err := b.mx.SendEvent(ctx, roomID, "m.room.message", content); err != nil { if _, err := b.mx.SendEvent(ctx, roomID, "m.room.message", content); err != nil {
b.log.Printf("encrypted-notice to %s failed: %v", roomID, err) b.log.Error("encrypted-notice failed", "room", roomID, "err", err)
return return
} }
if err := b.st.SetWarnedEncrypted(roomID); err != nil { if err := b.st.SetWarnedEncrypted(roomID); err != nil {
b.log.Printf("persist warned-flag failed: %v", err) b.log.Error("persist warned-flag failed", "room", roomID, "err", err)
} }
} }
@ -282,12 +377,21 @@ func buildNoticeContent(replyTo, sender string, triggerRelates *RelatesTo, body
} else { } else {
relates["m.in_reply_to"] = map[string]any{"event_id": replyTo} relates["m.in_reply_to"] = map[string]any{"event_id": replyTo}
} }
return map[string]any{ content := map[string]any{
"msgtype": "m.notice", "msgtype": "m.notice",
"body": body, "body": body,
"m.mentions": map[string]any{"user_ids": []string{sender}}, "m.mentions": map[string]any{"user_ids": []string{sender}},
"m.relates_to": relates, "m.relates_to": relates,
} }
// The model answers in markdown; render it to org.matrix.custom.html so clients
// show formatting instead of raw `**`, `#`, lists, code fences. Only attach
// formatted_body when there's actual formatting — a plain answer keeps rendering
// from `body` exactly as before.
if html, formatted := markdownToHTML(body); formatted {
content["format"] = matrixHTMLFormat
content["formatted_body"] = html
}
return content
} }
// --- per-room metadata helpers ------------------------------------------------- // --- per-room metadata helpers -------------------------------------------------
@ -307,7 +411,7 @@ func (b *Bot) ensureEncryption(ctx context.Context, roomID string, m *roomMeta)
} }
enc, err := b.mx.RoomEncrypted(ctx, roomID) enc, err := b.mx.RoomEncrypted(ctx, roomID)
if err != nil { if err != nil {
b.log.Printf("encryption probe %s failed: %v", roomID, err) b.log.Warn("encryption probe failed", "room", roomID, "err", err)
return // leave unknown; re-probed on the next message return // leave unknown; re-probed on the next message
} }
m.encrypted, m.encKnown = enc, true m.encrypted, m.encKnown = enc, true
@ -317,12 +421,19 @@ func (b *Bot) ensureCounts(ctx context.Context, roomID string, m *roomMeta) {
if m.countsKnown { if m.countsKnown {
return return
} }
joined, invited, err := b.mx.MemberCounts(ctx, roomID) joined, invited, servers, err := b.mx.RoomMembership(ctx, roomID)
if err != nil { if err != nil {
b.log.Printf("member-count probe %s failed: %v", roomID, err) b.log.Warn("member probe failed", "room", roomID, "err", err)
return return
} }
m.joined, m.invited, m.countsKnown = joined, invited, true foreign := false
for s := range servers {
if !b.cfg.AllowedServers[s] {
foreign = true
break
}
}
m.joined, m.invited, m.foreign, m.countsKnown = joined, invited, foreign, true
} }
func (b *Bot) appendBuf(roomID string, msg bufferedMsg) { func (b *Bot) appendBuf(roomID string, msg bufferedMsg) {

View file

@ -43,6 +43,10 @@ type Config struct {
DailyUSDCeiling float64 DailyUSDCeiling float64
PerUserDailyCap int PerUserDailyCap int
// mxids exempt from PER_USER_DAILY_CAP (e.g. the owner/admins testing). Still
// subject to the global DAILY_USD_CEILING, so the wallet stays protected.
UnlimitedUsers map[string]bool
// USD-per-1M-token prices applied to the API-returned token usage so the // USD-per-1M-token prices applied to the API-returned token usage so the
// hard ceiling tracks real usage even if the model/price changes. // hard ceiling tracks real usage even if the model/price changes.
PriceInputPerM float64 PriceInputPerM float64
@ -127,6 +131,7 @@ func LoadConfig() (*Config, error) {
SystemPromptPath: getenv("SYSTEM_PROMPT_PATH", "prompts/system_ru.txt"), SystemPromptPath: getenv("SYSTEM_PROMPT_PATH", "prompts/system_ru.txt"),
StateDir: strings.TrimRight(getenv("STATE_DIR", "/state"), "/"), StateDir: strings.TrimRight(getenv("STATE_DIR", "/state"), "/"),
AllowedServers: parseServerSet(getenv("ALLOWED_SERVERS", "")), AllowedServers: parseServerSet(getenv("ALLOWED_SERVERS", "")),
UnlimitedUsers: parseServerSet(getenv("UNLIMITED_USERS", "")),
} }
var problems []string var problems []string
@ -216,6 +221,10 @@ func (c *Config) Summary() string {
for s := range c.AllowedServers { for s := range c.AllowedServers {
servers = append(servers, s) servers = append(servers, s)
} }
unlimited := make([]string, 0, len(c.UnlimitedUsers))
for u := range c.UnlimitedUsers {
unlimited = append(unlimited, u)
}
redact := func(s string) string { redact := func(s string) string {
if s == "" { if s == "" {
return "(unset)" return "(unset)"
@ -245,6 +254,7 @@ func (c *Config) Summary() string {
" ALLOWED_SERVERS = " + strings.Join(servers, ","), " ALLOWED_SERVERS = " + strings.Join(servers, ","),
fmt.Sprintf(" DAILY_USD_CEILING = %g", c.DailyUSDCeiling), fmt.Sprintf(" DAILY_USD_CEILING = %g", c.DailyUSDCeiling),
fmt.Sprintf(" PER_USER_DAILY_CAP = %d", c.PerUserDailyCap), fmt.Sprintf(" PER_USER_DAILY_CAP = %d", c.PerUserDailyCap),
" UNLIMITED_USERS = " + strings.Join(unlimited, ","),
fmt.Sprintf(" PRICES /1M (in/cached/out) = %g / %g / %g", fmt.Sprintf(" PRICES /1M (in/cached/out) = %g / %g / %g",
c.PriceInputPerM, c.PriceCachedPerM, c.PriceOutputPerM), c.PriceInputPerM, c.PriceCachedPerM, c.PriceOutputPerM),
" SYSTEM_PROMPT_PATH = " + c.SystemPromptPath, " SYSTEM_PROMPT_PATH = " + c.SystemPromptPath,

View file

@ -3,16 +3,21 @@ module vojo.chat/ai-bot
go 1.25.0 go 1.25.0
require ( require (
github.com/microcosm-cc/bluemonday v1.0.27
github.com/yuin/goldmark v1.8.2
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
modernc.org/sqlite v1.51.0 modernc.org/sqlite v1.51.0
) )
require ( require (
github.com/aymerick/douceur v0.2.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect
github.com/google/uuid v1.6.0 // indirect github.com/google/uuid v1.6.0 // indirect
github.com/gorilla/css v1.0.1 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
golang.org/x/net v0.26.0 // indirect
golang.org/x/sys v0.42.0 // indirect golang.org/x/sys v0.42.0 // indirect
modernc.org/libc v1.72.3 // indirect modernc.org/libc v1.72.3 // indirect
modernc.org/mathutil v1.7.1 // indirect modernc.org/mathutil v1.7.1 // indirect

View file

@ -1,19 +1,29 @@
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk=
github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE=
github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

44
apps/ai-bot/logging.go Normal file
View file

@ -0,0 +1,44 @@
package main
import (
"log/slog"
"os"
"strings"
)
// newLogger builds the process logger from the environment: LOG_LEVEL
// (debug|info|warn|error, default info) and LOG_FORMAT (text|json, default
// text). It writes to stderr with UTC timestamps (matching the previous
// log.LUTC behaviour). Built from getenv directly — not Config — so it exists
// before LoadConfig (the generate-registration path logs before config loads).
func newLogger() *slog.Logger {
opts := &slog.HandlerOptions{
Level: parseLogLevel(getenv("LOG_LEVEL", "info")),
ReplaceAttr: func(_ []string, a slog.Attr) slog.Attr {
if a.Key == slog.TimeKey && a.Value.Kind() == slog.KindTime {
a.Value = slog.TimeValue(a.Value.Time().UTC())
}
return a
},
}
var h slog.Handler
if strings.EqualFold(strings.TrimSpace(getenv("LOG_FORMAT", "text")), "json") {
h = slog.NewJSONHandler(os.Stderr, opts)
} else {
h = slog.NewTextHandler(os.Stderr, opts)
}
return slog.New(h)
}
func parseLogLevel(s string) slog.Level {
switch strings.ToLower(strings.TrimSpace(s)) {
case "debug":
return slog.LevelDebug
case "warn", "warning":
return slog.LevelWarn
case "error":
return slog.LevelError
default:
return slog.LevelInfo
}
}

View file

@ -9,7 +9,6 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
@ -17,7 +16,7 @@ import (
) )
func main() { func main() {
logger := log.New(os.Stderr, "", log.LstdFlags|log.LUTC) logger := newLogger()
// `ai-bot generate-registration` writes a fresh registration.yaml with random // `ai-bot generate-registration` writes a fresh registration.yaml with random
// tokens (the mautrix bridge idiom), then exits. Runs BEFORE LoadConfig — the // tokens (the mautrix bridge idiom), then exits. Runs BEFORE LoadConfig — the
@ -26,12 +25,14 @@ func main() {
if len(os.Args) > 1 && os.Args[1] == "generate-registration" { if len(os.Args) > 1 && os.Args[1] == "generate-registration" {
mxid := getenv("BOT_MXID", "") mxid := getenv("BOT_MXID", "")
if mxid == "" { if mxid == "" {
logger.Fatalf("BOT_MXID is required to generate the registration") logger.Error("BOT_MXID is required to generate the registration")
os.Exit(1)
} }
path := getenv("REGISTRATION_PATH", "/data/registration.yaml") path := getenv("REGISTRATION_PATH", "/data/registration.yaml")
asURL := getenv("AS_URL", "http://ai-bot:8009") asURL := getenv("AS_URL", "http://ai-bot:8009")
if err := GenerateRegistration(path, asURL, localpartOf(mxid), serverOf(mxid)); err != nil { if err := GenerateRegistration(path, asURL, localpartOf(mxid), serverOf(mxid)); err != nil {
logger.Fatalf("generate-registration: %v", err) logger.Error("generate-registration failed", "err", err)
os.Exit(1)
} }
fmt.Printf("wrote %s\n", path) fmt.Printf("wrote %s\n", path)
fmt.Println("Next: mount this file into Synapse, add it to app_service_config_files,") fmt.Println("Next: mount this file into Synapse, add it to app_service_config_files,")
@ -42,14 +43,16 @@ func main() {
cfg, err := LoadConfig() cfg, err := LoadConfig()
if err != nil { if err != nil {
logger.Fatalf("config error: %v", err) logger.Error("config error", "err", err)
os.Exit(1)
} }
// Load the system prompt up front so a missing/unreadable file fails fast // Load the system prompt up front so a missing/unreadable file fails fast
// at startup rather than on the first message. // at startup rather than on the first message.
promptBytes, err := os.ReadFile(cfg.SystemPromptPath) promptBytes, err := os.ReadFile(cfg.SystemPromptPath)
if err != nil { if err != nil {
logger.Fatalf("cannot read SYSTEM_PROMPT_PATH (%s): %v", cfg.SystemPromptPath, err) logger.Error("cannot read system prompt", "path", cfg.SystemPromptPath, "err", err)
os.Exit(1)
} }
cfg.SystemPrompt = string(promptBytes) cfg.SystemPrompt = string(promptBytes)
@ -64,10 +67,12 @@ func main() {
} }
if err := os.MkdirAll(cfg.StateDir, 0o700); err != nil { if err := os.MkdirAll(cfg.StateDir, 0o700); err != nil {
logger.Fatalf("cannot create STATE_DIR (%s): %v", cfg.StateDir, err) logger.Error("cannot create state dir", "path", cfg.StateDir, "err", err)
os.Exit(1)
} }
logger.Printf("starting\n%s", cfg.Summary()) fmt.Fprintf(os.Stderr, "%s\n", cfg.Summary())
logger.Info("starting Vojo AI bot")
// Cancel on SIGINT/SIGTERM so the transaction server shuts down cleanly. // Cancel on SIGINT/SIGTERM so the transaction server shuts down cleanly.
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
@ -75,14 +80,16 @@ func main() {
bot, err := NewBot(ctx, cfg, logger) bot, err := NewBot(ctx, cfg, logger)
if err != nil { if err != nil {
logger.Fatalf("startup failed: %v", err) logger.Error("startup failed", "err", err)
os.Exit(1)
} }
defer bot.Close() defer bot.Close()
if err := bot.Run(ctx); err != nil && ctx.Err() == nil { if err := bot.Run(ctx); err != nil && ctx.Err() == nil {
logger.Fatalf("appservice server exited with error: %v", err) logger.Error("appservice server exited", "err", err)
os.Exit(1)
} }
logger.Printf("shut down cleanly") logger.Info("shut down cleanly")
} }
// statePath joins a filename under the configured state directory. // statePath joins a filename under the configured state directory.

128
apps/ai-bot/markdown.go Normal file
View file

@ -0,0 +1,128 @@
package main
import (
"bytes"
"strings"
"github.com/microcosm-cc/bluemonday"
"github.com/yuin/goldmark"
"github.com/yuin/goldmark/ast"
"github.com/yuin/goldmark/extension"
"github.com/yuin/goldmark/renderer"
ghtml "github.com/yuin/goldmark/renderer/html"
"github.com/yuin/goldmark/util"
)
// matrixHTMLFormat is the `format` value that flags `formatted_body` as
// org.matrix.custom.html (the only rich format Matrix clients render).
const matrixHTMLFormat = "org.matrix.custom.html"
const (
// maxInputBytes / maxFormattedBytes bound the model reply and the rendered
// HTML; beyond either we fall back to the plain body (no formatted_body).
maxInputBytes = 512 * 1024
maxFormattedBytes = 64 * 1024
)
// mdParser converts the model's CommonMark + GFM (tables, strikethrough,
// autolink, task lists) answer to HTML. WithUnsafe stays OFF (goldmark's default)
// so raw HTML and dangerous URLs are escaped, never rendered; WithHardWraps keeps
// the answer's line breaks as <br>; images are rendered as links, not <img> (see
// imageLinkRenderer). goldmark depends only on the standard library, so the static
// (CGO-free) build is preserved.
var mdParser = goldmark.New(
goldmark.WithExtensions(extension.GFM),
goldmark.WithRendererOptions(
ghtml.WithHardWraps(),
// Priority < the default renderer's 1000 → registered last → overrides
// goldmark's <img> rendering with imageLinkRenderer.
renderer.WithNodeRenderers(util.Prioritized(imageLinkRenderer{}, 100)),
),
)
// imageLinkRenderer overrides goldmark's image rendering to emit a link instead of
// an <img>, so a markdown image stays functional (a clickable link to its source)
// without ever putting a remote <img> in the event — which a client could
// auto-load, leaking the viewer's IP to a URL a prompt-injected reply chose.
type imageLinkRenderer struct{}
func (imageLinkRenderer) RegisterFuncs(reg renderer.NodeRendererFuncRegisterer) {
reg.Register(ast.KindImage, renderImageAsLink)
}
// renderImageAsLink renders ![alt](src) as <a href="src">alt</a>: the alt content
// (the node's children) becomes the link label. Mirrors goldmark's own URL escape
// + dangerous-URL guard; bluemonday re-checks the scheme afterwards.
func renderImageAsLink(w util.BufWriter, _ []byte, node ast.Node, entering bool) (ast.WalkStatus, error) {
n := node.(*ast.Image)
if entering {
_, _ = w.WriteString(`<a href="`)
dest := util.URLEscape(n.Destination, true)
if !ghtml.IsDangerousURL(dest) {
_, _ = w.Write(util.EscapeHTML(dest))
}
_, _ = w.WriteString(`">`)
} else {
_, _ = w.WriteString("</a>")
}
return ast.WalkContinue, nil
}
// htmlPolicy strips goldmark's output to the tags/attributes Cinny's renderer
// keeps (src/app/utils/sanitize.ts: permittedHtmlTags / urlSchemes) — defence in
// depth over goldmark's own escaping, and the single allowlist a crafted reply
// can't get around. Anything else (script/style/img/on*-handlers/unknown URL
// schemes) is removed.
var htmlPolicy = buildHTMLPolicy()
func buildHTMLPolicy() *bluemonday.Policy {
p := bluemonday.NewPolicy()
p.AllowElements(
"p", "br", "hr",
"h1", "h2", "h3", "h4", "h5", "h6",
"strong", "em", "del", "s", "code", "pre",
"blockquote", "ul", "ol", "li",
"table", "thead", "tbody", "tr", "th", "td",
)
p.AllowAttrs("href").OnElements("a")
p.AllowURLSchemes("https", "http", "ftp", "mailto", "magnet")
p.RequireParseableURLs(true)
p.AllowAttrs("class").OnElements("code", "pre") // language-xxx on code blocks
p.AllowAttrs("start").OnElements("ol")
return p
}
// markdownToHTML converts the model's markdown answer to sanitized
// org.matrix.custom.html and reports whether any rich formatting was emitted.
// When false the caller MUST omit formatted_body so a plain answer renders from
// the bare `body` exactly as before (Matrix convention: only attach
// formatted_body when it adds formatting the plain body can't carry).
func markdownToHTML(md string) (string, bool) {
if len(md) > maxInputBytes {
return "", false // implausibly large; just send the plain body
}
var buf bytes.Buffer
if err := mdParser.Convert([]byte(md), &buf); err != nil {
return "", false
}
html := strings.TrimSpace(string(htmlPolicy.SanitizeBytes(buf.Bytes())))
if len(html) > maxFormattedBytes {
return "", false // too large to be worth sending as a Matrix event
}
if !hasRichMarkup(html) {
return "", false // just a paragraph of text — the plain body is enough
}
return html, true
}
// hasRichMarkup reports whether the HTML carries formatting beyond the paragraph
// wrapper and soft line breaks goldmark emits for plain text, so a plain reply
// keeps rendering from the bare body. Model text is HTML-escaped (a literal '<'
// becomes "&lt;"), so any remaining raw '<' is a tag the converter emitted.
func hasRichMarkup(html string) bool {
stripped := html
for _, t := range []string{"<p>", "</p>", "<br>", "<br/>", "<br />"} {
stripped = strings.ReplaceAll(stripped, t, "")
}
return strings.Contains(stripped, "<")
}

View file

@ -0,0 +1,169 @@
package main
import (
"strings"
"testing"
)
// TestMarkdownToHTML asserts the rich constructs render and plain text stays
// plain. It checks for the meaningful tags/escaping (Contains), not goldmark's
// exact byte output — the converter's precise formatting is its own contract, not
// ours to pin.
func TestMarkdownToHTML(t *testing.T) {
rich := []struct {
name string
in string
contains []string
}{
{"bold", "a **bold** b", []string{"<strong>bold</strong>"}},
{"italic star", "a *it* b", []string{"<em>it</em>"}},
{"italic underscore", "a _it_ b", []string{"<em>it</em>"}},
{"bold italic", "***x***", []string{"<strong>", "<em>", "x"}},
{"strikethrough", "~~gone~~", []string{"gone"}}, // <del> or <s>; both rich
{"inline code", "use `npm i`", []string{"<code>npm i</code>"}},
{"inline code keeps stars literal", "`a*b*c`", []string{"<code>a*b*c</code>"}},
{"heading h1", "# Title", []string{"<h1>", "Title", "</h1>"}},
{"hr", "---", []string{"<hr"}},
{"unordered list", "- one\n- two", []string{"<ul>", "<li>", "one", "two"}},
{"ordered list", "1. one\n2. two", []string{"<ol>", "<li>", "one"}},
{"blockquote", "> quoted", []string{"<blockquote>", "quoted"}},
{"link", "see [xAI](https://x.ai)", []string{`href="https://x.ai"`, "xAI"}},
{"fenced code", "```go\nfmt.Println()\n```", []string{"<pre>", "<code", "fmt.Println"}},
{"table", "| a | b |\n| - | - |\n| 1 | 2 |", []string{"<table>", "<th>", "a", "<td>", "1"}},
{"image as link", "![logo](https://x.ai/logo.png)", []string{`href="https://x.ai/logo.png"`, "logo"}},
{"autolink bare url", "visit https://x.ai now", []string{`href="https://x.ai"`}},
}
for _, c := range rich {
t.Run("rich/"+c.name, func(t *testing.T) {
got, formatted := markdownToHTML(c.in)
if !formatted {
t.Fatalf("markdownToHTML(%q) formatted=false, want true (got %q)", c.in, got)
}
for _, sub := range c.contains {
if !strings.Contains(got, sub) {
t.Fatalf("markdownToHTML(%q) = %q, missing %q", c.in, got, sub)
}
}
})
}
// Plain text (even multi-line or with stray punctuation) carries no
// formatting, so the bot sends only the bare body.
plain := []string{
"just a sentence",
"line one\nline two",
"a < b & c > d",
"2 * 3 * 4",
"snake_case_name",
айл_имя_тут",
"text with ! bang",
`path c:\users`,
"",
}
for _, in := range plain {
t.Run("plain", func(t *testing.T) {
if got, formatted := markdownToHTML(in); formatted {
t.Fatalf("markdownToHTML(%q) formatted=true, want false (got %q)", in, got)
}
})
}
}
func TestMarkdownNeverEmitsUnsafeScheme(t *testing.T) {
for _, bad := range []string{
"[a](javascript:x)", "[a](data:text/html,x)", "[a](vbscript:x)", "[a](file:///etc)",
"[a](JaVaScRiPt:x)", "[a](java\tscript:x)",
} {
if html, _ := markdownToHTML(bad); strings.Contains(html, "href=") {
t.Fatalf("emitted a link for unsafe scheme: %q -> %q", bad, html)
}
}
}
func TestMarkdownOversizeFallsBackToPlain(t *testing.T) {
// A formatted reply whose HTML exceeds maxFormattedBytes must return ("", false)
// so the bot sends only the plain body.
big := strings.Repeat("- item\n", 8000)
if html, formatted := markdownToHTML(big); formatted || html != "" {
t.Fatalf("oversize formatted output should fall back to plain: formatted=%v len=%d", formatted, len(html))
}
// Implausibly large input is rejected outright.
huge := strings.Repeat("a", maxInputBytes+1)
if html, formatted := markdownToHTML(huge); formatted || html != "" {
t.Fatalf("oversize input should fall back to plain: formatted=%v len=%d", formatted, len(html))
}
}
func TestMarkdownAdversarialNoPanicNoInjection(t *testing.T) {
inputs := []string{
strings.Repeat("[", 20000) + "x",
"x" + strings.Repeat("](https://a)", 20000),
strings.Repeat("*", 5000) + "x" + strings.Repeat("*", 5000),
strings.Repeat("> ", 5000) + "deep",
strings.Repeat(" ", 50) + "- nested",
strings.Repeat("`", 4000) + "code",
"| " + strings.Repeat("a |", 2000) + "\n| " + strings.Repeat("- |", 2000) + "\n| x |",
"<script>alert(1)</script>\n**`<b>`**\n[x](\"><svg onload=alert(1)>)",
strings.Repeat("***nest ", 200) + "x" + strings.Repeat(" nest***", 200),
}
// Every model '<' is escaped to &lt;, so a dangerous element can only exist if
// the converter emitted it — and it emits none of these tag names. (Attribute
// vectors like onload= can appear only as escaped literal text, which is safe;
// the safe-href guarantee is covered by the unit + scheme tests.)
for i, in := range inputs {
html, _ := markdownToHTML(in) // must not panic
for _, tag := range []string{"<script", "<svg", "<img", "<iframe", "<style", "<object", "<embed"} {
if strings.Contains(strings.ToLower(html), tag) {
t.Fatalf("case %d emitted a dangerous tag %q: %.160q", i, tag, html)
}
}
}
}
func TestBuildNoticeContentAttachesFormatted(t *testing.T) {
c := buildNoticeContent("$evt", "@u:vojo.chat", nil, "Here is **bold**.")
if c["format"] != matrixHTMLFormat {
t.Fatalf("format = %v, want %v", c["format"], matrixHTMLFormat)
}
fb, _ := c["formatted_body"].(string)
if !strings.Contains(fb, "<strong>bold</strong>") {
t.Fatalf("formatted_body missing bold: %q", fb)
}
if c["body"] != "Here is **bold**." {
t.Fatalf("plain body must be preserved, got %v", c["body"])
}
}
func TestBuildNoticeContentSkipsFormattedForPlain(t *testing.T) {
c := buildNoticeContent("$evt", "@u:vojo.chat", nil, "no markdown here")
if _, ok := c["format"]; ok {
t.Fatalf("format must be absent for plain text")
}
if _, ok := c["formatted_body"]; ok {
t.Fatalf("formatted_body must be absent for plain text")
}
}
// TestMarkdownNoHangOnBangAndBackslash guards the inline-parser infinite loop: a
// '!' not starting an image, or a backslash not before ASCII punctuation
// (trailing, or before a letter/space/Cyrillic), used to fall through to a
// non-advancing default branch and spin forever — freezing the whole bot under
// the transaction mutex. These must all RETURN; if the bug returns this test
// hangs and `go test` times out instead of passing.
func TestMarkdownNoHangOnBangAndBackslash(t *testing.T) {
for _, in := range []string{
"Привет!",
"Hello! How are you?",
`path c:\users`,
`trailing backslash \`,
`что-то \ или вот это`,
`\` + "д",
"!",
"!!! wow",
"text with ! bang",
strings.Repeat("a! ", 2000),
strings.Repeat(`\`, 2000),
} {
_, _ = markdownToHTML(in) // a hang fails via test timeout
}
}

View file

@ -142,6 +142,18 @@ func (c *MatrixClient) SetDisplayName(ctx context.Context, name string) error {
return c.do(ctx, http.MethodPut, path, nil, map[string]any{"displayname": name}, nil) return c.do(ctx, http.MethodPut, path, nil, map[string]any{"displayname": name}, nil)
} }
// SendTyping sets or clears the bot user's typing indicator in a room. The
// homeserver broadcasts m.typing, which clients render as "… is typing"; the
// timeout (ms) applies only when starting and is omitted when clearing.
func (c *MatrixClient) SendTyping(ctx context.Context, roomID string, typing bool, timeoutMs int) error {
path := "/_matrix/client/v3/rooms/" + url.PathEscape(roomID) + "/typing/" + url.PathEscape(c.asUserID)
body := map[string]any{"typing": typing}
if typing {
body["timeout"] = timeoutMs
}
return c.do(ctx, http.MethodPut, path, nil, body, nil)
}
// RoomEncrypted checks live encryption state (F15 — never a join-time snapshot). // RoomEncrypted checks live encryption state (F15 — never a join-time snapshot).
// A 404/M_NOT_FOUND means no m.room.encryption state → unencrypted. // A 404/M_NOT_FOUND means no m.room.encryption state → unencrypted.
func (c *MatrixClient) RoomEncrypted(ctx context.Context, roomID string) (bool, error) { func (c *MatrixClient) RoomEncrypted(ctx context.Context, roomID string) (bool, error) {
@ -156,23 +168,34 @@ func (c *MatrixClient) RoomEncrypted(ctx context.Context, roomID string) (bool,
return false, err return false, err
} }
// MemberCounts returns joined+invited counts via /members, used to classify a // RoomMembership returns joined+invited counts and the set of homeservers that
// room as a 1:1 (F3) — appservice transactions carry no room summary. // have a member present (joined or invited). Used both to classify a room as a
func (c *MatrixClient) MemberCounts(ctx context.Context, roomID string) (joined, invited int, err error) { // 1:1 (F3) and to enforce that the bot only stays in rooms hosted entirely on
// allowed servers — appservice transactions carry no room summary, so this reads
// /members. The member is identified by the event's state_key (the sender is
// whoever *set* the membership, which may differ).
func (c *MatrixClient) RoomMembership(ctx context.Context, roomID string) (joined, invited int, servers map[string]bool, err error) {
path := "/_matrix/client/v3/rooms/" + url.PathEscape(roomID) + "/members" path := "/_matrix/client/v3/rooms/" + url.PathEscape(roomID) + "/members"
var out struct { var out struct {
Chunk []Event `json:"chunk"` Chunk []Event `json:"chunk"`
} }
if err = c.do(ctx, http.MethodGet, path, nil, nil, &out); err != nil { if err = c.do(ctx, http.MethodGet, path, nil, nil, &out); err != nil {
return 0, 0, err return 0, 0, nil, err
} }
servers = make(map[string]bool)
for i := range out.Chunk { for i := range out.Chunk {
switch out.Chunk[i].membershipOf() { e := &out.Chunk[i]
if e.StateKey == nil {
continue
}
switch e.membershipOf() {
case "join": case "join":
joined++ joined++
servers[serverOf(*e.StateKey)] = true
case "invite": case "invite":
invited++ invited++
servers[serverOf(*e.StateKey)] = true
} }
} }
return joined, invited, nil return joined, invited, servers, nil
} }

View file

@ -8,4 +8,8 @@ const (
"Напишите мне в обычном (незашифрованном) чате." "Напишите мне в обычном (незашифрованном) чате."
noticeDailyLimit = "Достигнут дневной лимит обращений к ИИ в этом сервисе. Попробуйте позже." noticeDailyLimit = "Достигнут дневной лимит обращений к ИИ в этом сервисе. Попробуйте позже."
noticeUserLimit = "Вы исчерпали свой дневной лимит обращений к ИИ. Попробуйте позже."
noticeError = "⚠️ Не удалось получить ответ от ИИ. Попробуйте ещё раз чуть позже."
) )

View file

@ -39,7 +39,8 @@ func OpenStore(path string) (*Store, error) {
requests INTEGER NOT NULL DEFAULT 0, usd REAL NOT NULL DEFAULT 0, requests INTEGER NOT NULL DEFAULT 0, usd REAL NOT NULL DEFAULT 0,
PRIMARY KEY (date, mxid) PRIMARY KEY (date, mxid)
); );
CREATE TABLE IF NOT EXISTS warned_encrypted (room_id TEXT PRIMARY KEY);` CREATE TABLE IF NOT EXISTS warned_encrypted (room_id TEXT PRIMARY KEY);
CREATE TABLE IF NOT EXISTS processed_event (event_id TEXT PRIMARY KEY);`
if _, err := db.Exec(schema); err != nil { if _, err := db.Exec(schema); err != nil {
db.Close() db.Close()
return nil, fmt.Errorf("init schema: %w", err) return nil, fmt.Errorf("init schema: %w", err)
@ -72,6 +73,24 @@ func (s *Store) MarkTxn(txnID string) error {
return err return err
} }
// SeenEvent records an event id as handled and reports whether it was NEW (true)
// or already seen (false) — the DURABLE equivalent of the in-memory dedup set, so
// a crash/restart between handling an event and acking its transaction can't make
// the bot reprocess it (dup answer + double-bill + cap inflation). Bounded to the
// most recent ids.
func (s *Store) SeenEvent(eventID string) (bool, error) {
res, err := s.db.Exec(`INSERT OR IGNORE INTO processed_event (event_id) VALUES (?)`, eventID)
if err != nil {
return false, err
}
if n, _ := res.RowsAffected(); n == 0 {
return false, nil // already recorded → not new
}
_, err = s.db.Exec(`DELETE FROM processed_event WHERE rowid NOT IN
(SELECT rowid FROM processed_event ORDER BY rowid DESC LIMIT 20000)`)
return true, err
}
// SpentTodayUSD sums all spend for the current UTC day. // SpentTodayUSD sums all spend for the current UTC day.
func (s *Store) SpentTodayUSD() (float64, error) { func (s *Store) SpentTodayUSD() (float64, error) {
var v sql.NullFloat64 var v sql.NullFloat64

View file

@ -6,6 +6,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log/slog"
"math/rand" "math/rand"
"net/http" "net/http"
"time" "time"
@ -18,14 +19,16 @@ type XAIClient struct {
key string key string
http *http.Client http *http.Client
maxTry int maxTry int
log *slog.Logger
} }
func NewXAIClient(base, key string) *XAIClient { func NewXAIClient(base, key string, logger *slog.Logger) *XAIClient {
return &XAIClient{ return &XAIClient{
base: base, base: base,
key: key, key: key,
http: &http.Client{}, http: &http.Client{},
maxTry: 3, maxTry: 3,
log: logger,
} }
} }
@ -111,6 +114,9 @@ func (x *XAIClient) Complete(ctx context.Context, model string, msgs []xaiMessag
if !retryable { if !retryable {
return nil, err return nil, err
} }
if x.log != nil {
x.log.Warn("xai attempt failed, will retry", "attempt", attempt+1, "max", x.maxTry, "err", err)
}
} }
return nil, fmt.Errorf("xai: exhausted %d attempts: %w", x.maxTry, lastErr) return nil, fmt.Errorf("xai: exhausted %d attempts: %w", x.maxTry, lastErr)
} }
@ -148,9 +154,11 @@ func (x *XAIClient) attempt(ctx context.Context, payload []byte) (*xaiResponse,
if err := json.Unmarshal(data, &out); err != nil { if err := json.Unmarshal(data, &out); err != nil {
return nil, false, fmt.Errorf("xai decode: %w", err) return nil, false, fmt.Errorf("xai decode: %w", err)
} }
if out.Text() == "" { // A 2xx is a billed call even when the model returns empty content (content
return nil, false, fmt.Errorf("xai returned no choices") // filter, finish_reason=length with no text, or no choices). Return it as a
} // success so the caller books the real cost via Reconcile instead of refunding
// the slot and losing the spend — which would let empty replies bypass BOTH the
// per-user cap and the global ceiling. The caller just won't send an empty body.
return &out, false, nil return &out, false, nil
} }