package core import ( "database/sql" "errors" "strings" "testing" _ "github.com/mattn/go-sqlite3" ) type FailureDb struct { db DB queryError func(string, ...any) error execError func(string, ...any) error } func (f *FailureDb) Query(query string, args ...any) (*sql.Rows, error) { if f.queryError != nil { if err := f.queryError(query, args...); err != nil { return nil, err } } return f.db.Query(query, args...) } func (f *FailureDb) QueryRow(query string, args ...any) *sql.Row { return f.db.QueryRow(query, args...) } func (f *FailureDb) Exec(query string, args ...any) (sql.Result, error) { if f.execError != nil { if err := f.execError(query, args...); err != nil { return nil, err } } return f.db.Exec(query, args...) } func (f *FailureDb) Prepare(query string) (*sql.Stmt, error) { return f.db.Prepare(query) } func (f *FailureDb) Transact(txFunc func(DB) error) error { return f.db.Transact(func(tx DB) error { ftx := FailureDb{ db: tx, queryError: f.queryError, execError: f.execError, } return txFunc(&ftx) }) } func TestFailureDb(t *testing.T) { db := EphemeralDb(t) fdb := FailureDb{ db: db, queryError: func(q string, a ...any) error { if strings.Contains(q, "2") { return errors.New("oopsie") } return nil }, } if _, err := fdb.Query("select 1"); err != nil { t.Fatal(err) } if _, err := fdb.Query("select 2"); err == nil { t.Fatal("expected error") } if err := fdb.Transact(func(tx DB) error { if _, err := tx.Query("select 1"); err != nil { t.Fatal(err) } _, err := tx.Query("select 2") return err }); err == nil { t.Fatal("expected error from inside transaction") } } func TestDeleteSourceCascade(t *testing.T) { db := EphemeralDb(t) if err := AddSource(db, "source1"); err != nil { t.Fatalf("failed to add source1: %v", err) } if err := AddSource(db, "source2"); err != nil { t.Fatalf("failed to add source2: %v", err) } if err := AddItems(db, []Item{ {Source: "source1", Id: "item1"}, {Source: "source2", Id: "item2"}, }); err != nil { t.Fatalf("failed to add items: %v", err) } items, err := GetAllActiveItems(db, 0, -1) if err != nil { t.Fatalf("failed to get active items: %v", err) } if len(items) != 2 { t.Fatal("Expected 2 items") } if err := DeleteSource(db, "source1"); err != nil { t.Fatal(err) } items, err = GetAllActiveItems(db, 0, -1) if err != nil { t.Fatal(err) } if len(items) != 1 { t.Fatalf("Expected only 1 item after source delete, got %d", len(items)) } err = AddItems(db, []Item{{Source: "source1", Id: "item3"}}) if err == nil { t.Fatal("Unexpected success adding item for nonexistent source") } } func TestTransaction(t *testing.T) { db := EphemeralDb(t) if _, err := db.Exec("create table planets (name text) strict"); err != nil { t.Fatal(err) } // A transaction that should succeed err := db.Transact(func(tx DB) error { if _, err := tx.Exec("insert into planets (name) values (?)", "mercury"); err != nil { t.Fatal(err) } if _, err := tx.Exec("insert into planets (name) values (?)", "venus"); err != nil { t.Fatal(err) } return nil }) if err != nil { t.Fatal(err) } // Check both rows were inserted rows, err := db.Query("select name from planets") if err != nil { t.Fatal(err) } found := map[string]bool{} for rows.Next() { var name string if err = rows.Scan(&name); err != nil { t.Fatal(err) } found[name] = true } if !found["mercury"] || !found["venus"] { t.Fatal("transaction failed to insert rows") } // A transaction that should fail err = db.Transact(func(tx DB) error { if _, err := tx.Exec("insert into planets (name) values (?)", "earth"); err != nil { t.Fatal(err) } _, err := tx.Exec("insert into planets (name) values (?, ?)", "moon", "surprise asteroid!") return err }) if err == nil { t.Fatal("expected error") } // Check the third insert was rolled back by the error rows, err = db.Query("select name from planets") if err != nil { t.Fatal(err) } found = map[string]bool{} for rows.Next() { var name string if err = rows.Scan(&name); err != nil { t.Fatal(err) } found[name] = true } if found["earth"] { t.Fatal("transaction failed to roll back insert") } }