package main import ( "context" "io" "log/slog" "net/http" "net/http/httptest" "os" "strings" "testing" "time" ) // testDSN is the throwaway Postgres the store-backed tests run against. When unset, // those tests skip rather than fail, so `go test ./...` stays green on a machine // without a Postgres (the build/vet gates still cover the package). func testDSN() string { return os.Getenv("AI_BOT_TEST_DATABASE_URL") } // openTestStore opens the store against the test database with a clean slate, so a // shared/persistent test database doesn't leak rows between tests or runs. Skips the // test when AI_BOT_TEST_DATABASE_URL is unset. func openTestStore(t *testing.T) *Store { t.Helper() dsn := testDSN() if dsn == "" { t.Skip("set AI_BOT_TEST_DATABASE_URL (a throwaway Postgres) to run store-backed tests") } st, err := OpenStore(dsn) if err != nil { t.Fatalf("open store: %v", err) } ctx, cancel := opContext() defer cancel() if _, err := st.pool.Exec(ctx, `TRUNCATE processed_txn, processed_event, spend, warned_encrypted`); err != nil { st.Close() t.Fatalf("truncate test tables: %v", err) } return st } // newTestAS wires an AppService whose handler pushes each dispatched batch onto a // channel. Transactions are now processed asynchronously (the 200 is returned before // the handler runs), so tests read from the channel with a timeout instead of // inspecting a slice immediately after the call. func newTestAS(t *testing.T) (*AppService, *Store, chan []Event) { t.Helper() st := openTestStore(t) dispatched := make(chan []Event, 8) as := NewAppService( &Config{HSToken: "secret", BotMXID: "@ai:vojo.chat"}, slog.New(slog.NewTextHandler(io.Discard, nil)), st, func(_ context.Context, ev []Event) { dispatched <- ev }, ) as.baseCtx = context.Background() return as, st, dispatched } // waitDispatch returns the next dispatched batch, or (nil,false) if none arrives // within the timeout. func waitDispatch(ch chan []Event, timeout time.Duration) ([]Event, bool) { select { case ev := <-ch: return ev, true case <-time.After(timeout): return nil, false } } func txnReq(txnID, auth, body string) *http.Request { r := httptest.NewRequest(http.MethodPut, "/_matrix/app/v1/transactions/"+txnID, strings.NewReader(body)) r.SetPathValue("txnId", txnID) if auth != "" { r.Header.Set("Authorization", "Bearer "+auth) } return r } func TestTransactionAuthAndIdempotency(t *testing.T) { as, st, dispatched := newTestAS(t) defer st.Close() body := `{"events":[{"type":"m.room.message","room_id":"!r:vojo.chat","event_id":"$1","sender":"@u:vojo.chat"}]}` // Bad hs_token → 403, nothing dispatched. w := httptest.NewRecorder() as.handleTransaction(w, txnReq("txn1", "wrong", body)) if w.Code != http.StatusForbidden { t.Fatalf("bad token: got %d, want 403", w.Code) } if _, ok := waitDispatch(dispatched, 100*time.Millisecond); ok { t.Fatalf("bad token must not dispatch") } // Good hs_token → 200, one batch dispatched (asynchronously). w = httptest.NewRecorder() as.handleTransaction(w, txnReq("txn1", "secret", body)) if w.Code != http.StatusOK { t.Fatalf("good token: got %d, want 200", w.Code) } batch, ok := waitDispatch(dispatched, time.Second) if !ok || len(batch) != 1 { t.Fatalf("expected one dispatched batch of one event, got %v ok=%v", batch, ok) } // Same txnId again → idempotent no-op (still 200, no re-dispatch). w = httptest.NewRecorder() as.handleTransaction(w, txnReq("txn1", "secret", body)) if w.Code != http.StatusOK { t.Fatalf("retry: got %d, want 200", w.Code) } if _, ok := waitDispatch(dispatched, 100*time.Millisecond); ok { t.Fatalf("retried transaction must not re-dispatch") } } func TestTransactionLegacyQueryTokenAccepted(t *testing.T) { as, st, _ := newTestAS(t) defer st.Close() r := httptest.NewRequest(http.MethodPut, "/transactions/txnX?access_token=secret", strings.NewReader(`{"events":[]}`)) r.SetPathValue("txnId", "txnX") w := httptest.NewRecorder() as.handleTransaction(w, r) if w.Code != http.StatusOK { t.Fatalf("legacy access_token query: got %d, want 200", w.Code) } } func TestUserQuery(t *testing.T) { as, st, _ := newTestAS(t) defer st.Close() mk := func(uid string) *http.Request { r := httptest.NewRequest(http.MethodGet, "/_matrix/app/v1/users/"+uid, nil) r.SetPathValue("userId", uid) r.Header.Set("Authorization", "Bearer secret") return r } w := httptest.NewRecorder() as.handleUserQuery(w, mk("@ai:vojo.chat")) if w.Code != http.StatusOK { t.Fatalf("own user: got %d, want 200", w.Code) } w = httptest.NewRecorder() as.handleUserQuery(w, mk("@someone:vojo.chat")) if w.Code != http.StatusNotFound { t.Fatalf("foreign user: got %d, want 404", w.Code) } }