replace the hand-rolled markdown renderer with goldmark/bluemonday and harden the ai-bot against quota abuse and third-party leaks
This commit is contained in:
parent
fe8ba2878b
commit
a4429d9c31
14 changed files with 617 additions and 70 deletions
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"log"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -16,13 +16,13 @@ import (
|
||||||
// the events to the bot's processing callback.
|
// the events to the bot's processing callback.
|
||||||
type AppService struct {
|
type AppService struct {
|
||||||
cfg *Config
|
cfg *Config
|
||||||
log *log.Logger
|
log *slog.Logger
|
||||||
store *Store
|
store *Store
|
||||||
handler func(ctx context.Context, events []Event)
|
handler func(ctx context.Context, events []Event)
|
||||||
baseCtx context.Context
|
baseCtx context.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAppService(cfg *Config, logger *log.Logger, store *Store, handler func(context.Context, []Event)) *AppService {
|
func NewAppService(cfg *Config, logger *slog.Logger, store *Store, handler func(context.Context, []Event)) *AppService {
|
||||||
return &AppService{cfg: cfg, log: logger, store: store, handler: handler}
|
return &AppService{cfg: cfg, log: logger, store: store, handler: handler}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -49,7 +49,7 @@ func (a *AppService) Serve(ctx context.Context) error {
|
||||||
|
|
||||||
errCh := make(chan error, 1)
|
errCh := make(chan error, 1)
|
||||||
go func() { errCh <- srv.ListenAndServe() }()
|
go func() { errCh <- srv.ListenAndServe() }()
|
||||||
a.log.Printf("appservice listening on %s", a.cfg.ASAddr)
|
a.log.Info("appservice listening", "addr", a.cfg.ASAddr)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
|
@ -79,7 +79,7 @@ func (a *AppService) authOK(r *http.Request) bool {
|
||||||
|
|
||||||
func (a *AppService) handleTransaction(w http.ResponseWriter, r *http.Request) {
|
func (a *AppService) handleTransaction(w http.ResponseWriter, r *http.Request) {
|
||||||
if !a.authOK(r) {
|
if !a.authOK(r) {
|
||||||
writeError(w, http.StatusForbidden, "M_FORBIDDEN", "bad hs_token")
|
a.denyUnauthed(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
txnID := r.PathValue("txnId")
|
txnID := r.PathValue("txnId")
|
||||||
|
|
@ -90,7 +90,7 @@ func (a *AppService) handleTransaction(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// Idempotency (spec): a retried, already-processed transaction is a no-op.
|
// Idempotency (spec): a retried, already-processed transaction is a no-op.
|
||||||
if done, err := a.store.HasTxn(txnID); err != nil {
|
if done, err := a.store.HasTxn(txnID); err != nil {
|
||||||
a.log.Printf("txn dedup read failed for %s: %v", txnID, err)
|
a.log.Error("txn dedup read failed", "txn", txnID, "err", err)
|
||||||
} else if done {
|
} else if done {
|
||||||
writeJSON(w, http.StatusOK, struct{}{})
|
writeJSON(w, http.StatusOK, struct{}{})
|
||||||
return
|
return
|
||||||
|
|
@ -103,20 +103,21 @@ func (a *AppService) handleTransaction(w http.ResponseWriter, r *http.Request) {
|
||||||
writeError(w, http.StatusBadRequest, "M_NOT_JSON", "invalid transaction body")
|
writeError(w, http.StatusBadRequest, "M_NOT_JSON", "invalid transaction body")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
a.log.Debug("transaction received", "txn", txnID, "events", len(txn.Events))
|
||||||
|
|
||||||
// Process with the bot's long-lived context (not the request context) so a
|
// Process with the bot's long-lived context (not the request context) so a
|
||||||
// homeserver-side timeout can't cancel an in-flight reply mid-send.
|
// homeserver-side timeout can't cancel an in-flight reply mid-send.
|
||||||
a.handler(a.baseCtx, txn.Events)
|
a.handler(a.baseCtx, txn.Events)
|
||||||
|
|
||||||
if err := a.store.MarkTxn(txnID); err != nil {
|
if err := a.store.MarkTxn(txnID); err != nil {
|
||||||
a.log.Printf("txn mark failed for %s: %v", txnID, err)
|
a.log.Error("txn mark failed", "txn", txnID, "err", err)
|
||||||
}
|
}
|
||||||
writeJSON(w, http.StatusOK, struct{}{})
|
writeJSON(w, http.StatusOK, struct{}{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AppService) handleUserQuery(w http.ResponseWriter, r *http.Request) {
|
func (a *AppService) handleUserQuery(w http.ResponseWriter, r *http.Request) {
|
||||||
if !a.authOK(r) {
|
if !a.authOK(r) {
|
||||||
writeError(w, http.StatusForbidden, "M_FORBIDDEN", "bad hs_token")
|
a.denyUnauthed(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// We own exactly one user. Synapse auto-creates the sender_localpart user;
|
// We own exactly one user. Synapse auto-creates the sender_localpart user;
|
||||||
|
|
@ -130,7 +131,7 @@ func (a *AppService) handleUserQuery(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
func (a *AppService) handleRoomQuery(w http.ResponseWriter, r *http.Request) {
|
func (a *AppService) handleRoomQuery(w http.ResponseWriter, r *http.Request) {
|
||||||
if !a.authOK(r) {
|
if !a.authOK(r) {
|
||||||
writeError(w, http.StatusForbidden, "M_FORBIDDEN", "bad hs_token")
|
a.denyUnauthed(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// The bot claims no room aliases.
|
// The bot claims no room aliases.
|
||||||
|
|
@ -146,3 +147,11 @@ func writeJSON(w http.ResponseWriter, status int, body any) {
|
||||||
func writeError(w http.ResponseWriter, status int, code, msg string) {
|
func writeError(w http.ResponseWriter, status int, code, msg string) {
|
||||||
writeJSON(w, status, map[string]string{"errcode": code, "error": msg})
|
writeJSON(w, status, map[string]string{"errcode": code, "error": msg})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// denyUnauthed logs and rejects a request whose hs_token didn't match. Logging
|
||||||
|
// at WARN makes probing / a misconfigured homeserver visible (the token itself
|
||||||
|
// is never logged).
|
||||||
|
func (a *AppService) denyUnauthed(w http.ResponseWriter, r *http.Request) {
|
||||||
|
a.log.Warn("rejected request: bad hs_token", "method", r.Method, "path", r.URL.Path, "remote", r.RemoteAddr)
|
||||||
|
writeError(w, http.StatusForbidden, "M_FORBIDDEN", "bad hs_token")
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ package main
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
@ -19,7 +19,7 @@ func newTestAS(t *testing.T, dispatched *[][]Event) (*AppService, *Store) {
|
||||||
}
|
}
|
||||||
as := NewAppService(
|
as := NewAppService(
|
||||||
&Config{HSToken: "secret", BotMXID: "@ai:vojo.chat"},
|
&Config{HSToken: "secret", BotMXID: "@ai:vojo.chat"},
|
||||||
log.New(io.Discard, "", 0),
|
slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||||
st,
|
st,
|
||||||
func(_ context.Context, ev []Event) { *dispatched = append(*dispatched, ev) },
|
func(_ context.Context, ev []Event) { *dispatched = append(*dispatched, ev) },
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -2,17 +2,22 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"log"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
// roomMeta caches per-room classification we need to handle a message: member
|
// roomMeta caches per-room classification we need to handle a message: member
|
||||||
// counts (for the 1:1 test, F3) and encryption state (F15). Rebuilt per process;
|
// counts (for the 1:1 test, F3), whether any member is outside ALLOWED_SERVERS,
|
||||||
// unknown fields are lazily fetched from the CS-API on first need — appservice
|
// and encryption state (F15). Lazily fetched from the CS-API on first need
|
||||||
// transactions carry no room summary.
|
// (appservice transactions carry no room summary) and INVALIDATED whenever a
|
||||||
|
// third party's membership changes, so a 1:1 that gains a member is reclassified
|
||||||
|
// out of DM mode (no DM-mode third-party leak) and a newly added foreign member
|
||||||
|
// is caught.
|
||||||
type roomMeta struct {
|
type roomMeta struct {
|
||||||
joined, invited int
|
joined, invited int
|
||||||
countsKnown bool
|
countsKnown bool
|
||||||
|
foreign bool // a joined/invited member is outside ALLOWED_SERVERS
|
||||||
encrypted, encKnown bool
|
encrypted, encKnown bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -20,7 +25,7 @@ func (m *roomMeta) isDM() bool { return m.countsKnown && m.joined+m.invited == 2
|
||||||
|
|
||||||
type Bot struct {
|
type Bot struct {
|
||||||
cfg *Config
|
cfg *Config
|
||||||
log *log.Logger
|
log *slog.Logger
|
||||||
mx *MatrixClient
|
mx *MatrixClient
|
||||||
xai *XAIClient
|
xai *XAIClient
|
||||||
st *Store
|
st *Store
|
||||||
|
|
@ -32,12 +37,12 @@ type Bot struct {
|
||||||
botSent *lruSet // event ids the bot itself sent (reply-parent detection)
|
botSent *lruSet // event ids the bot itself sent (reply-parent detection)
|
||||||
meta map[string]*roomMeta
|
meta map[string]*roomMeta
|
||||||
buf map[string][]bufferedMsg
|
buf map[string][]bufferedMsg
|
||||||
globalNote map[string]string // roomID → UTC date we last sent the daily-limit notice
|
globalNote map[string]string // roomID → UTC date we last sent the global daily-limit notice
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBot(ctx context.Context, cfg *Config, logger *log.Logger) (*Bot, error) {
|
func NewBot(ctx context.Context, cfg *Config, logger *slog.Logger) (*Bot, error) {
|
||||||
mx := NewMatrixClient(cfg.HomeserverURL, cfg.ASToken, cfg.BotMXID)
|
mx := NewMatrixClient(cfg.HomeserverURL, cfg.ASToken, cfg.BotMXID)
|
||||||
xai := NewXAIClient(cfg.XAIBaseURL, cfg.XAIAPIKey)
|
xai := NewXAIClient(cfg.XAIBaseURL, cfg.XAIAPIKey, logger)
|
||||||
|
|
||||||
st, err := OpenStore(cfg.statePath("ai-bot.db"))
|
st, err := OpenStore(cfg.statePath("ai-bot.db"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -64,7 +69,7 @@ func NewBot(ctx context.Context, cfg *Config, logger *log.Logger) (*Bot, error)
|
||||||
}
|
}
|
||||||
// F23: ensure the profile has a display name (best-effort, idempotent).
|
// F23: ensure the profile has a display name (best-effort, idempotent).
|
||||||
if err := mx.SetDisplayName(ctx, cfg.BotDisplayName); err != nil {
|
if err := mx.SetDisplayName(ctx, cfg.BotDisplayName); err != nil {
|
||||||
logger.Printf("set display name failed (non-fatal): %v", err)
|
logger.Warn("set display name failed (non-fatal)", "err", err)
|
||||||
}
|
}
|
||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
@ -81,9 +86,9 @@ func (b *Bot) verifyIdentity(ctx context.Context) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if who != b.cfg.BotMXID {
|
if who != b.cfg.BotMXID {
|
||||||
b.log.Fatalf("as_token resolves to %q but BOT_MXID is %q", who, b.cfg.BotMXID)
|
return fmt.Errorf("as_token resolves to %q but BOT_MXID is %q", who, b.cfg.BotMXID)
|
||||||
}
|
}
|
||||||
b.log.Printf("authenticated as %s", who)
|
b.log.Info("authenticated", "mxid", who)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -107,12 +112,27 @@ func (b *Bot) handleEvent(ctx context.Context, ev *Event) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !b.seen.Add(ev.EventID) {
|
if !b.seen.Add(ev.EventID) {
|
||||||
|
return // already handled this session (fast in-memory path)
|
||||||
|
}
|
||||||
|
// Durable dedup across restarts: if a previous run already handled this event
|
||||||
|
// but crashed before its transaction was acked, Synapse re-pushes it — don't
|
||||||
|
// reprocess (no dup answer / double-bill). On a DB error, fall through; the
|
||||||
|
// in-memory set still guards this session.
|
||||||
|
if isNew, err := b.st.SeenEvent(ev.EventID); err != nil {
|
||||||
|
b.log.Error("durable dedup check failed", "id", ev.EventID, "err", err)
|
||||||
|
} else if !isNew {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
b.log.Debug("event", "type", ev.Type, "room", ev.RoomID, "sender", ev.Sender, "id", ev.EventID)
|
||||||
switch ev.Type {
|
switch ev.Type {
|
||||||
case "m.room.member":
|
case "m.room.member":
|
||||||
if ev.StateKey != nil && *ev.StateKey == b.cfg.BotMXID {
|
if ev.StateKey != nil && *ev.StateKey == b.cfg.BotMXID {
|
||||||
b.handleSelfMembership(ctx, ev)
|
b.handleSelfMembership(ctx, ev)
|
||||||
|
} else if m := b.meta[ev.RoomID]; m != nil {
|
||||||
|
// A third party's membership changed: counts + foreign flag are now
|
||||||
|
// stale. Re-probe on the next message so a 1:1 that gains a member drops
|
||||||
|
// out of DM mode (no third-party leak) and a new foreign member is caught.
|
||||||
|
m.countsKnown = false
|
||||||
}
|
}
|
||||||
case "m.room.encryption":
|
case "m.room.encryption":
|
||||||
m := b.getMeta(ev.RoomID)
|
m := b.getMeta(ev.RoomID)
|
||||||
|
|
@ -127,16 +147,24 @@ func (b *Bot) handleEvent(ctx context.Context, ev *Event) {
|
||||||
func (b *Bot) handleSelfMembership(ctx context.Context, ev *Event) {
|
func (b *Bot) handleSelfMembership(ctx context.Context, ev *Event) {
|
||||||
switch ev.membershipOf() {
|
switch ev.membershipOf() {
|
||||||
case "invite":
|
case "invite":
|
||||||
if b.cfg.AllowedServers[serverOf(ev.Sender)] {
|
if !b.cfg.AllowedServers[serverOf(ev.Sender)] {
|
||||||
b.log.Printf("accepting invite to %s from %s", ev.RoomID, ev.Sender)
|
b.log.Warn("rejecting invite (server not allowed)", "room", ev.RoomID, "sender", ev.Sender)
|
||||||
if err := b.mx.JoinRoom(ctx, ev.RoomID); err != nil {
|
|
||||||
b.log.Printf("join %s failed: %v", ev.RoomID, err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
b.log.Printf("rejecting invite to %s from %q (server not allowed)", ev.RoomID, ev.Sender)
|
|
||||||
if err := b.mx.LeaveRoom(ctx, ev.RoomID); err != nil {
|
if err := b.mx.LeaveRoom(ctx, ev.RoomID); err != nil {
|
||||||
b.log.Printf("leave (reject) %s failed: %v", ev.RoomID, err)
|
b.log.Error("leave (reject) failed", "room", ev.RoomID, "err", err)
|
||||||
}
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.log.Info("accepting invite", "room", ev.RoomID, "sender", ev.Sender)
|
||||||
|
if err := b.mx.JoinRoom(ctx, ev.RoomID); err != nil {
|
||||||
|
b.log.Error("join failed", "room", ev.RoomID, "err", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Fully-on-allowed-servers gate: a vojo.chat inviter can still pull the bot
|
||||||
|
// into a room that already holds federated third parties — leave at once.
|
||||||
|
m := b.getMeta(ev.RoomID)
|
||||||
|
b.ensureCounts(ctx, ev.RoomID, m)
|
||||||
|
if m.countsKnown && m.foreign {
|
||||||
|
b.leaveForeign(ctx, ev.RoomID)
|
||||||
}
|
}
|
||||||
case "leave", "ban":
|
case "leave", "ban":
|
||||||
delete(b.meta, ev.RoomID)
|
delete(b.meta, ev.RoomID)
|
||||||
|
|
@ -144,6 +172,15 @@ func (b *Bot) handleSelfMembership(ctx context.Context, ev *Event) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// leaveForeign leaves a room that contains a member outside ALLOWED_SERVERS, so
|
||||||
|
// the bot only ever operates in rooms hosted entirely on allowed homeservers.
|
||||||
|
func (b *Bot) leaveForeign(ctx context.Context, roomID string) {
|
||||||
|
b.log.Warn("leaving room — a member is outside ALLOWED_SERVERS", "room", roomID)
|
||||||
|
if err := b.mx.LeaveRoom(ctx, roomID); err != nil {
|
||||||
|
b.log.Error("leave (foreign) failed", "room", roomID, "err", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (b *Bot) handleMessage(ctx context.Context, ev *Event) {
|
func (b *Bot) handleMessage(ctx context.Context, ev *Event) {
|
||||||
roomID := ev.RoomID
|
roomID := ev.RoomID
|
||||||
m := b.getMeta(roomID)
|
m := b.getMeta(roomID)
|
||||||
|
|
@ -152,6 +189,7 @@ func (b *Bot) handleMessage(ctx context.Context, ev *Event) {
|
||||||
// the bot can't read it.
|
// the bot can't read it.
|
||||||
b.ensureEncryption(ctx, roomID, m)
|
b.ensureEncryption(ctx, roomID, m)
|
||||||
if m.encrypted {
|
if m.encrypted {
|
||||||
|
b.log.Debug("skip: encrypted room", "room", roomID)
|
||||||
b.warnEncryptedOnce(ctx, roomID)
|
b.warnEncryptedOnce(ctx, roomID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -164,6 +202,14 @@ func (b *Bot) handleMessage(ctx context.Context, ev *Event) {
|
||||||
if mc.IsReplace() {
|
if mc.IsReplace() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// Only plain text ever reaches xAI: drop media (m.image/m.file/m.audio/…) and
|
||||||
|
// other custom msgtypes outright — the bot doesn't fetch or forward media, and
|
||||||
|
// a caption/filename is third-party content we keep out of xAI. m.notice falls
|
||||||
|
// through to its existing anti-loop handling below.
|
||||||
|
if mc.MsgType != "m.text" && mc.MsgType != "m.emote" && mc.MsgType != "m.notice" {
|
||||||
|
b.log.Debug("skip: non-text msgtype", "room", roomID, "sender", ev.Sender, "msgtype", mc.MsgType)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Buffer prior context BEFORE classifying so buildContext sees history only.
|
// Buffer prior context BEFORE classifying so buildContext sees history only.
|
||||||
history := b.buf[roomID]
|
history := b.buf[roomID]
|
||||||
|
|
@ -177,25 +223,47 @@ func (b *Bot) handleMessage(ctx context.Context, ev *Event) {
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ensureCounts(ctx, roomID, m)
|
b.ensureCounts(ctx, roomID, m)
|
||||||
|
// Stay only in rooms hosted entirely on allowed servers — never operate in (or
|
||||||
|
// "leak" the bot into) a federated room with non-consenting third parties.
|
||||||
|
if m.countsKnown && m.foreign {
|
||||||
|
b.leaveForeign(ctx, roomID)
|
||||||
|
return
|
||||||
|
}
|
||||||
replyParentIsBot := mc.RelatesTo != nil && mc.RelatesTo.InReplyTo != nil &&
|
replyParentIsBot := mc.RelatesTo != nil && mc.RelatesTo.InReplyTo != nil &&
|
||||||
b.botSent.Has(mc.RelatesTo.InReplyTo.EventID)
|
b.botSent.Has(mc.RelatesTo.InReplyTo.EventID)
|
||||||
|
|
||||||
if !(m.isDM() || mentionsBot(mc, b.cfg.BotMXID, replyParentIsBot)) {
|
mentioned := mentionsBot(mc, b.cfg.BotMXID, replyParentIsBot)
|
||||||
|
if !(m.isDM() || mentioned) {
|
||||||
|
b.log.Debug("skip: not addressed (group without mention)", "room", roomID, "sender", ev.Sender,
|
||||||
|
"dm", m.isDM(), "joined", m.joined, "invited", m.invited, "countsKnown", m.countsKnown, "mentioned", mentioned)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
b.respond(ctx, roomID, m, ev, mc, history)
|
b.respond(ctx, roomID, m, ev, mc, history)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// unlimitedCap is the effective per-user cap for UNLIMITED_USERS — high enough to
|
||||||
|
// never trip the per-user gate, while the global DAILY_USD_CEILING still applies.
|
||||||
|
const unlimitedCap = 1 << 30
|
||||||
|
|
||||||
func (b *Bot) respond(ctx context.Context, roomID string, m *roomMeta, ev *Event, mc *MessageContent, history []bufferedMsg) {
|
func (b *Bot) respond(ctx context.Context, roomID string, m *roomMeta, ev *Event, mc *MessageContent, history []bufferedMsg) {
|
||||||
switch res, err := b.st.Reserve(ev.Sender, b.cfg.PerUserDailyCap, b.cfg.DailyUSDCeiling); {
|
perUserCap := b.cfg.PerUserDailyCap
|
||||||
|
if b.cfg.UnlimitedUsers[ev.Sender] {
|
||||||
|
perUserCap = unlimitedCap
|
||||||
|
}
|
||||||
|
switch res, err := b.st.Reserve(ev.Sender, perUserCap, b.cfg.DailyUSDCeiling); {
|
||||||
case err != nil:
|
case err != nil:
|
||||||
b.log.Printf("limiter reserve failed: %v", err)
|
b.log.Error("limiter reserve failed", "sender", ev.Sender, "err", err)
|
||||||
return
|
return
|
||||||
case res == reserveDeniedUser:
|
case res == reserveDeniedUser:
|
||||||
// Silent drop — per-user cap is anti-abuse (F24).
|
// Per-user cap (anti-abuse, F24): stop answering, but always tell the user
|
||||||
|
// their request hit the limit — no message addressed to the bot is left
|
||||||
|
// silent. (m.notice → the anti-loop skip keeps this from re-triggering.)
|
||||||
|
b.log.Info("per-user daily cap reached; notifying", "sender", ev.Sender)
|
||||||
|
b.sendNotice(ctx, roomID, ev, mc, noticeUserLimit)
|
||||||
return
|
return
|
||||||
case res == reserveDeniedGlobal:
|
case res == reserveDeniedGlobal:
|
||||||
// Global USD ceiling — notice once per room per day, then stay quiet.
|
// Global USD ceiling — notice once per room per day, then stay quiet.
|
||||||
|
b.log.Warn("global daily USD ceiling reached", "room", roomID, "sender", ev.Sender)
|
||||||
if b.globalNote[roomID] != todayUTC() {
|
if b.globalNote[roomID] != todayUTC() {
|
||||||
b.globalNote[roomID] = todayUTC()
|
b.globalNote[roomID] = todayUTC()
|
||||||
b.sendNotice(ctx, roomID, ev, mc, noticeDailyLimit)
|
b.sendNotice(ctx, roomID, ev, mc, noticeDailyLimit)
|
||||||
|
|
@ -203,24 +271,41 @@ func (b *Bot) respond(ctx context.Context, roomID string, m *roomMeta, ev *Event
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Show "Vojo AI печатает…" while we build the answer; the deferred clear fires
|
||||||
|
// on every exit (success or failure). Pure UX — typing failures are best-effort.
|
||||||
|
b.setTyping(ctx, roomID, true)
|
||||||
|
defer b.setTyping(ctx, roomID, false)
|
||||||
|
|
||||||
msgs := buildContext(b.cfg.SystemPrompt, history, m.isDM(), mc.Body, b.cfg.MaxCtxEvent, 8000)
|
msgs := buildContext(b.cfg.SystemPrompt, history, m.isDM(), mc.Body, b.cfg.MaxCtxEvent, 8000)
|
||||||
resp, err := b.xai.Complete(ctx, b.cfg.XAIModel, msgs, b.cfg.MaxOutTok, b.cfg.XAITemp)
|
resp, err := b.xai.Complete(ctx, b.cfg.XAIModel, msgs, b.cfg.MaxOutTok, b.cfg.XAITemp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// at-most-once already retried transient failures inside Complete; refund
|
// at-most-once already retried transient failures inside Complete; refund
|
||||||
// the reserved request so an xAI outage doesn't burn the user's daily cap.
|
// the reserved request so an xAI outage doesn't burn the user's daily cap,
|
||||||
b.log.Printf("xai completion failed for %s: %v", ev.Sender, err)
|
// and tell the user we couldn't answer (notice → no anti-loop re-trigger).
|
||||||
|
b.log.Error("xai completion failed", "sender", ev.Sender, "err", err)
|
||||||
if rerr := b.st.RefundRequest(ev.Sender); rerr != nil {
|
if rerr := b.st.RefundRequest(ev.Sender); rerr != nil {
|
||||||
b.log.Printf("refund failed: %v", rerr)
|
b.log.Error("refund failed", "sender", ev.Sender, "err", rerr)
|
||||||
}
|
}
|
||||||
|
b.sendNotice(ctx, roomID, ev, mc, noticeError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A 2xx from xAI is billed even if the text came back empty — always book the
|
||||||
|
// real cost so both caps see it (the old refund-without-reconcile on an empty
|
||||||
|
// 200 let such calls bypass the per-user cap and the global ceiling).
|
||||||
usd := computeUSD(resp.Usage, b.cfg)
|
usd := computeUSD(resp.Usage, b.cfg)
|
||||||
if err := b.st.Reconcile(ev.Sender, usd); err != nil {
|
if err := b.st.Reconcile(ev.Sender, usd); err != nil {
|
||||||
b.log.Printf("reconcile spend failed: %v", err)
|
b.log.Error("reconcile spend failed", "sender", ev.Sender, "err", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.sendNotice(ctx, roomID, ev, mc, resp.Text())
|
text := resp.Text()
|
||||||
|
if text == "" {
|
||||||
|
b.log.Warn("xai returned empty completion (billed, nothing to send)", "sender", ev.Sender, "usd", usd)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.log.Info("answered", "room", roomID, "sender", ev.Sender, "dm", m.isDM(),
|
||||||
|
"usd", usd, "prompt_tokens", resp.Usage.PromptTokens, "completion_tokens", resp.Usage.CompletionTokens)
|
||||||
|
b.sendNotice(ctx, roomID, ev, mc, text)
|
||||||
}
|
}
|
||||||
|
|
||||||
// computeUSD prices the call from the API-returned token usage (authoritative
|
// computeUSD prices the call from the API-returned token usage (authoritative
|
||||||
|
|
@ -241,7 +326,7 @@ func (b *Bot) sendNotice(ctx context.Context, roomID string, trigger *Event, tri
|
||||||
content := buildNoticeContent(trigger.EventID, trigger.Sender, triggerMC.RelatesTo, body)
|
content := buildNoticeContent(trigger.EventID, trigger.Sender, triggerMC.RelatesTo, body)
|
||||||
id, err := b.mx.SendEvent(ctx, roomID, "m.room.message", content)
|
id, err := b.mx.SendEvent(ctx, roomID, "m.room.message", content)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.log.Printf("send notice to %s failed: %v", roomID, err)
|
b.log.Error("send notice failed", "room", roomID, "err", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Track our own reply so a future reply-to-it is recognised as addressing us,
|
// Track our own reply so a future reply-to-it is recognised as addressing us,
|
||||||
|
|
@ -250,10 +335,20 @@ func (b *Bot) sendNotice(ctx context.Context, roomID string, trigger *Event, tri
|
||||||
b.appendBuf(roomID, bufferedMsg{sender: b.cfg.BotMXID, body: body, isBot: true})
|
b.appendBuf(roomID, bufferedMsg{sender: b.cfg.BotMXID, body: body, isBot: true})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setTyping sets/clears the bot's typing indicator (best-effort UX; failures are
|
||||||
|
// non-fatal). The 30s timeout comfortably covers a normal completion, and
|
||||||
|
// respond() defers a clear so the indicator ends the moment the answer is sent
|
||||||
|
// or fails.
|
||||||
|
func (b *Bot) setTyping(ctx context.Context, roomID string, typing bool) {
|
||||||
|
if err := b.mx.SendTyping(ctx, roomID, typing, 30000); err != nil {
|
||||||
|
b.log.Debug("set typing failed", "room", roomID, "typing", typing, "err", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (b *Bot) warnEncryptedOnce(ctx context.Context, roomID string) {
|
func (b *Bot) warnEncryptedOnce(ctx context.Context, roomID string) {
|
||||||
warned, err := b.st.HasWarnedEncrypted(roomID)
|
warned, err := b.st.HasWarnedEncrypted(roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.log.Printf("warned-flag read failed: %v", err)
|
b.log.Error("warned-flag read failed", "room", roomID, "err", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if warned {
|
if warned {
|
||||||
|
|
@ -261,11 +356,11 @@ func (b *Bot) warnEncryptedOnce(ctx context.Context, roomID string) {
|
||||||
}
|
}
|
||||||
content := map[string]any{"msgtype": "m.notice", "body": noticeEncryptedUnsupported}
|
content := map[string]any{"msgtype": "m.notice", "body": noticeEncryptedUnsupported}
|
||||||
if _, err := b.mx.SendEvent(ctx, roomID, "m.room.message", content); err != nil {
|
if _, err := b.mx.SendEvent(ctx, roomID, "m.room.message", content); err != nil {
|
||||||
b.log.Printf("encrypted-notice to %s failed: %v", roomID, err)
|
b.log.Error("encrypted-notice failed", "room", roomID, "err", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := b.st.SetWarnedEncrypted(roomID); err != nil {
|
if err := b.st.SetWarnedEncrypted(roomID); err != nil {
|
||||||
b.log.Printf("persist warned-flag failed: %v", err)
|
b.log.Error("persist warned-flag failed", "room", roomID, "err", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -282,12 +377,21 @@ func buildNoticeContent(replyTo, sender string, triggerRelates *RelatesTo, body
|
||||||
} else {
|
} else {
|
||||||
relates["m.in_reply_to"] = map[string]any{"event_id": replyTo}
|
relates["m.in_reply_to"] = map[string]any{"event_id": replyTo}
|
||||||
}
|
}
|
||||||
return map[string]any{
|
content := map[string]any{
|
||||||
"msgtype": "m.notice",
|
"msgtype": "m.notice",
|
||||||
"body": body,
|
"body": body,
|
||||||
"m.mentions": map[string]any{"user_ids": []string{sender}},
|
"m.mentions": map[string]any{"user_ids": []string{sender}},
|
||||||
"m.relates_to": relates,
|
"m.relates_to": relates,
|
||||||
}
|
}
|
||||||
|
// The model answers in markdown; render it to org.matrix.custom.html so clients
|
||||||
|
// show formatting instead of raw `**`, `#`, lists, code fences. Only attach
|
||||||
|
// formatted_body when there's actual formatting — a plain answer keeps rendering
|
||||||
|
// from `body` exactly as before.
|
||||||
|
if html, formatted := markdownToHTML(body); formatted {
|
||||||
|
content["format"] = matrixHTMLFormat
|
||||||
|
content["formatted_body"] = html
|
||||||
|
}
|
||||||
|
return content
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- per-room metadata helpers -------------------------------------------------
|
// --- per-room metadata helpers -------------------------------------------------
|
||||||
|
|
@ -307,7 +411,7 @@ func (b *Bot) ensureEncryption(ctx context.Context, roomID string, m *roomMeta)
|
||||||
}
|
}
|
||||||
enc, err := b.mx.RoomEncrypted(ctx, roomID)
|
enc, err := b.mx.RoomEncrypted(ctx, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.log.Printf("encryption probe %s failed: %v", roomID, err)
|
b.log.Warn("encryption probe failed", "room", roomID, "err", err)
|
||||||
return // leave unknown; re-probed on the next message
|
return // leave unknown; re-probed on the next message
|
||||||
}
|
}
|
||||||
m.encrypted, m.encKnown = enc, true
|
m.encrypted, m.encKnown = enc, true
|
||||||
|
|
@ -317,12 +421,19 @@ func (b *Bot) ensureCounts(ctx context.Context, roomID string, m *roomMeta) {
|
||||||
if m.countsKnown {
|
if m.countsKnown {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
joined, invited, err := b.mx.MemberCounts(ctx, roomID)
|
joined, invited, servers, err := b.mx.RoomMembership(ctx, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.log.Printf("member-count probe %s failed: %v", roomID, err)
|
b.log.Warn("member probe failed", "room", roomID, "err", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
m.joined, m.invited, m.countsKnown = joined, invited, true
|
foreign := false
|
||||||
|
for s := range servers {
|
||||||
|
if !b.cfg.AllowedServers[s] {
|
||||||
|
foreign = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.joined, m.invited, m.foreign, m.countsKnown = joined, invited, foreign, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bot) appendBuf(roomID string, msg bufferedMsg) {
|
func (b *Bot) appendBuf(roomID string, msg bufferedMsg) {
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,10 @@ type Config struct {
|
||||||
DailyUSDCeiling float64
|
DailyUSDCeiling float64
|
||||||
PerUserDailyCap int
|
PerUserDailyCap int
|
||||||
|
|
||||||
|
// mxids exempt from PER_USER_DAILY_CAP (e.g. the owner/admins testing). Still
|
||||||
|
// subject to the global DAILY_USD_CEILING, so the wallet stays protected.
|
||||||
|
UnlimitedUsers map[string]bool
|
||||||
|
|
||||||
// USD-per-1M-token prices applied to the API-returned token usage so the
|
// USD-per-1M-token prices applied to the API-returned token usage so the
|
||||||
// hard ceiling tracks real usage even if the model/price changes.
|
// hard ceiling tracks real usage even if the model/price changes.
|
||||||
PriceInputPerM float64
|
PriceInputPerM float64
|
||||||
|
|
@ -127,6 +131,7 @@ func LoadConfig() (*Config, error) {
|
||||||
SystemPromptPath: getenv("SYSTEM_PROMPT_PATH", "prompts/system_ru.txt"),
|
SystemPromptPath: getenv("SYSTEM_PROMPT_PATH", "prompts/system_ru.txt"),
|
||||||
StateDir: strings.TrimRight(getenv("STATE_DIR", "/state"), "/"),
|
StateDir: strings.TrimRight(getenv("STATE_DIR", "/state"), "/"),
|
||||||
AllowedServers: parseServerSet(getenv("ALLOWED_SERVERS", "")),
|
AllowedServers: parseServerSet(getenv("ALLOWED_SERVERS", "")),
|
||||||
|
UnlimitedUsers: parseServerSet(getenv("UNLIMITED_USERS", "")),
|
||||||
}
|
}
|
||||||
|
|
||||||
var problems []string
|
var problems []string
|
||||||
|
|
@ -216,6 +221,10 @@ func (c *Config) Summary() string {
|
||||||
for s := range c.AllowedServers {
|
for s := range c.AllowedServers {
|
||||||
servers = append(servers, s)
|
servers = append(servers, s)
|
||||||
}
|
}
|
||||||
|
unlimited := make([]string, 0, len(c.UnlimitedUsers))
|
||||||
|
for u := range c.UnlimitedUsers {
|
||||||
|
unlimited = append(unlimited, u)
|
||||||
|
}
|
||||||
redact := func(s string) string {
|
redact := func(s string) string {
|
||||||
if s == "" {
|
if s == "" {
|
||||||
return "(unset)"
|
return "(unset)"
|
||||||
|
|
@ -245,6 +254,7 @@ func (c *Config) Summary() string {
|
||||||
" ALLOWED_SERVERS = " + strings.Join(servers, ","),
|
" ALLOWED_SERVERS = " + strings.Join(servers, ","),
|
||||||
fmt.Sprintf(" DAILY_USD_CEILING = %g", c.DailyUSDCeiling),
|
fmt.Sprintf(" DAILY_USD_CEILING = %g", c.DailyUSDCeiling),
|
||||||
fmt.Sprintf(" PER_USER_DAILY_CAP = %d", c.PerUserDailyCap),
|
fmt.Sprintf(" PER_USER_DAILY_CAP = %d", c.PerUserDailyCap),
|
||||||
|
" UNLIMITED_USERS = " + strings.Join(unlimited, ","),
|
||||||
fmt.Sprintf(" PRICES /1M (in/cached/out) = %g / %g / %g",
|
fmt.Sprintf(" PRICES /1M (in/cached/out) = %g / %g / %g",
|
||||||
c.PriceInputPerM, c.PriceCachedPerM, c.PriceOutputPerM),
|
c.PriceInputPerM, c.PriceCachedPerM, c.PriceOutputPerM),
|
||||||
" SYSTEM_PROMPT_PATH = " + c.SystemPromptPath,
|
" SYSTEM_PROMPT_PATH = " + c.SystemPromptPath,
|
||||||
|
|
|
||||||
|
|
@ -3,16 +3,21 @@ module vojo.chat/ai-bot
|
||||||
go 1.25.0
|
go 1.25.0
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/microcosm-cc/bluemonday v1.0.27
|
||||||
|
github.com/yuin/goldmark v1.8.2
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
modernc.org/sqlite v1.51.0
|
modernc.org/sqlite v1.51.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/aymerick/douceur v0.2.0 // indirect
|
||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
github.com/google/uuid v1.6.0 // indirect
|
github.com/google/uuid v1.6.0 // indirect
|
||||||
|
github.com/gorilla/css v1.0.1 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
|
golang.org/x/net v0.26.0 // indirect
|
||||||
golang.org/x/sys v0.42.0 // indirect
|
golang.org/x/sys v0.42.0 // indirect
|
||||||
modernc.org/libc v1.72.3 // indirect
|
modernc.org/libc v1.72.3 // indirect
|
||||||
modernc.org/mathutil v1.7.1 // indirect
|
modernc.org/mathutil v1.7.1 // indirect
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,29 @@
|
||||||
|
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
|
||||||
|
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
|
||||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
|
||||||
|
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
|
||||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk=
|
||||||
|
github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA=
|
||||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
|
github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE=
|
||||||
|
github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
|
||||||
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
|
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
|
||||||
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
|
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
|
||||||
|
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
|
||||||
|
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
|
||||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
|
|
||||||
44
apps/ai-bot/logging.go
Normal file
44
apps/ai-bot/logging.go
Normal file
|
|
@ -0,0 +1,44 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newLogger builds the process logger from the environment: LOG_LEVEL
|
||||||
|
// (debug|info|warn|error, default info) and LOG_FORMAT (text|json, default
|
||||||
|
// text). It writes to stderr with UTC timestamps (matching the previous
|
||||||
|
// log.LUTC behaviour). Built from getenv directly — not Config — so it exists
|
||||||
|
// before LoadConfig (the generate-registration path logs before config loads).
|
||||||
|
func newLogger() *slog.Logger {
|
||||||
|
opts := &slog.HandlerOptions{
|
||||||
|
Level: parseLogLevel(getenv("LOG_LEVEL", "info")),
|
||||||
|
ReplaceAttr: func(_ []string, a slog.Attr) slog.Attr {
|
||||||
|
if a.Key == slog.TimeKey && a.Value.Kind() == slog.KindTime {
|
||||||
|
a.Value = slog.TimeValue(a.Value.Time().UTC())
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
},
|
||||||
|
}
|
||||||
|
var h slog.Handler
|
||||||
|
if strings.EqualFold(strings.TrimSpace(getenv("LOG_FORMAT", "text")), "json") {
|
||||||
|
h = slog.NewJSONHandler(os.Stderr, opts)
|
||||||
|
} else {
|
||||||
|
h = slog.NewTextHandler(os.Stderr, opts)
|
||||||
|
}
|
||||||
|
return slog.New(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseLogLevel(s string) slog.Level {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(s)) {
|
||||||
|
case "debug":
|
||||||
|
return slog.LevelDebug
|
||||||
|
case "warn", "warning":
|
||||||
|
return slog.LevelWarn
|
||||||
|
case "error":
|
||||||
|
return slog.LevelError
|
||||||
|
default:
|
||||||
|
return slog.LevelInfo
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -9,7 +9,6 @@ package main
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
@ -17,7 +16,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
logger := log.New(os.Stderr, "", log.LstdFlags|log.LUTC)
|
logger := newLogger()
|
||||||
|
|
||||||
// `ai-bot generate-registration` writes a fresh registration.yaml with random
|
// `ai-bot generate-registration` writes a fresh registration.yaml with random
|
||||||
// tokens (the mautrix bridge idiom), then exits. Runs BEFORE LoadConfig — the
|
// tokens (the mautrix bridge idiom), then exits. Runs BEFORE LoadConfig — the
|
||||||
|
|
@ -26,12 +25,14 @@ func main() {
|
||||||
if len(os.Args) > 1 && os.Args[1] == "generate-registration" {
|
if len(os.Args) > 1 && os.Args[1] == "generate-registration" {
|
||||||
mxid := getenv("BOT_MXID", "")
|
mxid := getenv("BOT_MXID", "")
|
||||||
if mxid == "" {
|
if mxid == "" {
|
||||||
logger.Fatalf("BOT_MXID is required to generate the registration")
|
logger.Error("BOT_MXID is required to generate the registration")
|
||||||
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
path := getenv("REGISTRATION_PATH", "/data/registration.yaml")
|
path := getenv("REGISTRATION_PATH", "/data/registration.yaml")
|
||||||
asURL := getenv("AS_URL", "http://ai-bot:8009")
|
asURL := getenv("AS_URL", "http://ai-bot:8009")
|
||||||
if err := GenerateRegistration(path, asURL, localpartOf(mxid), serverOf(mxid)); err != nil {
|
if err := GenerateRegistration(path, asURL, localpartOf(mxid), serverOf(mxid)); err != nil {
|
||||||
logger.Fatalf("generate-registration: %v", err)
|
logger.Error("generate-registration failed", "err", err)
|
||||||
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
fmt.Printf("wrote %s\n", path)
|
fmt.Printf("wrote %s\n", path)
|
||||||
fmt.Println("Next: mount this file into Synapse, add it to app_service_config_files,")
|
fmt.Println("Next: mount this file into Synapse, add it to app_service_config_files,")
|
||||||
|
|
@ -42,14 +43,16 @@ func main() {
|
||||||
|
|
||||||
cfg, err := LoadConfig()
|
cfg, err := LoadConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatalf("config error: %v", err)
|
logger.Error("config error", "err", err)
|
||||||
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load the system prompt up front so a missing/unreadable file fails fast
|
// Load the system prompt up front so a missing/unreadable file fails fast
|
||||||
// at startup rather than on the first message.
|
// at startup rather than on the first message.
|
||||||
promptBytes, err := os.ReadFile(cfg.SystemPromptPath)
|
promptBytes, err := os.ReadFile(cfg.SystemPromptPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatalf("cannot read SYSTEM_PROMPT_PATH (%s): %v", cfg.SystemPromptPath, err)
|
logger.Error("cannot read system prompt", "path", cfg.SystemPromptPath, "err", err)
|
||||||
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
cfg.SystemPrompt = string(promptBytes)
|
cfg.SystemPrompt = string(promptBytes)
|
||||||
|
|
||||||
|
|
@ -64,10 +67,12 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := os.MkdirAll(cfg.StateDir, 0o700); err != nil {
|
if err := os.MkdirAll(cfg.StateDir, 0o700); err != nil {
|
||||||
logger.Fatalf("cannot create STATE_DIR (%s): %v", cfg.StateDir, err)
|
logger.Error("cannot create state dir", "path", cfg.StateDir, "err", err)
|
||||||
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Printf("starting\n%s", cfg.Summary())
|
fmt.Fprintf(os.Stderr, "%s\n", cfg.Summary())
|
||||||
|
logger.Info("starting Vojo AI bot")
|
||||||
|
|
||||||
// Cancel on SIGINT/SIGTERM so the transaction server shuts down cleanly.
|
// Cancel on SIGINT/SIGTERM so the transaction server shuts down cleanly.
|
||||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
@ -75,14 +80,16 @@ func main() {
|
||||||
|
|
||||||
bot, err := NewBot(ctx, cfg, logger)
|
bot, err := NewBot(ctx, cfg, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatalf("startup failed: %v", err)
|
logger.Error("startup failed", "err", err)
|
||||||
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
defer bot.Close()
|
defer bot.Close()
|
||||||
|
|
||||||
if err := bot.Run(ctx); err != nil && ctx.Err() == nil {
|
if err := bot.Run(ctx); err != nil && ctx.Err() == nil {
|
||||||
logger.Fatalf("appservice server exited with error: %v", err)
|
logger.Error("appservice server exited", "err", err)
|
||||||
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
logger.Printf("shut down cleanly")
|
logger.Info("shut down cleanly")
|
||||||
}
|
}
|
||||||
|
|
||||||
// statePath joins a filename under the configured state directory.
|
// statePath joins a filename under the configured state directory.
|
||||||
|
|
|
||||||
128
apps/ai-bot/markdown.go
Normal file
128
apps/ai-bot/markdown.go
Normal file
|
|
@ -0,0 +1,128 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/microcosm-cc/bluemonday"
|
||||||
|
"github.com/yuin/goldmark"
|
||||||
|
"github.com/yuin/goldmark/ast"
|
||||||
|
"github.com/yuin/goldmark/extension"
|
||||||
|
"github.com/yuin/goldmark/renderer"
|
||||||
|
ghtml "github.com/yuin/goldmark/renderer/html"
|
||||||
|
"github.com/yuin/goldmark/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// matrixHTMLFormat is the `format` value that flags `formatted_body` as
|
||||||
|
// org.matrix.custom.html (the only rich format Matrix clients render).
|
||||||
|
const matrixHTMLFormat = "org.matrix.custom.html"
|
||||||
|
|
||||||
|
const (
|
||||||
|
// maxInputBytes / maxFormattedBytes bound the model reply and the rendered
|
||||||
|
// HTML; beyond either we fall back to the plain body (no formatted_body).
|
||||||
|
maxInputBytes = 512 * 1024
|
||||||
|
maxFormattedBytes = 64 * 1024
|
||||||
|
)
|
||||||
|
|
||||||
|
// mdParser converts the model's CommonMark + GFM (tables, strikethrough,
|
||||||
|
// autolink, task lists) answer to HTML. WithUnsafe stays OFF (goldmark's default)
|
||||||
|
// so raw HTML and dangerous URLs are escaped, never rendered; WithHardWraps keeps
|
||||||
|
// the answer's line breaks as <br>; images are rendered as links, not <img> (see
|
||||||
|
// imageLinkRenderer). goldmark depends only on the standard library, so the static
|
||||||
|
// (CGO-free) build is preserved.
|
||||||
|
var mdParser = goldmark.New(
|
||||||
|
goldmark.WithExtensions(extension.GFM),
|
||||||
|
goldmark.WithRendererOptions(
|
||||||
|
ghtml.WithHardWraps(),
|
||||||
|
// Priority < the default renderer's 1000 → registered last → overrides
|
||||||
|
// goldmark's <img> rendering with imageLinkRenderer.
|
||||||
|
renderer.WithNodeRenderers(util.Prioritized(imageLinkRenderer{}, 100)),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
// imageLinkRenderer overrides goldmark's image rendering to emit a link instead of
|
||||||
|
// an <img>, so a markdown image stays functional (a clickable link to its source)
|
||||||
|
// without ever putting a remote <img> in the event — which a client could
|
||||||
|
// auto-load, leaking the viewer's IP to a URL a prompt-injected reply chose.
|
||||||
|
type imageLinkRenderer struct{}
|
||||||
|
|
||||||
|
func (imageLinkRenderer) RegisterFuncs(reg renderer.NodeRendererFuncRegisterer) {
|
||||||
|
reg.Register(ast.KindImage, renderImageAsLink)
|
||||||
|
}
|
||||||
|
|
||||||
|
// renderImageAsLink renders  as <a href="src">alt</a>: the alt content
|
||||||
|
// (the node's children) becomes the link label. Mirrors goldmark's own URL escape
|
||||||
|
// + dangerous-URL guard; bluemonday re-checks the scheme afterwards.
|
||||||
|
func renderImageAsLink(w util.BufWriter, _ []byte, node ast.Node, entering bool) (ast.WalkStatus, error) {
|
||||||
|
n := node.(*ast.Image)
|
||||||
|
if entering {
|
||||||
|
_, _ = w.WriteString(`<a href="`)
|
||||||
|
dest := util.URLEscape(n.Destination, true)
|
||||||
|
if !ghtml.IsDangerousURL(dest) {
|
||||||
|
_, _ = w.Write(util.EscapeHTML(dest))
|
||||||
|
}
|
||||||
|
_, _ = w.WriteString(`">`)
|
||||||
|
} else {
|
||||||
|
_, _ = w.WriteString("</a>")
|
||||||
|
}
|
||||||
|
return ast.WalkContinue, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// htmlPolicy strips goldmark's output to the tags/attributes Cinny's renderer
|
||||||
|
// keeps (src/app/utils/sanitize.ts: permittedHtmlTags / urlSchemes) — defence in
|
||||||
|
// depth over goldmark's own escaping, and the single allowlist a crafted reply
|
||||||
|
// can't get around. Anything else (script/style/img/on*-handlers/unknown URL
|
||||||
|
// schemes) is removed.
|
||||||
|
var htmlPolicy = buildHTMLPolicy()
|
||||||
|
|
||||||
|
func buildHTMLPolicy() *bluemonday.Policy {
|
||||||
|
p := bluemonday.NewPolicy()
|
||||||
|
p.AllowElements(
|
||||||
|
"p", "br", "hr",
|
||||||
|
"h1", "h2", "h3", "h4", "h5", "h6",
|
||||||
|
"strong", "em", "del", "s", "code", "pre",
|
||||||
|
"blockquote", "ul", "ol", "li",
|
||||||
|
"table", "thead", "tbody", "tr", "th", "td",
|
||||||
|
)
|
||||||
|
p.AllowAttrs("href").OnElements("a")
|
||||||
|
p.AllowURLSchemes("https", "http", "ftp", "mailto", "magnet")
|
||||||
|
p.RequireParseableURLs(true)
|
||||||
|
p.AllowAttrs("class").OnElements("code", "pre") // language-xxx on code blocks
|
||||||
|
p.AllowAttrs("start").OnElements("ol")
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// markdownToHTML converts the model's markdown answer to sanitized
|
||||||
|
// org.matrix.custom.html and reports whether any rich formatting was emitted.
|
||||||
|
// When false the caller MUST omit formatted_body so a plain answer renders from
|
||||||
|
// the bare `body` exactly as before (Matrix convention: only attach
|
||||||
|
// formatted_body when it adds formatting the plain body can't carry).
|
||||||
|
func markdownToHTML(md string) (string, bool) {
|
||||||
|
if len(md) > maxInputBytes {
|
||||||
|
return "", false // implausibly large; just send the plain body
|
||||||
|
}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := mdParser.Convert([]byte(md), &buf); err != nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
html := strings.TrimSpace(string(htmlPolicy.SanitizeBytes(buf.Bytes())))
|
||||||
|
if len(html) > maxFormattedBytes {
|
||||||
|
return "", false // too large to be worth sending as a Matrix event
|
||||||
|
}
|
||||||
|
if !hasRichMarkup(html) {
|
||||||
|
return "", false // just a paragraph of text — the plain body is enough
|
||||||
|
}
|
||||||
|
return html, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasRichMarkup reports whether the HTML carries formatting beyond the paragraph
|
||||||
|
// wrapper and soft line breaks goldmark emits for plain text, so a plain reply
|
||||||
|
// keeps rendering from the bare body. Model text is HTML-escaped (a literal '<'
|
||||||
|
// becomes "<"), so any remaining raw '<' is a tag the converter emitted.
|
||||||
|
func hasRichMarkup(html string) bool {
|
||||||
|
stripped := html
|
||||||
|
for _, t := range []string{"<p>", "</p>", "<br>", "<br/>", "<br />"} {
|
||||||
|
stripped = strings.ReplaceAll(stripped, t, "")
|
||||||
|
}
|
||||||
|
return strings.Contains(stripped, "<")
|
||||||
|
}
|
||||||
169
apps/ai-bot/markdown_test.go
Normal file
169
apps/ai-bot/markdown_test.go
Normal file
|
|
@ -0,0 +1,169 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestMarkdownToHTML asserts the rich constructs render and plain text stays
|
||||||
|
// plain. It checks for the meaningful tags/escaping (Contains), not goldmark's
|
||||||
|
// exact byte output — the converter's precise formatting is its own contract, not
|
||||||
|
// ours to pin.
|
||||||
|
func TestMarkdownToHTML(t *testing.T) {
|
||||||
|
rich := []struct {
|
||||||
|
name string
|
||||||
|
in string
|
||||||
|
contains []string
|
||||||
|
}{
|
||||||
|
{"bold", "a **bold** b", []string{"<strong>bold</strong>"}},
|
||||||
|
{"italic star", "a *it* b", []string{"<em>it</em>"}},
|
||||||
|
{"italic underscore", "a _it_ b", []string{"<em>it</em>"}},
|
||||||
|
{"bold italic", "***x***", []string{"<strong>", "<em>", "x"}},
|
||||||
|
{"strikethrough", "~~gone~~", []string{"gone"}}, // <del> or <s>; both rich
|
||||||
|
{"inline code", "use `npm i`", []string{"<code>npm i</code>"}},
|
||||||
|
{"inline code keeps stars literal", "`a*b*c`", []string{"<code>a*b*c</code>"}},
|
||||||
|
{"heading h1", "# Title", []string{"<h1>", "Title", "</h1>"}},
|
||||||
|
{"hr", "---", []string{"<hr"}},
|
||||||
|
{"unordered list", "- one\n- two", []string{"<ul>", "<li>", "one", "two"}},
|
||||||
|
{"ordered list", "1. one\n2. two", []string{"<ol>", "<li>", "one"}},
|
||||||
|
{"blockquote", "> quoted", []string{"<blockquote>", "quoted"}},
|
||||||
|
{"link", "see [xAI](https://x.ai)", []string{`href="https://x.ai"`, "xAI"}},
|
||||||
|
{"fenced code", "```go\nfmt.Println()\n```", []string{"<pre>", "<code", "fmt.Println"}},
|
||||||
|
{"table", "| a | b |\n| - | - |\n| 1 | 2 |", []string{"<table>", "<th>", "a", "<td>", "1"}},
|
||||||
|
{"image as link", "", []string{`href="https://x.ai/logo.png"`, "logo"}},
|
||||||
|
{"autolink bare url", "visit https://x.ai now", []string{`href="https://x.ai"`}},
|
||||||
|
}
|
||||||
|
for _, c := range rich {
|
||||||
|
t.Run("rich/"+c.name, func(t *testing.T) {
|
||||||
|
got, formatted := markdownToHTML(c.in)
|
||||||
|
if !formatted {
|
||||||
|
t.Fatalf("markdownToHTML(%q) formatted=false, want true (got %q)", c.in, got)
|
||||||
|
}
|
||||||
|
for _, sub := range c.contains {
|
||||||
|
if !strings.Contains(got, sub) {
|
||||||
|
t.Fatalf("markdownToHTML(%q) = %q, missing %q", c.in, got, sub)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Plain text (even multi-line or with stray punctuation) carries no
|
||||||
|
// formatting, so the bot sends only the bare body.
|
||||||
|
plain := []string{
|
||||||
|
"just a sentence",
|
||||||
|
"line one\nline two",
|
||||||
|
"a < b & c > d",
|
||||||
|
"2 * 3 * 4",
|
||||||
|
"snake_case_name",
|
||||||
|
"файл_имя_тут",
|
||||||
|
"text with ! bang",
|
||||||
|
`path c:\users`,
|
||||||
|
"",
|
||||||
|
}
|
||||||
|
for _, in := range plain {
|
||||||
|
t.Run("plain", func(t *testing.T) {
|
||||||
|
if got, formatted := markdownToHTML(in); formatted {
|
||||||
|
t.Fatalf("markdownToHTML(%q) formatted=true, want false (got %q)", in, got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarkdownNeverEmitsUnsafeScheme(t *testing.T) {
|
||||||
|
for _, bad := range []string{
|
||||||
|
"[a](javascript:x)", "[a](data:text/html,x)", "[a](vbscript:x)", "[a](file:///etc)",
|
||||||
|
"[a](JaVaScRiPt:x)", "[a](java\tscript:x)",
|
||||||
|
} {
|
||||||
|
if html, _ := markdownToHTML(bad); strings.Contains(html, "href=") {
|
||||||
|
t.Fatalf("emitted a link for unsafe scheme: %q -> %q", bad, html)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarkdownOversizeFallsBackToPlain(t *testing.T) {
|
||||||
|
// A formatted reply whose HTML exceeds maxFormattedBytes must return ("", false)
|
||||||
|
// so the bot sends only the plain body.
|
||||||
|
big := strings.Repeat("- item\n", 8000)
|
||||||
|
if html, formatted := markdownToHTML(big); formatted || html != "" {
|
||||||
|
t.Fatalf("oversize formatted output should fall back to plain: formatted=%v len=%d", formatted, len(html))
|
||||||
|
}
|
||||||
|
// Implausibly large input is rejected outright.
|
||||||
|
huge := strings.Repeat("a", maxInputBytes+1)
|
||||||
|
if html, formatted := markdownToHTML(huge); formatted || html != "" {
|
||||||
|
t.Fatalf("oversize input should fall back to plain: formatted=%v len=%d", formatted, len(html))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarkdownAdversarialNoPanicNoInjection(t *testing.T) {
|
||||||
|
inputs := []string{
|
||||||
|
strings.Repeat("[", 20000) + "x",
|
||||||
|
"x" + strings.Repeat("](https://a)", 20000),
|
||||||
|
strings.Repeat("*", 5000) + "x" + strings.Repeat("*", 5000),
|
||||||
|
strings.Repeat("> ", 5000) + "deep",
|
||||||
|
strings.Repeat(" ", 50) + "- nested",
|
||||||
|
strings.Repeat("`", 4000) + "code",
|
||||||
|
"| " + strings.Repeat("a |", 2000) + "\n| " + strings.Repeat("- |", 2000) + "\n| x |",
|
||||||
|
"<script>alert(1)</script>\n**`<b>`**\n[x](\"><svg onload=alert(1)>)",
|
||||||
|
strings.Repeat("***nest ", 200) + "x" + strings.Repeat(" nest***", 200),
|
||||||
|
}
|
||||||
|
// Every model '<' is escaped to <, so a dangerous element can only exist if
|
||||||
|
// the converter emitted it — and it emits none of these tag names. (Attribute
|
||||||
|
// vectors like onload= can appear only as escaped literal text, which is safe;
|
||||||
|
// the safe-href guarantee is covered by the unit + scheme tests.)
|
||||||
|
for i, in := range inputs {
|
||||||
|
html, _ := markdownToHTML(in) // must not panic
|
||||||
|
for _, tag := range []string{"<script", "<svg", "<img", "<iframe", "<style", "<object", "<embed"} {
|
||||||
|
if strings.Contains(strings.ToLower(html), tag) {
|
||||||
|
t.Fatalf("case %d emitted a dangerous tag %q: %.160q", i, tag, html)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildNoticeContentAttachesFormatted(t *testing.T) {
|
||||||
|
c := buildNoticeContent("$evt", "@u:vojo.chat", nil, "Here is **bold**.")
|
||||||
|
if c["format"] != matrixHTMLFormat {
|
||||||
|
t.Fatalf("format = %v, want %v", c["format"], matrixHTMLFormat)
|
||||||
|
}
|
||||||
|
fb, _ := c["formatted_body"].(string)
|
||||||
|
if !strings.Contains(fb, "<strong>bold</strong>") {
|
||||||
|
t.Fatalf("formatted_body missing bold: %q", fb)
|
||||||
|
}
|
||||||
|
if c["body"] != "Here is **bold**." {
|
||||||
|
t.Fatalf("plain body must be preserved, got %v", c["body"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildNoticeContentSkipsFormattedForPlain(t *testing.T) {
|
||||||
|
c := buildNoticeContent("$evt", "@u:vojo.chat", nil, "no markdown here")
|
||||||
|
if _, ok := c["format"]; ok {
|
||||||
|
t.Fatalf("format must be absent for plain text")
|
||||||
|
}
|
||||||
|
if _, ok := c["formatted_body"]; ok {
|
||||||
|
t.Fatalf("formatted_body must be absent for plain text")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMarkdownNoHangOnBangAndBackslash guards the inline-parser infinite loop: a
|
||||||
|
// '!' not starting an image, or a backslash not before ASCII punctuation
|
||||||
|
// (trailing, or before a letter/space/Cyrillic), used to fall through to a
|
||||||
|
// non-advancing default branch and spin forever — freezing the whole bot under
|
||||||
|
// the transaction mutex. These must all RETURN; if the bug returns this test
|
||||||
|
// hangs and `go test` times out instead of passing.
|
||||||
|
func TestMarkdownNoHangOnBangAndBackslash(t *testing.T) {
|
||||||
|
for _, in := range []string{
|
||||||
|
"Привет!",
|
||||||
|
"Hello! How are you?",
|
||||||
|
`path c:\users`,
|
||||||
|
`trailing backslash \`,
|
||||||
|
`что-то \ или вот это`,
|
||||||
|
`\` + "д",
|
||||||
|
"!",
|
||||||
|
"!!! wow",
|
||||||
|
"text with ! bang",
|
||||||
|
strings.Repeat("a! ", 2000),
|
||||||
|
strings.Repeat(`\`, 2000),
|
||||||
|
} {
|
||||||
|
_, _ = markdownToHTML(in) // a hang fails via test timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -142,6 +142,18 @@ func (c *MatrixClient) SetDisplayName(ctx context.Context, name string) error {
|
||||||
return c.do(ctx, http.MethodPut, path, nil, map[string]any{"displayname": name}, nil)
|
return c.do(ctx, http.MethodPut, path, nil, map[string]any{"displayname": name}, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SendTyping sets or clears the bot user's typing indicator in a room. The
|
||||||
|
// homeserver broadcasts m.typing, which clients render as "… is typing"; the
|
||||||
|
// timeout (ms) applies only when starting and is omitted when clearing.
|
||||||
|
func (c *MatrixClient) SendTyping(ctx context.Context, roomID string, typing bool, timeoutMs int) error {
|
||||||
|
path := "/_matrix/client/v3/rooms/" + url.PathEscape(roomID) + "/typing/" + url.PathEscape(c.asUserID)
|
||||||
|
body := map[string]any{"typing": typing}
|
||||||
|
if typing {
|
||||||
|
body["timeout"] = timeoutMs
|
||||||
|
}
|
||||||
|
return c.do(ctx, http.MethodPut, path, nil, body, nil)
|
||||||
|
}
|
||||||
|
|
||||||
// RoomEncrypted checks live encryption state (F15 — never a join-time snapshot).
|
// RoomEncrypted checks live encryption state (F15 — never a join-time snapshot).
|
||||||
// A 404/M_NOT_FOUND means no m.room.encryption state → unencrypted.
|
// A 404/M_NOT_FOUND means no m.room.encryption state → unencrypted.
|
||||||
func (c *MatrixClient) RoomEncrypted(ctx context.Context, roomID string) (bool, error) {
|
func (c *MatrixClient) RoomEncrypted(ctx context.Context, roomID string) (bool, error) {
|
||||||
|
|
@ -156,23 +168,34 @@ func (c *MatrixClient) RoomEncrypted(ctx context.Context, roomID string) (bool,
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// MemberCounts returns joined+invited counts via /members, used to classify a
|
// RoomMembership returns joined+invited counts and the set of homeservers that
|
||||||
// room as a 1:1 (F3) — appservice transactions carry no room summary.
|
// have a member present (joined or invited). Used both to classify a room as a
|
||||||
func (c *MatrixClient) MemberCounts(ctx context.Context, roomID string) (joined, invited int, err error) {
|
// 1:1 (F3) and to enforce that the bot only stays in rooms hosted entirely on
|
||||||
|
// allowed servers — appservice transactions carry no room summary, so this reads
|
||||||
|
// /members. The member is identified by the event's state_key (the sender is
|
||||||
|
// whoever *set* the membership, which may differ).
|
||||||
|
func (c *MatrixClient) RoomMembership(ctx context.Context, roomID string) (joined, invited int, servers map[string]bool, err error) {
|
||||||
path := "/_matrix/client/v3/rooms/" + url.PathEscape(roomID) + "/members"
|
path := "/_matrix/client/v3/rooms/" + url.PathEscape(roomID) + "/members"
|
||||||
var out struct {
|
var out struct {
|
||||||
Chunk []Event `json:"chunk"`
|
Chunk []Event `json:"chunk"`
|
||||||
}
|
}
|
||||||
if err = c.do(ctx, http.MethodGet, path, nil, nil, &out); err != nil {
|
if err = c.do(ctx, http.MethodGet, path, nil, nil, &out); err != nil {
|
||||||
return 0, 0, err
|
return 0, 0, nil, err
|
||||||
}
|
}
|
||||||
|
servers = make(map[string]bool)
|
||||||
for i := range out.Chunk {
|
for i := range out.Chunk {
|
||||||
switch out.Chunk[i].membershipOf() {
|
e := &out.Chunk[i]
|
||||||
|
if e.StateKey == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch e.membershipOf() {
|
||||||
case "join":
|
case "join":
|
||||||
joined++
|
joined++
|
||||||
|
servers[serverOf(*e.StateKey)] = true
|
||||||
case "invite":
|
case "invite":
|
||||||
invited++
|
invited++
|
||||||
|
servers[serverOf(*e.StateKey)] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return joined, invited, nil
|
return joined, invited, servers, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -8,4 +8,8 @@ const (
|
||||||
"Напишите мне в обычном (незашифрованном) чате."
|
"Напишите мне в обычном (незашифрованном) чате."
|
||||||
|
|
||||||
noticeDailyLimit = "Достигнут дневной лимит обращений к ИИ в этом сервисе. Попробуйте позже."
|
noticeDailyLimit = "Достигнут дневной лимит обращений к ИИ в этом сервисе. Попробуйте позже."
|
||||||
|
|
||||||
|
noticeUserLimit = "Вы исчерпали свой дневной лимит обращений к ИИ. Попробуйте позже."
|
||||||
|
|
||||||
|
noticeError = "⚠️ Не удалось получить ответ от ИИ. Попробуйте ещё раз чуть позже."
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,8 @@ func OpenStore(path string) (*Store, error) {
|
||||||
requests INTEGER NOT NULL DEFAULT 0, usd REAL NOT NULL DEFAULT 0,
|
requests INTEGER NOT NULL DEFAULT 0, usd REAL NOT NULL DEFAULT 0,
|
||||||
PRIMARY KEY (date, mxid)
|
PRIMARY KEY (date, mxid)
|
||||||
);
|
);
|
||||||
CREATE TABLE IF NOT EXISTS warned_encrypted (room_id TEXT PRIMARY KEY);`
|
CREATE TABLE IF NOT EXISTS warned_encrypted (room_id TEXT PRIMARY KEY);
|
||||||
|
CREATE TABLE IF NOT EXISTS processed_event (event_id TEXT PRIMARY KEY);`
|
||||||
if _, err := db.Exec(schema); err != nil {
|
if _, err := db.Exec(schema); err != nil {
|
||||||
db.Close()
|
db.Close()
|
||||||
return nil, fmt.Errorf("init schema: %w", err)
|
return nil, fmt.Errorf("init schema: %w", err)
|
||||||
|
|
@ -72,6 +73,24 @@ func (s *Store) MarkTxn(txnID string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SeenEvent records an event id as handled and reports whether it was NEW (true)
|
||||||
|
// or already seen (false) — the DURABLE equivalent of the in-memory dedup set, so
|
||||||
|
// a crash/restart between handling an event and acking its transaction can't make
|
||||||
|
// the bot reprocess it (dup answer + double-bill + cap inflation). Bounded to the
|
||||||
|
// most recent ids.
|
||||||
|
func (s *Store) SeenEvent(eventID string) (bool, error) {
|
||||||
|
res, err := s.db.Exec(`INSERT OR IGNORE INTO processed_event (event_id) VALUES (?)`, eventID)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if n, _ := res.RowsAffected(); n == 0 {
|
||||||
|
return false, nil // already recorded → not new
|
||||||
|
}
|
||||||
|
_, err = s.db.Exec(`DELETE FROM processed_event WHERE rowid NOT IN
|
||||||
|
(SELECT rowid FROM processed_event ORDER BY rowid DESC LIMIT 20000)`)
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
|
||||||
// SpentTodayUSD sums all spend for the current UTC day.
|
// SpentTodayUSD sums all spend for the current UTC day.
|
||||||
func (s *Store) SpentTodayUSD() (float64, error) {
|
func (s *Store) SpentTodayUSD() (float64, error) {
|
||||||
var v sql.NullFloat64
|
var v sql.NullFloat64
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log/slog"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -18,14 +19,16 @@ type XAIClient struct {
|
||||||
key string
|
key string
|
||||||
http *http.Client
|
http *http.Client
|
||||||
maxTry int
|
maxTry int
|
||||||
|
log *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewXAIClient(base, key string) *XAIClient {
|
func NewXAIClient(base, key string, logger *slog.Logger) *XAIClient {
|
||||||
return &XAIClient{
|
return &XAIClient{
|
||||||
base: base,
|
base: base,
|
||||||
key: key,
|
key: key,
|
||||||
http: &http.Client{},
|
http: &http.Client{},
|
||||||
maxTry: 3,
|
maxTry: 3,
|
||||||
|
log: logger,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -111,6 +114,9 @@ func (x *XAIClient) Complete(ctx context.Context, model string, msgs []xaiMessag
|
||||||
if !retryable {
|
if !retryable {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if x.log != nil {
|
||||||
|
x.log.Warn("xai attempt failed, will retry", "attempt", attempt+1, "max", x.maxTry, "err", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("xai: exhausted %d attempts: %w", x.maxTry, lastErr)
|
return nil, fmt.Errorf("xai: exhausted %d attempts: %w", x.maxTry, lastErr)
|
||||||
}
|
}
|
||||||
|
|
@ -148,9 +154,11 @@ func (x *XAIClient) attempt(ctx context.Context, payload []byte) (*xaiResponse,
|
||||||
if err := json.Unmarshal(data, &out); err != nil {
|
if err := json.Unmarshal(data, &out); err != nil {
|
||||||
return nil, false, fmt.Errorf("xai decode: %w", err)
|
return nil, false, fmt.Errorf("xai decode: %w", err)
|
||||||
}
|
}
|
||||||
if out.Text() == "" {
|
// A 2xx is a billed call even when the model returns empty content (content
|
||||||
return nil, false, fmt.Errorf("xai returned no choices")
|
// filter, finish_reason=length with no text, or no choices). Return it as a
|
||||||
}
|
// success so the caller books the real cost via Reconcile instead of refunding
|
||||||
|
// the slot and losing the spend — which would let empty replies bypass BOTH the
|
||||||
|
// per-user cap and the global ceiling. The caller just won't send an empty body.
|
||||||
return &out, false, nil
|
return &out, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue