Add transaction DB implementation

This and the previous commit allow passing transactions to query functions
This commit is contained in:
Tim Van Baak 2025-01-31 08:53:11 -08:00
parent d89f85e141
commit 9c42847ee2
4 changed files with 37 additions and 10 deletions

View File

@ -11,7 +11,8 @@ type DB interface {
Query(query string, args ...any) (*sql.Rows, error) Query(query string, args ...any) (*sql.Rows, error)
QueryRow(query string, args ...any) *sql.Row QueryRow(query string, args ...any) *sql.Row
Exec(query string, args ...any) (sql.Result, error) Exec(query string, args ...any) (sql.Result, error)
Transact(func(*sql.Tx) error) error Prepare(query string) (*sql.Stmt, error)
Transact(func(DB) error) error
} }
type RoRwDb struct { type RoRwDb struct {
@ -31,7 +32,11 @@ func (db *RoRwDb) Exec(query string, args ...any) (sql.Result, error) {
return db.rw.Exec(query, args...) return db.rw.Exec(query, args...)
} }
func (db *RoRwDb) Transact(transaction func(*sql.Tx) error) error { func (db *RoRwDb) Prepare(query string) (*sql.Stmt, error) {
return db.rw.Prepare(query)
}
func (db *RoRwDb) Transact(transaction func(DB) error) error {
tx, err := db.rw.Begin() tx, err := db.rw.Begin()
if err != nil { if err != nil {
return err return err
@ -41,7 +46,7 @@ func (db *RoRwDb) Transact(transaction func(*sql.Tx) error) error {
if err != nil { if err != nil {
return err return err
} }
if err = transaction(tx); err != nil { if err = transaction(&TxDb{tx}); err != nil {
return err return err
} }
if err = tx.Commit(); err != nil { if err = tx.Commit(); err != nil {
@ -50,6 +55,30 @@ func (db *RoRwDb) Transact(transaction func(*sql.Tx) error) error {
return nil return nil
} }
type TxDb struct {
*sql.Tx
}
func (tx *TxDb) Query(query string, args ...any) (*sql.Rows, error) {
return tx.Tx.Query(query, args...)
}
func (tx *TxDb) QueryRow(query string, args ...any) *sql.Row {
return tx.Tx.QueryRow(query, args...)
}
func (tx *TxDb) Exec(query string, args ...any) (sql.Result, error) {
return tx.Tx.Exec(query, args...)
}
func (tx *TxDb) Prepare(query string) (*sql.Stmt, error) {
return tx.Tx.Prepare(query)
}
func (tx *TxDb) Transact(transaction func(DB) error) error {
return transaction(tx)
}
func defaultPragma(db *sql.DB) (sql.Result, error) { func defaultPragma(db *sql.DB) (sql.Result, error) {
return db.Exec(` return db.Exec(`
pragma journal_mode = WAL; pragma journal_mode = WAL;

View File

@ -1,7 +1,6 @@
package core package core
import ( import (
"database/sql"
"testing" "testing"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
@ -55,7 +54,7 @@ func TestTransaction(t *testing.T) {
} }
// A transaction that should succeed // A transaction that should succeed
err := db.Transact(func(tx *sql.Tx) error { err := db.Transact(func(tx DB) error {
if _, err := tx.Exec("insert into planets (name) values (?)", "mercury"); err != nil { if _, err := tx.Exec("insert into planets (name) values (?)", "mercury"); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -86,7 +85,7 @@ func TestTransaction(t *testing.T) {
} }
// A transaction that should fail // A transaction that should fail
err = db.Transact(func(tx *sql.Tx) error { err = db.Transact(func(tx DB) error {
if _, err := tx.Exec("insert into planets (name) values (?)", "earth"); err != nil { if _, err := tx.Exec("insert into planets (name) values (?)", "earth"); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -1,7 +1,6 @@
package core package core
import ( import (
"database/sql"
"fmt" "fmt"
"strings" "strings"
) )
@ -28,7 +27,7 @@ func GetEnvs(db DB, source string) ([]string, error) {
} }
func SetEnvs(db DB, source string, envs []string) error { func SetEnvs(db DB, source string, envs []string) error {
return db.Transact(func(tx *sql.Tx) error { return db.Transact(func(tx DB) error {
for _, env := range envs { for _, env := range envs {
parts := strings.SplitN(env, "=", 2) parts := strings.SplitN(env, "=", 2)
if len(parts) != 2 { if len(parts) != 2 {

View File

@ -49,7 +49,7 @@ func DeleteSource(db DB, name string) error {
} }
func AddItems(db DB, items []Item) error { func AddItems(db DB, items []Item) error {
return db.Transact(func(tx *sql.Tx) error { return db.Transact(func(tx DB) error {
stmt, err := tx.Prepare(` stmt, err := tx.Prepare(`
insert into items (source, id, active, title, author, body, link, time, action) insert into items (source, id, active, title, author, body, link, time, action)
values (?, ?, ?, ?, ?, ?, ?, ?, jsonb(?)) values (?, ?, ?, ?, ?, ?, ?, ?, jsonb(?))
@ -96,7 +96,7 @@ func BackfillItem(new *Item, old *Item) {
} }
func UpdateItems(db DB, items []Item) error { func UpdateItems(db DB, items []Item) error {
return db.Transact(func(tx *sql.Tx) error { return db.Transact(func(tx DB) error {
stmt, err := tx.Prepare(` stmt, err := tx.Prepare(`
update items update items
set set