Convert DB to interface
This commit is contained in:
parent
ced1de05e8
commit
d89f85e141
@ -51,7 +51,7 @@ func getDbPath() string {
|
||||
}
|
||||
|
||||
// Attempt to open the specified database and exit with an error if it fails.
|
||||
func openDb() *core.DB {
|
||||
func openDb() core.DB {
|
||||
dbPath := getDbPath()
|
||||
db, err := core.OpenDb(dbPath)
|
||||
if err != nil {
|
||||
@ -61,7 +61,7 @@ func openDb() *core.DB {
|
||||
}
|
||||
|
||||
// Attempt to open and migrate the specified database and exit with an error if it fails.
|
||||
func openAndMigrateDb() *core.DB {
|
||||
func openAndMigrateDb() core.DB {
|
||||
db := openDb()
|
||||
if err := core.InitDatabase(db); err != nil {
|
||||
log.Fatalf("error: failed to init database: %v", err)
|
||||
|
@ -26,7 +26,7 @@ func (a *argList) Scan(value interface{}) error {
|
||||
return json.Unmarshal([]byte(value.(string)), a)
|
||||
}
|
||||
|
||||
func AddAction(db *DB, source string, name string, argv []string) error {
|
||||
func AddAction(db DB, source string, name string, argv []string) error {
|
||||
_, err := db.Exec(`
|
||||
insert into actions (source, name, argv)
|
||||
values (?, ?, jsonb(?))
|
||||
@ -34,7 +34,7 @@ func AddAction(db *DB, source string, name string, argv []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func UpdateAction(db *DB, source string, name string, argv []string) error {
|
||||
func UpdateAction(db DB, source string, name string, argv []string) error {
|
||||
_, err := db.Exec(`
|
||||
update actions
|
||||
set argv = jsonb(?)
|
||||
@ -43,7 +43,7 @@ func UpdateAction(db *DB, source string, name string, argv []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func GetActionsForSource(db *DB, source string) ([]string, error) {
|
||||
func GetActionsForSource(db DB, source string) ([]string, error) {
|
||||
rows, err := db.Query(`
|
||||
select name
|
||||
from actions
|
||||
@ -64,7 +64,7 @@ func GetActionsForSource(db *DB, source string) ([]string, error) {
|
||||
return names, nil
|
||||
}
|
||||
|
||||
func GetArgvForAction(db *DB, source string, name string) ([]string, error) {
|
||||
func GetArgvForAction(db DB, source string, name string) ([]string, error) {
|
||||
rows := db.QueryRow(`
|
||||
select json(argv)
|
||||
from actions
|
||||
@ -78,7 +78,7 @@ func GetArgvForAction(db *DB, source string, name string) ([]string, error) {
|
||||
return argv, nil
|
||||
}
|
||||
|
||||
func DeleteAction(db *DB, source string, name string) error {
|
||||
func DeleteAction(db DB, source string, name string) error {
|
||||
_, err := db.Exec(`
|
||||
delete from actions
|
||||
where source = ? and name = ?
|
||||
|
21
core/db.go
21
core/db.go
@ -7,24 +7,31 @@ import (
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
type DB interface {
|
||||
Query(query string, args ...any) (*sql.Rows, error)
|
||||
QueryRow(query string, args ...any) *sql.Row
|
||||
Exec(query string, args ...any) (sql.Result, error)
|
||||
Transact(func(*sql.Tx) error) error
|
||||
}
|
||||
|
||||
type RoRwDb struct {
|
||||
ro *sql.DB
|
||||
rw *sql.DB
|
||||
}
|
||||
|
||||
func (db *DB) Query(query string, args ...any) (*sql.Rows, error) {
|
||||
func (db *RoRwDb) Query(query string, args ...any) (*sql.Rows, error) {
|
||||
return db.ro.Query(query, args...)
|
||||
}
|
||||
|
||||
func (db *DB) QueryRow(query string, args ...any) *sql.Row {
|
||||
func (db *RoRwDb) QueryRow(query string, args ...any) *sql.Row {
|
||||
return db.ro.QueryRow(query, args...)
|
||||
}
|
||||
|
||||
func (db *DB) Exec(query string, args ...any) (sql.Result, error) {
|
||||
func (db *RoRwDb) Exec(query string, args ...any) (sql.Result, error) {
|
||||
return db.rw.Exec(query, args...)
|
||||
}
|
||||
|
||||
func (db *DB) Transact(transaction func(*sql.Tx) error) error {
|
||||
func (db *RoRwDb) Transact(transaction func(*sql.Tx) error) error {
|
||||
tx, err := db.rw.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
@ -55,7 +62,7 @@ func defaultPragma(db *sql.DB) (sql.Result, error) {
|
||||
`)
|
||||
}
|
||||
|
||||
func OpenDb(dataSourceName string) (*DB, error) {
|
||||
func OpenDb(dataSourceName string) (DB, error) {
|
||||
ro, err := sql.Open("sqlite3", dataSourceName)
|
||||
if err != nil {
|
||||
defer ro.Close()
|
||||
@ -82,7 +89,7 @@ func OpenDb(dataSourceName string) (*DB, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wrapper := new(DB)
|
||||
wrapper := new(RoRwDb)
|
||||
wrapper.ro = ro
|
||||
wrapper.rw = rw
|
||||
return wrapper, nil
|
||||
|
@ -6,7 +6,7 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetEnvs(db *DB, source string) ([]string, error) {
|
||||
func GetEnvs(db DB, source string) ([]string, error) {
|
||||
rows, err := db.Query(`
|
||||
select name, value
|
||||
from envs
|
||||
@ -27,7 +27,7 @@ func GetEnvs(db *DB, source string) ([]string, error) {
|
||||
return envs, nil
|
||||
}
|
||||
|
||||
func SetEnvs(db *DB, source string, envs []string) error {
|
||||
func SetEnvs(db DB, source string, envs []string) error {
|
||||
return db.Transact(func(tx *sql.Tx) error {
|
||||
for _, env := range envs {
|
||||
parts := strings.SplitN(env, "=", 2)
|
||||
|
@ -11,7 +11,7 @@ import (
|
||||
var migrations embed.FS
|
||||
|
||||
// Idempotently initialize the database. Safe to call unconditionally.
|
||||
func InitDatabase(db *DB) error {
|
||||
func InitDatabase(db DB) error {
|
||||
rows, err := db.Query(`
|
||||
select exists (
|
||||
select 1
|
||||
@ -41,7 +41,7 @@ func InitDatabase(db *DB) error {
|
||||
}
|
||||
|
||||
// Get a map of migration names to whether the migration has been applied.
|
||||
func GetPendingMigrations(db *DB) (map[string]bool, error) {
|
||||
func GetPendingMigrations(db DB) (map[string]bool, error) {
|
||||
allMigrations, err := migrations.ReadDir("sql")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -69,7 +69,7 @@ func GetPendingMigrations(db *DB) (map[string]bool, error) {
|
||||
}
|
||||
|
||||
// Apply a migration by name.
|
||||
func ApplyMigration(db *DB, name string) error {
|
||||
func ApplyMigration(db DB, name string) error {
|
||||
data, err := migrations.ReadFile("sql/" + name)
|
||||
if err != nil {
|
||||
log.Fatalf("Missing migration %s", name)
|
||||
@ -84,7 +84,7 @@ func ApplyMigration(db *DB, name string) error {
|
||||
}
|
||||
|
||||
// Apply all pending migrations.
|
||||
func MigrateDatabase(db *DB) error {
|
||||
func MigrateDatabase(db DB) error {
|
||||
pending, err := GetPendingMigrations(db)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func EphemeralDb(t *testing.T) *DB {
|
||||
func EphemeralDb(t *testing.T) DB {
|
||||
// We don't use OpenDb here because you can't open two connections to the same memory mem
|
||||
mem, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
@ -16,7 +16,7 @@ func EphemeralDb(t *testing.T) *DB {
|
||||
if _, err = defaultPragma(mem); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
db := new(DB)
|
||||
db := new(RoRwDb)
|
||||
db.ro = mem
|
||||
db.rw = mem
|
||||
if err = InitDatabase(db); err != nil {
|
||||
@ -33,7 +33,7 @@ func TestInitIdempotency(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
db := new(DB)
|
||||
db := new(RoRwDb)
|
||||
db.ro = mem
|
||||
db.rw = mem
|
||||
if err = InitDatabase(db); err != nil {
|
||||
|
@ -11,7 +11,7 @@ import (
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func AddSource(db *DB, name string) error {
|
||||
func AddSource(db DB, name string) error {
|
||||
_, err := db.Exec(`
|
||||
insert into sources (name)
|
||||
values (?)
|
||||
@ -20,7 +20,7 @@ func AddSource(db *DB, name string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func GetSources(db *DB) ([]string, error) {
|
||||
func GetSources(db DB) ([]string, error) {
|
||||
rows, err := db.Query(`
|
||||
select name
|
||||
from sources
|
||||
@ -39,7 +39,7 @@ func GetSources(db *DB) ([]string, error) {
|
||||
return names, nil
|
||||
}
|
||||
|
||||
func DeleteSource(db *DB, name string) error {
|
||||
func DeleteSource(db DB, name string) error {
|
||||
_, err := db.Exec(`
|
||||
delete from sources
|
||||
where name = ?
|
||||
@ -48,7 +48,7 @@ func DeleteSource(db *DB, name string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func AddItems(db *DB, items []Item) error {
|
||||
func AddItems(db DB, items []Item) error {
|
||||
return db.Transact(func(tx *sql.Tx) error {
|
||||
stmt, err := tx.Prepare(`
|
||||
insert into items (source, id, active, title, author, body, link, time, action)
|
||||
@ -95,7 +95,7 @@ func BackfillItem(new *Item, old *Item) {
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateItems(db *DB, items []Item) error {
|
||||
func UpdateItems(db DB, items []Item) error {
|
||||
return db.Transact(func(tx *sql.Tx) error {
|
||||
stmt, err := tx.Prepare(`
|
||||
update items
|
||||
@ -127,7 +127,7 @@ func UpdateItems(db *DB, items []Item) error {
|
||||
}
|
||||
|
||||
// Deactivate an item, returning its previous active state.
|
||||
func DeactivateItem(db *DB, source string, id string) (bool, error) {
|
||||
func DeactivateItem(db DB, source string, id string) (bool, error) {
|
||||
row := db.QueryRow(`
|
||||
select active
|
||||
from items
|
||||
@ -150,7 +150,7 @@ func DeactivateItem(db *DB, source string, id string) (bool, error) {
|
||||
return active, nil
|
||||
}
|
||||
|
||||
func DeleteItem(db *DB, source string, id string) (int64, error) {
|
||||
func DeleteItem(db DB, source string, id string) (int64, error) {
|
||||
res, err := db.Exec(`
|
||||
delete from items
|
||||
where source = ?
|
||||
@ -162,7 +162,7 @@ func DeleteItem(db *DB, source string, id string) (int64, error) {
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func getItems(db *DB, query string, args ...any) ([]Item, error) {
|
||||
func getItems(db DB, query string, args ...any) ([]Item, error) {
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -179,7 +179,7 @@ func getItems(db *DB, query string, args ...any) ([]Item, error) {
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func GetItem(db *DB, source string, id string) (Item, error) {
|
||||
func GetItem(db DB, source string, id string) (Item, error) {
|
||||
items, err := getItems(db, `
|
||||
select source, id, created, active, title, author, body, link, time, json(action)
|
||||
from items
|
||||
@ -196,7 +196,7 @@ func GetItem(db *DB, source string, id string) (Item, error) {
|
||||
return items[0], nil
|
||||
}
|
||||
|
||||
func GetAllActiveItems(db *DB) ([]Item, error) {
|
||||
func GetAllActiveItems(db DB) ([]Item, error) {
|
||||
return getItems(db, `
|
||||
select
|
||||
source, id, created, active, title, author, body, link, time, json(action)
|
||||
@ -206,7 +206,7 @@ func GetAllActiveItems(db *DB) ([]Item, error) {
|
||||
`)
|
||||
}
|
||||
|
||||
func GetAllItems(db *DB) ([]Item, error) {
|
||||
func GetAllItems(db DB) ([]Item, error) {
|
||||
return getItems(db, `
|
||||
select
|
||||
source, id, created, active, title, author, body, link, time, json(action)
|
||||
@ -215,7 +215,7 @@ func GetAllItems(db *DB) ([]Item, error) {
|
||||
`)
|
||||
}
|
||||
|
||||
func GetActiveItemsForSource(db *DB, source string) ([]Item, error) {
|
||||
func GetActiveItemsForSource(db DB, source string) ([]Item, error) {
|
||||
return getItems(db, `
|
||||
select
|
||||
source, id, created, active, title, author, body, link, time, json(action)
|
||||
@ -227,7 +227,7 @@ func GetActiveItemsForSource(db *DB, source string) ([]Item, error) {
|
||||
`, source)
|
||||
}
|
||||
|
||||
func GetAllItemsForSource(db *DB, source string) ([]Item, error) {
|
||||
func GetAllItemsForSource(db DB, source string) ([]Item, error) {
|
||||
return getItems(db, `
|
||||
select
|
||||
source, id, created, active, title, author, body, link, time, json(action)
|
||||
@ -238,14 +238,14 @@ func GetAllItemsForSource(db *DB, source string) ([]Item, error) {
|
||||
`, source)
|
||||
}
|
||||
|
||||
func GetState(db *DB, source string) ([]byte, error) {
|
||||
func GetState(db DB, source string) ([]byte, error) {
|
||||
row := db.QueryRow("select state from sources where name = ?", source)
|
||||
var state []byte
|
||||
err := row.Scan(&state)
|
||||
return state, err
|
||||
}
|
||||
|
||||
func SetState(db *DB, source string, state []byte) error {
|
||||
func SetState(db DB, source string, state []byte) error {
|
||||
_, err := db.Exec("update sources set state = ? where name = ?", state, source)
|
||||
return err
|
||||
}
|
||||
@ -253,7 +253,7 @@ func SetState(db *DB, source string, state []byte) error {
|
||||
// Given the results of a fetch, add new items, update existing items, and delete expired items.
|
||||
//
|
||||
// Returns the number of new and deleted items on success.
|
||||
func UpdateWithFetchedItems(db *DB, source string, state []byte, items []Item) (int, int, error) {
|
||||
func UpdateWithFetchedItems(db DB, source string, state []byte, items []Item) (int, int, error) {
|
||||
// Get the existing items
|
||||
existingItems, err := GetAllItemsForSource(db, source)
|
||||
if err != nil {
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
type Env struct {
|
||||
db *core.DB
|
||||
db core.DB
|
||||
}
|
||||
|
||||
func logged(handler http.HandlerFunc) http.HandlerFunc {
|
||||
@ -23,7 +23,7 @@ func handleFunc(pattern string, handler http.HandlerFunc) {
|
||||
http.HandleFunc(pattern, logged(handler))
|
||||
}
|
||||
|
||||
func RunServer(db *core.DB, addr string, port string) {
|
||||
func RunServer(db core.DB, addr string, port string) {
|
||||
env := &Env{db}
|
||||
bind := net.JoinHostPort(addr, port)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user