Add transaction DB implementation

This and the previous commit allow passing transactions to query functions
This commit is contained in:
Tim Van Baak 2025-01-31 08:53:11 -08:00
parent d89f85e141
commit 9c42847ee2
4 changed files with 37 additions and 10 deletions

View File

@ -11,7 +11,8 @@ type DB interface {
Query(query string, args ...any) (*sql.Rows, error)
QueryRow(query string, args ...any) *sql.Row
Exec(query string, args ...any) (sql.Result, error)
Transact(func(*sql.Tx) error) error
Prepare(query string) (*sql.Stmt, error)
Transact(func(DB) error) error
}
type RoRwDb struct {
@ -31,7 +32,11 @@ func (db *RoRwDb) Exec(query string, args ...any) (sql.Result, error) {
return db.rw.Exec(query, args...)
}
func (db *RoRwDb) Transact(transaction func(*sql.Tx) error) error {
func (db *RoRwDb) Prepare(query string) (*sql.Stmt, error) {
return db.rw.Prepare(query)
}
func (db *RoRwDb) Transact(transaction func(DB) error) error {
tx, err := db.rw.Begin()
if err != nil {
return err
@ -41,7 +46,7 @@ func (db *RoRwDb) Transact(transaction func(*sql.Tx) error) error {
if err != nil {
return err
}
if err = transaction(tx); err != nil {
if err = transaction(&TxDb{tx}); err != nil {
return err
}
if err = tx.Commit(); err != nil {
@ -50,6 +55,30 @@ func (db *RoRwDb) Transact(transaction func(*sql.Tx) error) error {
return nil
}
type TxDb struct {
*sql.Tx
}
func (tx *TxDb) Query(query string, args ...any) (*sql.Rows, error) {
return tx.Tx.Query(query, args...)
}
func (tx *TxDb) QueryRow(query string, args ...any) *sql.Row {
return tx.Tx.QueryRow(query, args...)
}
func (tx *TxDb) Exec(query string, args ...any) (sql.Result, error) {
return tx.Tx.Exec(query, args...)
}
func (tx *TxDb) Prepare(query string) (*sql.Stmt, error) {
return tx.Tx.Prepare(query)
}
func (tx *TxDb) Transact(transaction func(DB) error) error {
return transaction(tx)
}
func defaultPragma(db *sql.DB) (sql.Result, error) {
return db.Exec(`
pragma journal_mode = WAL;

View File

@ -1,7 +1,6 @@
package core
import (
"database/sql"
"testing"
_ "github.com/mattn/go-sqlite3"
@ -55,7 +54,7 @@ func TestTransaction(t *testing.T) {
}
// A transaction that should succeed
err := db.Transact(func(tx *sql.Tx) error {
err := db.Transact(func(tx DB) error {
if _, err := tx.Exec("insert into planets (name) values (?)", "mercury"); err != nil {
t.Fatal(err)
}
@ -86,7 +85,7 @@ func TestTransaction(t *testing.T) {
}
// A transaction that should fail
err = db.Transact(func(tx *sql.Tx) error {
err = db.Transact(func(tx DB) error {
if _, err := tx.Exec("insert into planets (name) values (?)", "earth"); err != nil {
t.Fatal(err)
}

View File

@ -1,7 +1,6 @@
package core
import (
"database/sql"
"fmt"
"strings"
)
@ -28,7 +27,7 @@ func GetEnvs(db DB, source string) ([]string, error) {
}
func SetEnvs(db DB, source string, envs []string) error {
return db.Transact(func(tx *sql.Tx) error {
return db.Transact(func(tx DB) error {
for _, env := range envs {
parts := strings.SplitN(env, "=", 2)
if len(parts) != 2 {

View File

@ -49,7 +49,7 @@ func DeleteSource(db DB, name string) error {
}
func AddItems(db DB, items []Item) error {
return db.Transact(func(tx *sql.Tx) error {
return db.Transact(func(tx DB) error {
stmt, err := tx.Prepare(`
insert into items (source, id, active, title, author, body, link, time, action)
values (?, ?, ?, ?, ?, ?, ?, ?, jsonb(?))
@ -96,7 +96,7 @@ func BackfillItem(new *Item, old *Item) {
}
func UpdateItems(db DB, items []Item) error {
return db.Transact(func(tx *sql.Tx) error {
return db.Transact(func(tx DB) error {
stmt, err := tx.Prepare(`
update items
set