Convert DB to interface

This commit is contained in:
Tim Van Baak 2025-01-31 08:44:09 -08:00
parent ced1de05e8
commit d89f85e141
8 changed files with 48 additions and 41 deletions

View File

@ -51,7 +51,7 @@ func getDbPath() string {
}
// Attempt to open the specified database and exit with an error if it fails.
func openDb() *core.DB {
func openDb() core.DB {
dbPath := getDbPath()
db, err := core.OpenDb(dbPath)
if err != nil {
@ -61,7 +61,7 @@ func openDb() *core.DB {
}
// Attempt to open and migrate the specified database and exit with an error if it fails.
func openAndMigrateDb() *core.DB {
func openAndMigrateDb() core.DB {
db := openDb()
if err := core.InitDatabase(db); err != nil {
log.Fatalf("error: failed to init database: %v", err)

View File

@ -26,7 +26,7 @@ func (a *argList) Scan(value interface{}) error {
return json.Unmarshal([]byte(value.(string)), a)
}
func AddAction(db *DB, source string, name string, argv []string) error {
func AddAction(db DB, source string, name string, argv []string) error {
_, err := db.Exec(`
insert into actions (source, name, argv)
values (?, ?, jsonb(?))
@ -34,7 +34,7 @@ func AddAction(db *DB, source string, name string, argv []string) error {
return err
}
func UpdateAction(db *DB, source string, name string, argv []string) error {
func UpdateAction(db DB, source string, name string, argv []string) error {
_, err := db.Exec(`
update actions
set argv = jsonb(?)
@ -43,7 +43,7 @@ func UpdateAction(db *DB, source string, name string, argv []string) error {
return err
}
func GetActionsForSource(db *DB, source string) ([]string, error) {
func GetActionsForSource(db DB, source string) ([]string, error) {
rows, err := db.Query(`
select name
from actions
@ -64,7 +64,7 @@ func GetActionsForSource(db *DB, source string) ([]string, error) {
return names, nil
}
func GetArgvForAction(db *DB, source string, name string) ([]string, error) {
func GetArgvForAction(db DB, source string, name string) ([]string, error) {
rows := db.QueryRow(`
select json(argv)
from actions
@ -78,7 +78,7 @@ func GetArgvForAction(db *DB, source string, name string) ([]string, error) {
return argv, nil
}
func DeleteAction(db *DB, source string, name string) error {
func DeleteAction(db DB, source string, name string) error {
_, err := db.Exec(`
delete from actions
where source = ? and name = ?

View File

@ -7,24 +7,31 @@ import (
_ "github.com/mattn/go-sqlite3"
)
type DB struct {
type DB interface {
Query(query string, args ...any) (*sql.Rows, error)
QueryRow(query string, args ...any) *sql.Row
Exec(query string, args ...any) (sql.Result, error)
Transact(func(*sql.Tx) error) error
}
type RoRwDb struct {
ro *sql.DB
rw *sql.DB
}
func (db *DB) Query(query string, args ...any) (*sql.Rows, error) {
func (db *RoRwDb) Query(query string, args ...any) (*sql.Rows, error) {
return db.ro.Query(query, args...)
}
func (db *DB) QueryRow(query string, args ...any) *sql.Row {
func (db *RoRwDb) QueryRow(query string, args ...any) *sql.Row {
return db.ro.QueryRow(query, args...)
}
func (db *DB) Exec(query string, args ...any) (sql.Result, error) {
func (db *RoRwDb) Exec(query string, args ...any) (sql.Result, error) {
return db.rw.Exec(query, args...)
}
func (db *DB) Transact(transaction func(*sql.Tx) error) error {
func (db *RoRwDb) Transact(transaction func(*sql.Tx) error) error {
tx, err := db.rw.Begin()
if err != nil {
return err
@ -55,7 +62,7 @@ func defaultPragma(db *sql.DB) (sql.Result, error) {
`)
}
func OpenDb(dataSourceName string) (*DB, error) {
func OpenDb(dataSourceName string) (DB, error) {
ro, err := sql.Open("sqlite3", dataSourceName)
if err != nil {
defer ro.Close()
@ -82,7 +89,7 @@ func OpenDb(dataSourceName string) (*DB, error) {
return nil, err
}
wrapper := new(DB)
wrapper := new(RoRwDb)
wrapper.ro = ro
wrapper.rw = rw
return wrapper, nil

View File

@ -6,7 +6,7 @@ import (
"strings"
)
func GetEnvs(db *DB, source string) ([]string, error) {
func GetEnvs(db DB, source string) ([]string, error) {
rows, err := db.Query(`
select name, value
from envs
@ -27,7 +27,7 @@ func GetEnvs(db *DB, source string) ([]string, error) {
return envs, nil
}
func SetEnvs(db *DB, source string, envs []string) error {
func SetEnvs(db DB, source string, envs []string) error {
return db.Transact(func(tx *sql.Tx) error {
for _, env := range envs {
parts := strings.SplitN(env, "=", 2)

View File

@ -11,7 +11,7 @@ import (
var migrations embed.FS
// Idempotently initialize the database. Safe to call unconditionally.
func InitDatabase(db *DB) error {
func InitDatabase(db DB) error {
rows, err := db.Query(`
select exists (
select 1
@ -41,7 +41,7 @@ func InitDatabase(db *DB) error {
}
// Get a map of migration names to whether the migration has been applied.
func GetPendingMigrations(db *DB) (map[string]bool, error) {
func GetPendingMigrations(db DB) (map[string]bool, error) {
allMigrations, err := migrations.ReadDir("sql")
if err != nil {
return nil, err
@ -69,7 +69,7 @@ func GetPendingMigrations(db *DB) (map[string]bool, error) {
}
// Apply a migration by name.
func ApplyMigration(db *DB, name string) error {
func ApplyMigration(db DB, name string) error {
data, err := migrations.ReadFile("sql/" + name)
if err != nil {
log.Fatalf("Missing migration %s", name)
@ -84,7 +84,7 @@ func ApplyMigration(db *DB, name string) error {
}
// Apply all pending migrations.
func MigrateDatabase(db *DB) error {
func MigrateDatabase(db DB) error {
pending, err := GetPendingMigrations(db)
if err != nil {
return err

View File

@ -7,7 +7,7 @@ import (
_ "github.com/mattn/go-sqlite3"
)
func EphemeralDb(t *testing.T) *DB {
func EphemeralDb(t *testing.T) DB {
// We don't use OpenDb here because you can't open two connections to the same memory mem
mem, err := sql.Open("sqlite3", ":memory:")
if err != nil {
@ -16,7 +16,7 @@ func EphemeralDb(t *testing.T) *DB {
if _, err = defaultPragma(mem); err != nil {
t.Fatal(err)
}
db := new(DB)
db := new(RoRwDb)
db.ro = mem
db.rw = mem
if err = InitDatabase(db); err != nil {
@ -33,7 +33,7 @@ func TestInitIdempotency(t *testing.T) {
if err != nil {
t.Fatal(err)
}
db := new(DB)
db := new(RoRwDb)
db.ro = mem
db.rw = mem
if err = InitDatabase(db); err != nil {

View File

@ -11,7 +11,7 @@ import (
_ "github.com/mattn/go-sqlite3"
)
func AddSource(db *DB, name string) error {
func AddSource(db DB, name string) error {
_, err := db.Exec(`
insert into sources (name)
values (?)
@ -20,7 +20,7 @@ func AddSource(db *DB, name string) error {
return err
}
func GetSources(db *DB) ([]string, error) {
func GetSources(db DB) ([]string, error) {
rows, err := db.Query(`
select name
from sources
@ -39,7 +39,7 @@ func GetSources(db *DB) ([]string, error) {
return names, nil
}
func DeleteSource(db *DB, name string) error {
func DeleteSource(db DB, name string) error {
_, err := db.Exec(`
delete from sources
where name = ?
@ -48,7 +48,7 @@ func DeleteSource(db *DB, name string) error {
return err
}
func AddItems(db *DB, items []Item) error {
func AddItems(db DB, items []Item) error {
return db.Transact(func(tx *sql.Tx) error {
stmt, err := tx.Prepare(`
insert into items (source, id, active, title, author, body, link, time, action)
@ -95,7 +95,7 @@ func BackfillItem(new *Item, old *Item) {
}
}
func UpdateItems(db *DB, items []Item) error {
func UpdateItems(db DB, items []Item) error {
return db.Transact(func(tx *sql.Tx) error {
stmt, err := tx.Prepare(`
update items
@ -127,7 +127,7 @@ func UpdateItems(db *DB, items []Item) error {
}
// Deactivate an item, returning its previous active state.
func DeactivateItem(db *DB, source string, id string) (bool, error) {
func DeactivateItem(db DB, source string, id string) (bool, error) {
row := db.QueryRow(`
select active
from items
@ -150,7 +150,7 @@ func DeactivateItem(db *DB, source string, id string) (bool, error) {
return active, nil
}
func DeleteItem(db *DB, source string, id string) (int64, error) {
func DeleteItem(db DB, source string, id string) (int64, error) {
res, err := db.Exec(`
delete from items
where source = ?
@ -162,7 +162,7 @@ func DeleteItem(db *DB, source string, id string) (int64, error) {
return res.RowsAffected()
}
func getItems(db *DB, query string, args ...any) ([]Item, error) {
func getItems(db DB, query string, args ...any) ([]Item, error) {
rows, err := db.Query(query, args...)
if err != nil {
return nil, err
@ -179,7 +179,7 @@ func getItems(db *DB, query string, args ...any) ([]Item, error) {
return items, nil
}
func GetItem(db *DB, source string, id string) (Item, error) {
func GetItem(db DB, source string, id string) (Item, error) {
items, err := getItems(db, `
select source, id, created, active, title, author, body, link, time, json(action)
from items
@ -196,7 +196,7 @@ func GetItem(db *DB, source string, id string) (Item, error) {
return items[0], nil
}
func GetAllActiveItems(db *DB) ([]Item, error) {
func GetAllActiveItems(db DB) ([]Item, error) {
return getItems(db, `
select
source, id, created, active, title, author, body, link, time, json(action)
@ -206,7 +206,7 @@ func GetAllActiveItems(db *DB) ([]Item, error) {
`)
}
func GetAllItems(db *DB) ([]Item, error) {
func GetAllItems(db DB) ([]Item, error) {
return getItems(db, `
select
source, id, created, active, title, author, body, link, time, json(action)
@ -215,7 +215,7 @@ func GetAllItems(db *DB) ([]Item, error) {
`)
}
func GetActiveItemsForSource(db *DB, source string) ([]Item, error) {
func GetActiveItemsForSource(db DB, source string) ([]Item, error) {
return getItems(db, `
select
source, id, created, active, title, author, body, link, time, json(action)
@ -227,7 +227,7 @@ func GetActiveItemsForSource(db *DB, source string) ([]Item, error) {
`, source)
}
func GetAllItemsForSource(db *DB, source string) ([]Item, error) {
func GetAllItemsForSource(db DB, source string) ([]Item, error) {
return getItems(db, `
select
source, id, created, active, title, author, body, link, time, json(action)
@ -238,14 +238,14 @@ func GetAllItemsForSource(db *DB, source string) ([]Item, error) {
`, source)
}
func GetState(db *DB, source string) ([]byte, error) {
func GetState(db DB, source string) ([]byte, error) {
row := db.QueryRow("select state from sources where name = ?", source)
var state []byte
err := row.Scan(&state)
return state, err
}
func SetState(db *DB, source string, state []byte) error {
func SetState(db DB, source string, state []byte) error {
_, err := db.Exec("update sources set state = ? where name = ?", state, source)
return err
}
@ -253,7 +253,7 @@ func SetState(db *DB, source string, state []byte) error {
// Given the results of a fetch, add new items, update existing items, and delete expired items.
//
// Returns the number of new and deleted items on success.
func UpdateWithFetchedItems(db *DB, source string, state []byte, items []Item) (int, int, error) {
func UpdateWithFetchedItems(db DB, source string, state []byte, items []Item) (int, int, error) {
// Get the existing items
existingItems, err := GetAllItemsForSource(db, source)
if err != nil {

View File

@ -9,7 +9,7 @@ import (
)
type Env struct {
db *core.DB
db core.DB
}
func logged(handler http.HandlerFunc) http.HandlerFunc {
@ -23,7 +23,7 @@ func handleFunc(pattern string, handler http.HandlerFunc) {
http.HandleFunc(pattern, logged(handler))
}
func RunServer(db *core.DB, addr string, port string) {
func RunServer(db core.DB, addr string, port string) {
env := &Env{db}
bind := net.JoinHostPort(addr, port)