From d89f85e14164f547c513946dc5ffc3b2997bd4bd Mon Sep 17 00:00:00 2001 From: Tim Van Baak Date: Fri, 31 Jan 2025 08:44:09 -0800 Subject: [PATCH] Convert DB to interface --- cmd/root.go | 4 ++-- core/action.go | 10 +++++----- core/db.go | 21 ++++++++++++++------- core/env.go | 4 ++-- core/migrations.go | 8 ++++---- core/migrations_test.go | 6 +++--- core/source.go | 32 ++++++++++++++++---------------- web/main.go | 4 ++-- 8 files changed, 48 insertions(+), 41 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 94415e7..d340f50 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -51,7 +51,7 @@ func getDbPath() string { } // Attempt to open the specified database and exit with an error if it fails. -func openDb() *core.DB { +func openDb() core.DB { dbPath := getDbPath() db, err := core.OpenDb(dbPath) if err != nil { @@ -61,7 +61,7 @@ func openDb() *core.DB { } // Attempt to open and migrate the specified database and exit with an error if it fails. -func openAndMigrateDb() *core.DB { +func openAndMigrateDb() core.DB { db := openDb() if err := core.InitDatabase(db); err != nil { log.Fatalf("error: failed to init database: %v", err) diff --git a/core/action.go b/core/action.go index c33a5ba..f097626 100644 --- a/core/action.go +++ b/core/action.go @@ -26,7 +26,7 @@ func (a *argList) Scan(value interface{}) error { return json.Unmarshal([]byte(value.(string)), a) } -func AddAction(db *DB, source string, name string, argv []string) error { +func AddAction(db DB, source string, name string, argv []string) error { _, err := db.Exec(` insert into actions (source, name, argv) values (?, ?, jsonb(?)) @@ -34,7 +34,7 @@ func AddAction(db *DB, source string, name string, argv []string) error { return err } -func UpdateAction(db *DB, source string, name string, argv []string) error { +func UpdateAction(db DB, source string, name string, argv []string) error { _, err := db.Exec(` update actions set argv = jsonb(?) @@ -43,7 +43,7 @@ func UpdateAction(db *DB, source string, name string, argv []string) error { return err } -func GetActionsForSource(db *DB, source string) ([]string, error) { +func GetActionsForSource(db DB, source string) ([]string, error) { rows, err := db.Query(` select name from actions @@ -64,7 +64,7 @@ func GetActionsForSource(db *DB, source string) ([]string, error) { return names, nil } -func GetArgvForAction(db *DB, source string, name string) ([]string, error) { +func GetArgvForAction(db DB, source string, name string) ([]string, error) { rows := db.QueryRow(` select json(argv) from actions @@ -78,7 +78,7 @@ func GetArgvForAction(db *DB, source string, name string) ([]string, error) { return argv, nil } -func DeleteAction(db *DB, source string, name string) error { +func DeleteAction(db DB, source string, name string) error { _, err := db.Exec(` delete from actions where source = ? and name = ? diff --git a/core/db.go b/core/db.go index b5a1aa6..b3650b1 100644 --- a/core/db.go +++ b/core/db.go @@ -7,24 +7,31 @@ import ( _ "github.com/mattn/go-sqlite3" ) -type DB struct { +type DB interface { + Query(query string, args ...any) (*sql.Rows, error) + QueryRow(query string, args ...any) *sql.Row + Exec(query string, args ...any) (sql.Result, error) + Transact(func(*sql.Tx) error) error +} + +type RoRwDb struct { ro *sql.DB rw *sql.DB } -func (db *DB) Query(query string, args ...any) (*sql.Rows, error) { +func (db *RoRwDb) Query(query string, args ...any) (*sql.Rows, error) { return db.ro.Query(query, args...) } -func (db *DB) QueryRow(query string, args ...any) *sql.Row { +func (db *RoRwDb) QueryRow(query string, args ...any) *sql.Row { return db.ro.QueryRow(query, args...) } -func (db *DB) Exec(query string, args ...any) (sql.Result, error) { +func (db *RoRwDb) Exec(query string, args ...any) (sql.Result, error) { return db.rw.Exec(query, args...) } -func (db *DB) Transact(transaction func(*sql.Tx) error) error { +func (db *RoRwDb) Transact(transaction func(*sql.Tx) error) error { tx, err := db.rw.Begin() if err != nil { return err @@ -55,7 +62,7 @@ func defaultPragma(db *sql.DB) (sql.Result, error) { `) } -func OpenDb(dataSourceName string) (*DB, error) { +func OpenDb(dataSourceName string) (DB, error) { ro, err := sql.Open("sqlite3", dataSourceName) if err != nil { defer ro.Close() @@ -82,7 +89,7 @@ func OpenDb(dataSourceName string) (*DB, error) { return nil, err } - wrapper := new(DB) + wrapper := new(RoRwDb) wrapper.ro = ro wrapper.rw = rw return wrapper, nil diff --git a/core/env.go b/core/env.go index 6868f57..c44e109 100644 --- a/core/env.go +++ b/core/env.go @@ -6,7 +6,7 @@ import ( "strings" ) -func GetEnvs(db *DB, source string) ([]string, error) { +func GetEnvs(db DB, source string) ([]string, error) { rows, err := db.Query(` select name, value from envs @@ -27,7 +27,7 @@ func GetEnvs(db *DB, source string) ([]string, error) { return envs, nil } -func SetEnvs(db *DB, source string, envs []string) error { +func SetEnvs(db DB, source string, envs []string) error { return db.Transact(func(tx *sql.Tx) error { for _, env := range envs { parts := strings.SplitN(env, "=", 2) diff --git a/core/migrations.go b/core/migrations.go index 02f01c0..c5a1d4a 100644 --- a/core/migrations.go +++ b/core/migrations.go @@ -11,7 +11,7 @@ import ( var migrations embed.FS // Idempotently initialize the database. Safe to call unconditionally. -func InitDatabase(db *DB) error { +func InitDatabase(db DB) error { rows, err := db.Query(` select exists ( select 1 @@ -41,7 +41,7 @@ func InitDatabase(db *DB) error { } // Get a map of migration names to whether the migration has been applied. -func GetPendingMigrations(db *DB) (map[string]bool, error) { +func GetPendingMigrations(db DB) (map[string]bool, error) { allMigrations, err := migrations.ReadDir("sql") if err != nil { return nil, err @@ -69,7 +69,7 @@ func GetPendingMigrations(db *DB) (map[string]bool, error) { } // Apply a migration by name. -func ApplyMigration(db *DB, name string) error { +func ApplyMigration(db DB, name string) error { data, err := migrations.ReadFile("sql/" + name) if err != nil { log.Fatalf("Missing migration %s", name) @@ -84,7 +84,7 @@ func ApplyMigration(db *DB, name string) error { } // Apply all pending migrations. -func MigrateDatabase(db *DB) error { +func MigrateDatabase(db DB) error { pending, err := GetPendingMigrations(db) if err != nil { return err diff --git a/core/migrations_test.go b/core/migrations_test.go index 1386f50..7cd1b3c 100644 --- a/core/migrations_test.go +++ b/core/migrations_test.go @@ -7,7 +7,7 @@ import ( _ "github.com/mattn/go-sqlite3" ) -func EphemeralDb(t *testing.T) *DB { +func EphemeralDb(t *testing.T) DB { // We don't use OpenDb here because you can't open two connections to the same memory mem mem, err := sql.Open("sqlite3", ":memory:") if err != nil { @@ -16,7 +16,7 @@ func EphemeralDb(t *testing.T) *DB { if _, err = defaultPragma(mem); err != nil { t.Fatal(err) } - db := new(DB) + db := new(RoRwDb) db.ro = mem db.rw = mem if err = InitDatabase(db); err != nil { @@ -33,7 +33,7 @@ func TestInitIdempotency(t *testing.T) { if err != nil { t.Fatal(err) } - db := new(DB) + db := new(RoRwDb) db.ro = mem db.rw = mem if err = InitDatabase(db); err != nil { diff --git a/core/source.go b/core/source.go index b212a31..34672e1 100644 --- a/core/source.go +++ b/core/source.go @@ -11,7 +11,7 @@ import ( _ "github.com/mattn/go-sqlite3" ) -func AddSource(db *DB, name string) error { +func AddSource(db DB, name string) error { _, err := db.Exec(` insert into sources (name) values (?) @@ -20,7 +20,7 @@ func AddSource(db *DB, name string) error { return err } -func GetSources(db *DB) ([]string, error) { +func GetSources(db DB) ([]string, error) { rows, err := db.Query(` select name from sources @@ -39,7 +39,7 @@ func GetSources(db *DB) ([]string, error) { return names, nil } -func DeleteSource(db *DB, name string) error { +func DeleteSource(db DB, name string) error { _, err := db.Exec(` delete from sources where name = ? @@ -48,7 +48,7 @@ func DeleteSource(db *DB, name string) error { return err } -func AddItems(db *DB, items []Item) error { +func AddItems(db DB, items []Item) error { return db.Transact(func(tx *sql.Tx) error { stmt, err := tx.Prepare(` insert into items (source, id, active, title, author, body, link, time, action) @@ -95,7 +95,7 @@ func BackfillItem(new *Item, old *Item) { } } -func UpdateItems(db *DB, items []Item) error { +func UpdateItems(db DB, items []Item) error { return db.Transact(func(tx *sql.Tx) error { stmt, err := tx.Prepare(` update items @@ -127,7 +127,7 @@ func UpdateItems(db *DB, items []Item) error { } // Deactivate an item, returning its previous active state. -func DeactivateItem(db *DB, source string, id string) (bool, error) { +func DeactivateItem(db DB, source string, id string) (bool, error) { row := db.QueryRow(` select active from items @@ -150,7 +150,7 @@ func DeactivateItem(db *DB, source string, id string) (bool, error) { return active, nil } -func DeleteItem(db *DB, source string, id string) (int64, error) { +func DeleteItem(db DB, source string, id string) (int64, error) { res, err := db.Exec(` delete from items where source = ? @@ -162,7 +162,7 @@ func DeleteItem(db *DB, source string, id string) (int64, error) { return res.RowsAffected() } -func getItems(db *DB, query string, args ...any) ([]Item, error) { +func getItems(db DB, query string, args ...any) ([]Item, error) { rows, err := db.Query(query, args...) if err != nil { return nil, err @@ -179,7 +179,7 @@ func getItems(db *DB, query string, args ...any) ([]Item, error) { return items, nil } -func GetItem(db *DB, source string, id string) (Item, error) { +func GetItem(db DB, source string, id string) (Item, error) { items, err := getItems(db, ` select source, id, created, active, title, author, body, link, time, json(action) from items @@ -196,7 +196,7 @@ func GetItem(db *DB, source string, id string) (Item, error) { return items[0], nil } -func GetAllActiveItems(db *DB) ([]Item, error) { +func GetAllActiveItems(db DB) ([]Item, error) { return getItems(db, ` select source, id, created, active, title, author, body, link, time, json(action) @@ -206,7 +206,7 @@ func GetAllActiveItems(db *DB) ([]Item, error) { `) } -func GetAllItems(db *DB) ([]Item, error) { +func GetAllItems(db DB) ([]Item, error) { return getItems(db, ` select source, id, created, active, title, author, body, link, time, json(action) @@ -215,7 +215,7 @@ func GetAllItems(db *DB) ([]Item, error) { `) } -func GetActiveItemsForSource(db *DB, source string) ([]Item, error) { +func GetActiveItemsForSource(db DB, source string) ([]Item, error) { return getItems(db, ` select source, id, created, active, title, author, body, link, time, json(action) @@ -227,7 +227,7 @@ func GetActiveItemsForSource(db *DB, source string) ([]Item, error) { `, source) } -func GetAllItemsForSource(db *DB, source string) ([]Item, error) { +func GetAllItemsForSource(db DB, source string) ([]Item, error) { return getItems(db, ` select source, id, created, active, title, author, body, link, time, json(action) @@ -238,14 +238,14 @@ func GetAllItemsForSource(db *DB, source string) ([]Item, error) { `, source) } -func GetState(db *DB, source string) ([]byte, error) { +func GetState(db DB, source string) ([]byte, error) { row := db.QueryRow("select state from sources where name = ?", source) var state []byte err := row.Scan(&state) return state, err } -func SetState(db *DB, source string, state []byte) error { +func SetState(db DB, source string, state []byte) error { _, err := db.Exec("update sources set state = ? where name = ?", state, source) return err } @@ -253,7 +253,7 @@ func SetState(db *DB, source string, state []byte) error { // Given the results of a fetch, add new items, update existing items, and delete expired items. // // Returns the number of new and deleted items on success. -func UpdateWithFetchedItems(db *DB, source string, state []byte, items []Item) (int, int, error) { +func UpdateWithFetchedItems(db DB, source string, state []byte, items []Item) (int, int, error) { // Get the existing items existingItems, err := GetAllItemsForSource(db, source) if err != nil { diff --git a/web/main.go b/web/main.go index 675e9dd..b02a9c0 100644 --- a/web/main.go +++ b/web/main.go @@ -9,7 +9,7 @@ import ( ) type Env struct { - db *core.DB + db core.DB } func logged(handler http.HandlerFunc) http.HandlerFunc { @@ -23,7 +23,7 @@ func handleFunc(pattern string, handler http.HandlerFunc) { http.HandleFunc(pattern, logged(handler)) } -func RunServer(db *core.DB, addr string, port string) { +func RunServer(db core.DB, addr string, port string) { env := &Env{db} bind := net.JoinHostPort(addr, port)