From 2d7d48846dcb0ddd6da6d6e6ae3338e357be8ed0 Mon Sep 17 00:00:00 2001 From: Tim Van Baak Date: Thu, 23 Jan 2025 10:03:46 -0800 Subject: [PATCH] Add transaction utility --- core/db.go | 19 ++++++++++++++ core/db_test.go | 68 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/core/db.go b/core/db.go index ccc9e33..b5a1aa6 100644 --- a/core/db.go +++ b/core/db.go @@ -24,6 +24,25 @@ func (db *DB) Exec(query string, args ...any) (sql.Result, error) { return db.rw.Exec(query, args...) } +func (db *DB) Transact(transaction func(*sql.Tx) error) error { + tx, err := db.rw.Begin() + if err != nil { + return err + } + defer tx.Rollback() + _, err = tx.Exec("rollback; begin immediate") + if err != nil { + return err + } + if err = transaction(tx); err != nil { + return err + } + if err = tx.Commit(); err != nil { + return err + } + return nil +} + func defaultPragma(db *sql.DB) (sql.Result, error) { return db.Exec(` pragma journal_mode = WAL; diff --git a/core/db_test.go b/core/db_test.go index 9d376a8..d49302a 100644 --- a/core/db_test.go +++ b/core/db_test.go @@ -1,6 +1,7 @@ package core import ( + "database/sql" "testing" _ "github.com/mattn/go-sqlite3" @@ -46,3 +47,70 @@ func TestDeleteSourceCascade(t *testing.T) { t.Fatal("Unexpected success adding item for nonexistent source") } } + +func TestTransaction(t *testing.T) { + db := EphemeralDb(t) + if _, err := db.Exec("create table planets (name text) strict"); err != nil { + t.Fatal(err) + } + + // A transaction that should succeed + err := db.Transact(func(tx *sql.Tx) error { + if _, err := tx.Exec("insert into planets (name) values (?)", "mercury"); err != nil { + t.Fatal(err) + } + if _, err := tx.Exec("insert into planets (name) values (?)", "venus"); err != nil { + t.Fatal(err) + } + return nil + }) + if err != nil { + t.Fatal(err) + } + + // Check both rows were inserted + rows, err := db.Query("select name from planets") + if err != nil { + t.Fatal(err) + } + found := map[string]bool{} + for rows.Next() { + var name string + if err = rows.Scan(&name); err != nil { + t.Fatal(err) + } + found[name] = true + } + if !found["mercury"] || !found["venus"] { + t.Fatal("transaction failed to insert rows") + } + + // A transaction that should fail + err = db.Transact(func(tx *sql.Tx) error { + if _, err := tx.Exec("insert into planets (name) values (?)", "earth"); err != nil { + t.Fatal(err) + } + _, err := tx.Exec("insert into planets (name) values (?, ?)", "moon", "surprise asteroid!") + return err + }) + if err == nil { + t.Fatal("expected error") + } + + // Check the third insert was rolled back by the error + rows, err = db.Query("select name from planets") + if err != nil { + t.Fatal(err) + } + found = map[string]bool{} + for rows.Next() { + var name string + if err = rows.Scan(&name); err != nil { + t.Fatal(err) + } + found[name] = true + } + if found["earth"] { + t.Fatal("transaction failed to roll back insert") + } +}