151 lines
4.7 KiB
Go
151 lines
4.7 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// testDSN is the throwaway Postgres the store-backed tests run against. When unset,
|
|
// those tests skip rather than fail, so `go test ./...` stays green on a machine
|
|
// without a Postgres (the build/vet gates still cover the package).
|
|
func testDSN() string { return os.Getenv("AI_BOT_TEST_DATABASE_URL") }
|
|
|
|
// openTestStore opens the store against the test database with a clean slate, so a
|
|
// shared/persistent test database doesn't leak rows between tests or runs. Skips the
|
|
// test when AI_BOT_TEST_DATABASE_URL is unset.
|
|
func openTestStore(t *testing.T) *Store {
|
|
t.Helper()
|
|
dsn := testDSN()
|
|
if dsn == "" {
|
|
t.Skip("set AI_BOT_TEST_DATABASE_URL (a throwaway Postgres) to run store-backed tests")
|
|
}
|
|
st, err := OpenStore(dsn)
|
|
if err != nil {
|
|
t.Fatalf("open store: %v", err)
|
|
}
|
|
ctx, cancel := opContext()
|
|
defer cancel()
|
|
if _, err := st.pool.Exec(ctx, `TRUNCATE processed_txn, processed_event, spend, warned_encrypted, request_log, grounding_count`); err != nil {
|
|
st.Close()
|
|
t.Fatalf("truncate test tables: %v", err)
|
|
}
|
|
return st
|
|
}
|
|
|
|
// newTestAS wires an AppService whose handler pushes each dispatched batch onto a
|
|
// channel. Transactions are now processed asynchronously (the 200 is returned before
|
|
// the handler runs), so tests read from the channel with a timeout instead of
|
|
// inspecting a slice immediately after the call.
|
|
func newTestAS(t *testing.T) (*AppService, *Store, chan []Event) {
|
|
t.Helper()
|
|
st := openTestStore(t)
|
|
dispatched := make(chan []Event, 8)
|
|
as := NewAppService(
|
|
&Config{HSToken: "secret", BotMXID: "@ai:vojo.chat"},
|
|
slog.New(slog.NewTextHandler(io.Discard, nil)),
|
|
st,
|
|
func(_ context.Context, ev []Event) { dispatched <- ev },
|
|
)
|
|
as.baseCtx = context.Background()
|
|
return as, st, dispatched
|
|
}
|
|
|
|
// waitDispatch returns the next dispatched batch, or (nil,false) if none arrives
|
|
// within the timeout.
|
|
func waitDispatch(ch chan []Event, timeout time.Duration) ([]Event, bool) {
|
|
select {
|
|
case ev := <-ch:
|
|
return ev, true
|
|
case <-time.After(timeout):
|
|
return nil, false
|
|
}
|
|
}
|
|
|
|
func txnReq(txnID, auth, body string) *http.Request {
|
|
r := httptest.NewRequest(http.MethodPut, "/_matrix/app/v1/transactions/"+txnID, strings.NewReader(body))
|
|
r.SetPathValue("txnId", txnID)
|
|
if auth != "" {
|
|
r.Header.Set("Authorization", "Bearer "+auth)
|
|
}
|
|
return r
|
|
}
|
|
|
|
func TestTransactionAuthAndIdempotency(t *testing.T) {
|
|
as, st, dispatched := newTestAS(t)
|
|
defer st.Close()
|
|
body := `{"events":[{"type":"m.room.message","room_id":"!r:vojo.chat","event_id":"$1","sender":"@u:vojo.chat"}]}`
|
|
|
|
// Bad hs_token → 403, nothing dispatched.
|
|
w := httptest.NewRecorder()
|
|
as.handleTransaction(w, txnReq("txn1", "wrong", body))
|
|
if w.Code != http.StatusForbidden {
|
|
t.Fatalf("bad token: got %d, want 403", w.Code)
|
|
}
|
|
if _, ok := waitDispatch(dispatched, 100*time.Millisecond); ok {
|
|
t.Fatalf("bad token must not dispatch")
|
|
}
|
|
|
|
// Good hs_token → 200, one batch dispatched (asynchronously).
|
|
w = httptest.NewRecorder()
|
|
as.handleTransaction(w, txnReq("txn1", "secret", body))
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("good token: got %d, want 200", w.Code)
|
|
}
|
|
batch, ok := waitDispatch(dispatched, time.Second)
|
|
if !ok || len(batch) != 1 {
|
|
t.Fatalf("expected one dispatched batch of one event, got %v ok=%v", batch, ok)
|
|
}
|
|
|
|
// Same txnId again → idempotent no-op (still 200, no re-dispatch).
|
|
w = httptest.NewRecorder()
|
|
as.handleTransaction(w, txnReq("txn1", "secret", body))
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("retry: got %d, want 200", w.Code)
|
|
}
|
|
if _, ok := waitDispatch(dispatched, 100*time.Millisecond); ok {
|
|
t.Fatalf("retried transaction must not re-dispatch")
|
|
}
|
|
}
|
|
|
|
func TestTransactionLegacyQueryTokenAccepted(t *testing.T) {
|
|
as, st, _ := newTestAS(t)
|
|
defer st.Close()
|
|
|
|
r := httptest.NewRequest(http.MethodPut, "/transactions/txnX?access_token=secret", strings.NewReader(`{"events":[]}`))
|
|
r.SetPathValue("txnId", "txnX")
|
|
w := httptest.NewRecorder()
|
|
as.handleTransaction(w, r)
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("legacy access_token query: got %d, want 200", w.Code)
|
|
}
|
|
}
|
|
|
|
func TestUserQuery(t *testing.T) {
|
|
as, st, _ := newTestAS(t)
|
|
defer st.Close()
|
|
|
|
mk := func(uid string) *http.Request {
|
|
r := httptest.NewRequest(http.MethodGet, "/_matrix/app/v1/users/"+uid, nil)
|
|
r.SetPathValue("userId", uid)
|
|
r.Header.Set("Authorization", "Bearer secret")
|
|
return r
|
|
}
|
|
|
|
w := httptest.NewRecorder()
|
|
as.handleUserQuery(w, mk("@ai:vojo.chat"))
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("own user: got %d, want 200", w.Code)
|
|
}
|
|
w = httptest.NewRecorder()
|
|
as.handleUserQuery(w, mk("@someone:vojo.chat"))
|
|
if w.Code != http.StatusNotFound {
|
|
t.Fatalf("foreign user: got %d, want 404", w.Code)
|
|
}
|
|
}
|