338 lines
10 KiB
Go
338 lines
10 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"log"
|
|
"sync"
|
|
)
|
|
|
|
// roomMeta caches per-room classification we need to handle a message: member
|
|
// counts (for the 1:1 test, F3) and encryption state (F15). Rebuilt per process;
|
|
// unknown fields are lazily fetched from the CS-API on first need — appservice
|
|
// transactions carry no room summary.
|
|
type roomMeta struct {
|
|
joined, invited int
|
|
countsKnown bool
|
|
encrypted, encKnown bool
|
|
}
|
|
|
|
func (m *roomMeta) isDM() bool { return m.countsKnown && m.joined+m.invited == 2 }
|
|
|
|
type Bot struct {
|
|
cfg *Config
|
|
log *log.Logger
|
|
mx *MatrixClient
|
|
xai *XAIClient
|
|
st *Store
|
|
|
|
// Transactions are delivered one at a time by Synapse, but guard the shared
|
|
// maps/sets anyway so an unexpected concurrent call can't corrupt them.
|
|
mu sync.Mutex
|
|
seen *lruSet // event ids already handled (dedup within a session)
|
|
botSent *lruSet // event ids the bot itself sent (reply-parent detection)
|
|
meta map[string]*roomMeta
|
|
buf map[string][]bufferedMsg
|
|
globalNote map[string]string // roomID → UTC date we last sent the daily-limit notice
|
|
}
|
|
|
|
func NewBot(ctx context.Context, cfg *Config, logger *log.Logger) (*Bot, error) {
|
|
mx := NewMatrixClient(cfg.HomeserverURL, cfg.ASToken, cfg.BotMXID)
|
|
xai := NewXAIClient(cfg.XAIBaseURL, cfg.XAIAPIKey)
|
|
|
|
st, err := OpenStore(cfg.statePath("ai-bot.db"))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
b := &Bot{
|
|
cfg: cfg,
|
|
log: logger,
|
|
mx: mx,
|
|
xai: xai,
|
|
st: st,
|
|
seen: newLRUSet(5000),
|
|
botSent: newLRUSet(5000),
|
|
meta: make(map[string]*roomMeta),
|
|
buf: make(map[string][]bufferedMsg),
|
|
globalNote: make(map[string]string),
|
|
}
|
|
|
|
// Confirm the as_token + user_id resolves to BOT_MXID before serving.
|
|
if err := b.verifyIdentity(ctx); err != nil {
|
|
st.Close()
|
|
return nil, err
|
|
}
|
|
// F23: ensure the profile has a display name (best-effort, idempotent).
|
|
if err := mx.SetDisplayName(ctx, cfg.BotDisplayName); err != nil {
|
|
logger.Printf("set display name failed (non-fatal): %v", err)
|
|
}
|
|
return b, nil
|
|
}
|
|
|
|
func (b *Bot) Close() {
|
|
if b.st != nil {
|
|
_ = b.st.Close()
|
|
}
|
|
}
|
|
|
|
func (b *Bot) verifyIdentity(ctx context.Context) error {
|
|
who, err := b.mx.Whoami(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if who != b.cfg.BotMXID {
|
|
b.log.Fatalf("as_token resolves to %q but BOT_MXID is %q", who, b.cfg.BotMXID)
|
|
}
|
|
b.log.Printf("authenticated as %s", who)
|
|
return nil
|
|
}
|
|
|
|
// Run starts the appservice transaction server and blocks until ctx is cancelled.
|
|
func (b *Bot) Run(ctx context.Context) error {
|
|
as := NewAppService(b.cfg, b.log, b.st, b.handleTransaction)
|
|
return as.Serve(ctx)
|
|
}
|
|
|
|
// handleTransaction processes one pushed transaction's events in order.
|
|
func (b *Bot) handleTransaction(ctx context.Context, events []Event) {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
for i := range events {
|
|
b.handleEvent(ctx, &events[i])
|
|
}
|
|
}
|
|
|
|
func (b *Bot) handleEvent(ctx context.Context, ev *Event) {
|
|
if ev.EventID == "" || ev.RoomID == "" {
|
|
return
|
|
}
|
|
if !b.seen.Add(ev.EventID) {
|
|
return
|
|
}
|
|
switch ev.Type {
|
|
case "m.room.member":
|
|
if ev.StateKey != nil && *ev.StateKey == b.cfg.BotMXID {
|
|
b.handleSelfMembership(ctx, ev)
|
|
}
|
|
case "m.room.encryption":
|
|
m := b.getMeta(ev.RoomID)
|
|
m.encrypted, m.encKnown = true, true
|
|
case "m.room.message":
|
|
b.handleMessage(ctx, ev)
|
|
}
|
|
}
|
|
|
|
// handleSelfMembership reacts to membership changes for the bot user: auto-join
|
|
// invites from allowed servers (F11), reject others, forget rooms we leave.
|
|
func (b *Bot) handleSelfMembership(ctx context.Context, ev *Event) {
|
|
switch ev.membershipOf() {
|
|
case "invite":
|
|
if b.cfg.AllowedServers[serverOf(ev.Sender)] {
|
|
b.log.Printf("accepting invite to %s from %s", ev.RoomID, ev.Sender)
|
|
if err := b.mx.JoinRoom(ctx, ev.RoomID); err != nil {
|
|
b.log.Printf("join %s failed: %v", ev.RoomID, err)
|
|
}
|
|
} else {
|
|
b.log.Printf("rejecting invite to %s from %q (server not allowed)", ev.RoomID, ev.Sender)
|
|
if err := b.mx.LeaveRoom(ctx, ev.RoomID); err != nil {
|
|
b.log.Printf("leave (reject) %s failed: %v", ev.RoomID, err)
|
|
}
|
|
}
|
|
case "leave", "ban":
|
|
delete(b.meta, ev.RoomID)
|
|
delete(b.buf, ev.RoomID)
|
|
}
|
|
}
|
|
|
|
func (b *Bot) handleMessage(ctx context.Context, ev *Event) {
|
|
roomID := ev.RoomID
|
|
m := b.getMeta(roomID)
|
|
|
|
// A9/F15: re-check encryption; if (or once) encrypted, warn once and skip —
|
|
// the bot can't read it.
|
|
b.ensureEncryption(ctx, roomID, m)
|
|
if m.encrypted {
|
|
b.warnEncryptedOnce(ctx, roomID)
|
|
return
|
|
}
|
|
|
|
mc, ok := ev.DecodeMessage()
|
|
if !ok {
|
|
return
|
|
}
|
|
// Edits re-carry m.mentions; never re-trigger or replay them (F16).
|
|
if mc.IsReplace() {
|
|
return
|
|
}
|
|
|
|
// Buffer prior context BEFORE classifying so buildContext sees history only.
|
|
history := b.buf[roomID]
|
|
b.appendBuf(roomID, bufferedMsg{sender: ev.Sender, body: mc.Body, isBot: ev.Sender == b.cfg.BotMXID})
|
|
|
|
if ev.Sender == b.cfg.BotMXID {
|
|
return // our own message (also tracked via botSent)
|
|
}
|
|
if mc.MsgType == "m.notice" {
|
|
return // anti-loop: ignore notices (ours and other bots')
|
|
}
|
|
|
|
b.ensureCounts(ctx, roomID, m)
|
|
replyParentIsBot := mc.RelatesTo != nil && mc.RelatesTo.InReplyTo != nil &&
|
|
b.botSent.Has(mc.RelatesTo.InReplyTo.EventID)
|
|
|
|
if !(m.isDM() || mentionsBot(mc, b.cfg.BotMXID, replyParentIsBot)) {
|
|
return
|
|
}
|
|
b.respond(ctx, roomID, m, ev, mc, history)
|
|
}
|
|
|
|
func (b *Bot) respond(ctx context.Context, roomID string, m *roomMeta, ev *Event, mc *MessageContent, history []bufferedMsg) {
|
|
switch res, err := b.st.Reserve(ev.Sender, b.cfg.PerUserDailyCap, b.cfg.DailyUSDCeiling); {
|
|
case err != nil:
|
|
b.log.Printf("limiter reserve failed: %v", err)
|
|
return
|
|
case res == reserveDeniedUser:
|
|
// Silent drop — per-user cap is anti-abuse (F24).
|
|
return
|
|
case res == reserveDeniedGlobal:
|
|
// Global USD ceiling — notice once per room per day, then stay quiet.
|
|
if b.globalNote[roomID] != todayUTC() {
|
|
b.globalNote[roomID] = todayUTC()
|
|
b.sendNotice(ctx, roomID, ev, mc, noticeDailyLimit)
|
|
}
|
|
return
|
|
}
|
|
|
|
msgs := buildContext(b.cfg.SystemPrompt, history, m.isDM(), mc.Body, b.cfg.MaxCtxEvent, 8000)
|
|
resp, err := b.xai.Complete(ctx, b.cfg.XAIModel, msgs, b.cfg.MaxOutTok, b.cfg.XAITemp)
|
|
if err != nil {
|
|
// at-most-once already retried transient failures inside Complete; refund
|
|
// the reserved request so an xAI outage doesn't burn the user's daily cap.
|
|
b.log.Printf("xai completion failed for %s: %v", ev.Sender, err)
|
|
if rerr := b.st.RefundRequest(ev.Sender); rerr != nil {
|
|
b.log.Printf("refund failed: %v", rerr)
|
|
}
|
|
return
|
|
}
|
|
|
|
usd := computeUSD(resp.Usage, b.cfg)
|
|
if err := b.st.Reconcile(ev.Sender, usd); err != nil {
|
|
b.log.Printf("reconcile spend failed: %v", err)
|
|
}
|
|
|
|
b.sendNotice(ctx, roomID, ev, mc, resp.Text())
|
|
}
|
|
|
|
// computeUSD prices the call from the API-returned token usage (authoritative
|
|
// counts) and the configured per-1M prices — so the hard ceiling tracks real
|
|
// usage even if the model/price changes (only the constants need updating).
|
|
func computeUSD(u xaiUsage, cfg *Config) float64 {
|
|
cached := u.PromptTokensDetails.CachedTokens
|
|
nonCached := u.PromptTokens - cached
|
|
if nonCached < 0 {
|
|
nonCached = 0
|
|
}
|
|
return float64(nonCached)/1e6*cfg.PriceInputPerM +
|
|
float64(cached)/1e6*cfg.PriceCachedPerM +
|
|
float64(u.CompletionTokens)/1e6*cfg.PriceOutputPerM
|
|
}
|
|
|
|
func (b *Bot) sendNotice(ctx context.Context, roomID string, trigger *Event, triggerMC *MessageContent, body string) {
|
|
content := buildNoticeContent(trigger.EventID, trigger.Sender, triggerMC.RelatesTo, body)
|
|
id, err := b.mx.SendEvent(ctx, roomID, "m.room.message", content)
|
|
if err != nil {
|
|
b.log.Printf("send notice to %s failed: %v", roomID, err)
|
|
return
|
|
}
|
|
// Track our own reply so a future reply-to-it is recognised as addressing us,
|
|
// and add it to the room buffer as an assistant turn for context continuity.
|
|
b.botSent.Add(id)
|
|
b.appendBuf(roomID, bufferedMsg{sender: b.cfg.BotMXID, body: body, isBot: true})
|
|
}
|
|
|
|
func (b *Bot) warnEncryptedOnce(ctx context.Context, roomID string) {
|
|
warned, err := b.st.HasWarnedEncrypted(roomID)
|
|
if err != nil {
|
|
b.log.Printf("warned-flag read failed: %v", err)
|
|
return
|
|
}
|
|
if warned {
|
|
return
|
|
}
|
|
content := map[string]any{"msgtype": "m.notice", "body": noticeEncryptedUnsupported}
|
|
if _, err := b.mx.SendEvent(ctx, roomID, "m.room.message", content); err != nil {
|
|
b.log.Printf("encrypted-notice to %s failed: %v", roomID, err)
|
|
return
|
|
}
|
|
if err := b.st.SetWarnedEncrypted(roomID); err != nil {
|
|
b.log.Printf("persist warned-flag failed: %v", err)
|
|
}
|
|
}
|
|
|
|
// buildNoticeContent builds the reply. m.notice (not m.text) so the anti-loop
|
|
// skip catches our own output. Thread-aware (F27): a trigger from a thread gets a
|
|
// thread relation so the answer lands in the thread, not the main timeline.
|
|
func buildNoticeContent(replyTo, sender string, triggerRelates *RelatesTo, body string) map[string]any {
|
|
relates := map[string]any{}
|
|
if triggerRelates != nil && triggerRelates.RelType == "m.thread" && triggerRelates.EventID != "" {
|
|
relates["rel_type"] = "m.thread"
|
|
relates["event_id"] = triggerRelates.EventID
|
|
relates["is_falling_back"] = true
|
|
relates["m.in_reply_to"] = map[string]any{"event_id": replyTo}
|
|
} else {
|
|
relates["m.in_reply_to"] = map[string]any{"event_id": replyTo}
|
|
}
|
|
return map[string]any{
|
|
"msgtype": "m.notice",
|
|
"body": body,
|
|
"m.mentions": map[string]any{"user_ids": []string{sender}},
|
|
"m.relates_to": relates,
|
|
}
|
|
}
|
|
|
|
// --- per-room metadata helpers -------------------------------------------------
|
|
|
|
func (b *Bot) getMeta(roomID string) *roomMeta {
|
|
m := b.meta[roomID]
|
|
if m == nil {
|
|
m = &roomMeta{}
|
|
b.meta[roomID] = m
|
|
}
|
|
return m
|
|
}
|
|
|
|
func (b *Bot) ensureEncryption(ctx context.Context, roomID string, m *roomMeta) {
|
|
if m.encKnown {
|
|
return
|
|
}
|
|
enc, err := b.mx.RoomEncrypted(ctx, roomID)
|
|
if err != nil {
|
|
b.log.Printf("encryption probe %s failed: %v", roomID, err)
|
|
return // leave unknown; re-probed on the next message
|
|
}
|
|
m.encrypted, m.encKnown = enc, true
|
|
}
|
|
|
|
func (b *Bot) ensureCounts(ctx context.Context, roomID string, m *roomMeta) {
|
|
if m.countsKnown {
|
|
return
|
|
}
|
|
joined, invited, err := b.mx.MemberCounts(ctx, roomID)
|
|
if err != nil {
|
|
b.log.Printf("member-count probe %s failed: %v", roomID, err)
|
|
return
|
|
}
|
|
m.joined, m.invited, m.countsKnown = joined, invited, true
|
|
}
|
|
|
|
func (b *Bot) appendBuf(roomID string, msg bufferedMsg) {
|
|
limit := b.cfg.MaxCtxEvent * 2
|
|
if limit < 8 {
|
|
limit = 8
|
|
}
|
|
buf := append(b.buf[roomID], msg)
|
|
if len(buf) > limit {
|
|
buf = buf[len(buf)-limit:]
|
|
}
|
|
b.buf[roomID] = buf
|
|
}
|