From 9c42847ee21e3d56c46d85272348c0b61b094804 Mon Sep 17 00:00:00 2001 From: Tim Van Baak Date: Fri, 31 Jan 2025 08:53:11 -0800 Subject: [PATCH] Add transaction DB implementation This and the previous commit allow passing transactions to query functions --- core/db.go | 35 ++++++++++++++++++++++++++++++++--- core/db_test.go | 5 ++--- core/env.go | 3 +-- core/source.go | 4 ++-- 4 files changed, 37 insertions(+), 10 deletions(-) diff --git a/core/db.go b/core/db.go index b3650b1..40d6dce 100644 --- a/core/db.go +++ b/core/db.go @@ -11,7 +11,8 @@ type DB interface { Query(query string, args ...any) (*sql.Rows, error) QueryRow(query string, args ...any) *sql.Row Exec(query string, args ...any) (sql.Result, error) - Transact(func(*sql.Tx) error) error + Prepare(query string) (*sql.Stmt, error) + Transact(func(DB) error) error } type RoRwDb struct { @@ -31,7 +32,11 @@ func (db *RoRwDb) Exec(query string, args ...any) (sql.Result, error) { return db.rw.Exec(query, args...) } -func (db *RoRwDb) Transact(transaction func(*sql.Tx) error) error { +func (db *RoRwDb) Prepare(query string) (*sql.Stmt, error) { + return db.rw.Prepare(query) +} + +func (db *RoRwDb) Transact(transaction func(DB) error) error { tx, err := db.rw.Begin() if err != nil { return err @@ -41,7 +46,7 @@ func (db *RoRwDb) Transact(transaction func(*sql.Tx) error) error { if err != nil { return err } - if err = transaction(tx); err != nil { + if err = transaction(&TxDb{tx}); err != nil { return err } if err = tx.Commit(); err != nil { @@ -50,6 +55,30 @@ func (db *RoRwDb) Transact(transaction func(*sql.Tx) error) error { return nil } +type TxDb struct { + *sql.Tx +} + +func (tx *TxDb) Query(query string, args ...any) (*sql.Rows, error) { + return tx.Tx.Query(query, args...) +} + +func (tx *TxDb) QueryRow(query string, args ...any) *sql.Row { + return tx.Tx.QueryRow(query, args...) +} + +func (tx *TxDb) Exec(query string, args ...any) (sql.Result, error) { + return tx.Tx.Exec(query, args...) +} + +func (tx *TxDb) Prepare(query string) (*sql.Stmt, error) { + return tx.Tx.Prepare(query) +} + +func (tx *TxDb) Transact(transaction func(DB) error) error { + return transaction(tx) +} + func defaultPragma(db *sql.DB) (sql.Result, error) { return db.Exec(` pragma journal_mode = WAL; diff --git a/core/db_test.go b/core/db_test.go index 555b8a2..93823fc 100644 --- a/core/db_test.go +++ b/core/db_test.go @@ -1,7 +1,6 @@ package core import ( - "database/sql" "testing" _ "github.com/mattn/go-sqlite3" @@ -55,7 +54,7 @@ func TestTransaction(t *testing.T) { } // A transaction that should succeed - err := db.Transact(func(tx *sql.Tx) error { + err := db.Transact(func(tx DB) error { if _, err := tx.Exec("insert into planets (name) values (?)", "mercury"); err != nil { t.Fatal(err) } @@ -86,7 +85,7 @@ func TestTransaction(t *testing.T) { } // A transaction that should fail - err = db.Transact(func(tx *sql.Tx) error { + err = db.Transact(func(tx DB) error { if _, err := tx.Exec("insert into planets (name) values (?)", "earth"); err != nil { t.Fatal(err) } diff --git a/core/env.go b/core/env.go index c44e109..3330daa 100644 --- a/core/env.go +++ b/core/env.go @@ -1,7 +1,6 @@ package core import ( - "database/sql" "fmt" "strings" ) @@ -28,7 +27,7 @@ func GetEnvs(db DB, source string) ([]string, error) { } func SetEnvs(db DB, source string, envs []string) error { - return db.Transact(func(tx *sql.Tx) error { + return db.Transact(func(tx DB) error { for _, env := range envs { parts := strings.SplitN(env, "=", 2) if len(parts) != 2 { diff --git a/core/source.go b/core/source.go index 34672e1..d7a4051 100644 --- a/core/source.go +++ b/core/source.go @@ -49,7 +49,7 @@ func DeleteSource(db DB, name string) error { } func AddItems(db DB, items []Item) error { - return db.Transact(func(tx *sql.Tx) error { + return db.Transact(func(tx DB) error { stmt, err := tx.Prepare(` insert into items (source, id, active, title, author, body, link, time, action) values (?, ?, ?, ?, ?, ?, ?, ?, jsonb(?)) @@ -96,7 +96,7 @@ func BackfillItem(new *Item, old *Item) { } func UpdateItems(db DB, items []Item) error { - return db.Transact(func(tx *sql.Tx) error { + return db.Transact(func(tx DB) error { stmt, err := tx.Prepare(` update items set