Convert DB to interface
This commit is contained in:
parent
ced1de05e8
commit
d89f85e141
@ -51,7 +51,7 @@ func getDbPath() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Attempt to open the specified database and exit with an error if it fails.
|
// Attempt to open the specified database and exit with an error if it fails.
|
||||||
func openDb() *core.DB {
|
func openDb() core.DB {
|
||||||
dbPath := getDbPath()
|
dbPath := getDbPath()
|
||||||
db, err := core.OpenDb(dbPath)
|
db, err := core.OpenDb(dbPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -61,7 +61,7 @@ func openDb() *core.DB {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Attempt to open and migrate the specified database and exit with an error if it fails.
|
// Attempt to open and migrate the specified database and exit with an error if it fails.
|
||||||
func openAndMigrateDb() *core.DB {
|
func openAndMigrateDb() core.DB {
|
||||||
db := openDb()
|
db := openDb()
|
||||||
if err := core.InitDatabase(db); err != nil {
|
if err := core.InitDatabase(db); err != nil {
|
||||||
log.Fatalf("error: failed to init database: %v", err)
|
log.Fatalf("error: failed to init database: %v", err)
|
||||||
|
@ -26,7 +26,7 @@ func (a *argList) Scan(value interface{}) error {
|
|||||||
return json.Unmarshal([]byte(value.(string)), a)
|
return json.Unmarshal([]byte(value.(string)), a)
|
||||||
}
|
}
|
||||||
|
|
||||||
func AddAction(db *DB, source string, name string, argv []string) error {
|
func AddAction(db DB, source string, name string, argv []string) error {
|
||||||
_, err := db.Exec(`
|
_, err := db.Exec(`
|
||||||
insert into actions (source, name, argv)
|
insert into actions (source, name, argv)
|
||||||
values (?, ?, jsonb(?))
|
values (?, ?, jsonb(?))
|
||||||
@ -34,7 +34,7 @@ func AddAction(db *DB, source string, name string, argv []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateAction(db *DB, source string, name string, argv []string) error {
|
func UpdateAction(db DB, source string, name string, argv []string) error {
|
||||||
_, err := db.Exec(`
|
_, err := db.Exec(`
|
||||||
update actions
|
update actions
|
||||||
set argv = jsonb(?)
|
set argv = jsonb(?)
|
||||||
@ -43,7 +43,7 @@ func UpdateAction(db *DB, source string, name string, argv []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetActionsForSource(db *DB, source string) ([]string, error) {
|
func GetActionsForSource(db DB, source string) ([]string, error) {
|
||||||
rows, err := db.Query(`
|
rows, err := db.Query(`
|
||||||
select name
|
select name
|
||||||
from actions
|
from actions
|
||||||
@ -64,7 +64,7 @@ func GetActionsForSource(db *DB, source string) ([]string, error) {
|
|||||||
return names, nil
|
return names, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetArgvForAction(db *DB, source string, name string) ([]string, error) {
|
func GetArgvForAction(db DB, source string, name string) ([]string, error) {
|
||||||
rows := db.QueryRow(`
|
rows := db.QueryRow(`
|
||||||
select json(argv)
|
select json(argv)
|
||||||
from actions
|
from actions
|
||||||
@ -78,7 +78,7 @@ func GetArgvForAction(db *DB, source string, name string) ([]string, error) {
|
|||||||
return argv, nil
|
return argv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteAction(db *DB, source string, name string) error {
|
func DeleteAction(db DB, source string, name string) error {
|
||||||
_, err := db.Exec(`
|
_, err := db.Exec(`
|
||||||
delete from actions
|
delete from actions
|
||||||
where source = ? and name = ?
|
where source = ? and name = ?
|
||||||
|
21
core/db.go
21
core/db.go
@ -7,24 +7,31 @@ import (
|
|||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DB struct {
|
type DB interface {
|
||||||
|
Query(query string, args ...any) (*sql.Rows, error)
|
||||||
|
QueryRow(query string, args ...any) *sql.Row
|
||||||
|
Exec(query string, args ...any) (sql.Result, error)
|
||||||
|
Transact(func(*sql.Tx) error) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type RoRwDb struct {
|
||||||
ro *sql.DB
|
ro *sql.DB
|
||||||
rw *sql.DB
|
rw *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) Query(query string, args ...any) (*sql.Rows, error) {
|
func (db *RoRwDb) Query(query string, args ...any) (*sql.Rows, error) {
|
||||||
return db.ro.Query(query, args...)
|
return db.ro.Query(query, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) QueryRow(query string, args ...any) *sql.Row {
|
func (db *RoRwDb) QueryRow(query string, args ...any) *sql.Row {
|
||||||
return db.ro.QueryRow(query, args...)
|
return db.ro.QueryRow(query, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) Exec(query string, args ...any) (sql.Result, error) {
|
func (db *RoRwDb) Exec(query string, args ...any) (sql.Result, error) {
|
||||||
return db.rw.Exec(query, args...)
|
return db.rw.Exec(query, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) Transact(transaction func(*sql.Tx) error) error {
|
func (db *RoRwDb) Transact(transaction func(*sql.Tx) error) error {
|
||||||
tx, err := db.rw.Begin()
|
tx, err := db.rw.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -55,7 +62,7 @@ func defaultPragma(db *sql.DB) (sql.Result, error) {
|
|||||||
`)
|
`)
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenDb(dataSourceName string) (*DB, error) {
|
func OpenDb(dataSourceName string) (DB, error) {
|
||||||
ro, err := sql.Open("sqlite3", dataSourceName)
|
ro, err := sql.Open("sqlite3", dataSourceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
defer ro.Close()
|
defer ro.Close()
|
||||||
@ -82,7 +89,7 @@ func OpenDb(dataSourceName string) (*DB, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
wrapper := new(DB)
|
wrapper := new(RoRwDb)
|
||||||
wrapper.ro = ro
|
wrapper.ro = ro
|
||||||
wrapper.rw = rw
|
wrapper.rw = rw
|
||||||
return wrapper, nil
|
return wrapper, nil
|
||||||
|
@ -6,7 +6,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetEnvs(db *DB, source string) ([]string, error) {
|
func GetEnvs(db DB, source string) ([]string, error) {
|
||||||
rows, err := db.Query(`
|
rows, err := db.Query(`
|
||||||
select name, value
|
select name, value
|
||||||
from envs
|
from envs
|
||||||
@ -27,7 +27,7 @@ func GetEnvs(db *DB, source string) ([]string, error) {
|
|||||||
return envs, nil
|
return envs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetEnvs(db *DB, source string, envs []string) error {
|
func SetEnvs(db DB, source string, envs []string) error {
|
||||||
return db.Transact(func(tx *sql.Tx) error {
|
return db.Transact(func(tx *sql.Tx) error {
|
||||||
for _, env := range envs {
|
for _, env := range envs {
|
||||||
parts := strings.SplitN(env, "=", 2)
|
parts := strings.SplitN(env, "=", 2)
|
||||||
|
@ -11,7 +11,7 @@ import (
|
|||||||
var migrations embed.FS
|
var migrations embed.FS
|
||||||
|
|
||||||
// Idempotently initialize the database. Safe to call unconditionally.
|
// Idempotently initialize the database. Safe to call unconditionally.
|
||||||
func InitDatabase(db *DB) error {
|
func InitDatabase(db DB) error {
|
||||||
rows, err := db.Query(`
|
rows, err := db.Query(`
|
||||||
select exists (
|
select exists (
|
||||||
select 1
|
select 1
|
||||||
@ -41,7 +41,7 @@ func InitDatabase(db *DB) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get a map of migration names to whether the migration has been applied.
|
// Get a map of migration names to whether the migration has been applied.
|
||||||
func GetPendingMigrations(db *DB) (map[string]bool, error) {
|
func GetPendingMigrations(db DB) (map[string]bool, error) {
|
||||||
allMigrations, err := migrations.ReadDir("sql")
|
allMigrations, err := migrations.ReadDir("sql")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -69,7 +69,7 @@ func GetPendingMigrations(db *DB) (map[string]bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Apply a migration by name.
|
// Apply a migration by name.
|
||||||
func ApplyMigration(db *DB, name string) error {
|
func ApplyMigration(db DB, name string) error {
|
||||||
data, err := migrations.ReadFile("sql/" + name)
|
data, err := migrations.ReadFile("sql/" + name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Missing migration %s", name)
|
log.Fatalf("Missing migration %s", name)
|
||||||
@ -84,7 +84,7 @@ func ApplyMigration(db *DB, name string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Apply all pending migrations.
|
// Apply all pending migrations.
|
||||||
func MigrateDatabase(db *DB) error {
|
func MigrateDatabase(db DB) error {
|
||||||
pending, err := GetPendingMigrations(db)
|
pending, err := GetPendingMigrations(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -7,7 +7,7 @@ import (
|
|||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func EphemeralDb(t *testing.T) *DB {
|
func EphemeralDb(t *testing.T) DB {
|
||||||
// We don't use OpenDb here because you can't open two connections to the same memory mem
|
// We don't use OpenDb here because you can't open two connections to the same memory mem
|
||||||
mem, err := sql.Open("sqlite3", ":memory:")
|
mem, err := sql.Open("sqlite3", ":memory:")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -16,7 +16,7 @@ func EphemeralDb(t *testing.T) *DB {
|
|||||||
if _, err = defaultPragma(mem); err != nil {
|
if _, err = defaultPragma(mem); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
db := new(DB)
|
db := new(RoRwDb)
|
||||||
db.ro = mem
|
db.ro = mem
|
||||||
db.rw = mem
|
db.rw = mem
|
||||||
if err = InitDatabase(db); err != nil {
|
if err = InitDatabase(db); err != nil {
|
||||||
@ -33,7 +33,7 @@ func TestInitIdempotency(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
db := new(DB)
|
db := new(RoRwDb)
|
||||||
db.ro = mem
|
db.ro = mem
|
||||||
db.rw = mem
|
db.rw = mem
|
||||||
if err = InitDatabase(db); err != nil {
|
if err = InitDatabase(db); err != nil {
|
||||||
|
@ -11,7 +11,7 @@ import (
|
|||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func AddSource(db *DB, name string) error {
|
func AddSource(db DB, name string) error {
|
||||||
_, err := db.Exec(`
|
_, err := db.Exec(`
|
||||||
insert into sources (name)
|
insert into sources (name)
|
||||||
values (?)
|
values (?)
|
||||||
@ -20,7 +20,7 @@ func AddSource(db *DB, name string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSources(db *DB) ([]string, error) {
|
func GetSources(db DB) ([]string, error) {
|
||||||
rows, err := db.Query(`
|
rows, err := db.Query(`
|
||||||
select name
|
select name
|
||||||
from sources
|
from sources
|
||||||
@ -39,7 +39,7 @@ func GetSources(db *DB) ([]string, error) {
|
|||||||
return names, nil
|
return names, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteSource(db *DB, name string) error {
|
func DeleteSource(db DB, name string) error {
|
||||||
_, err := db.Exec(`
|
_, err := db.Exec(`
|
||||||
delete from sources
|
delete from sources
|
||||||
where name = ?
|
where name = ?
|
||||||
@ -48,7 +48,7 @@ func DeleteSource(db *DB, name string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func AddItems(db *DB, items []Item) error {
|
func AddItems(db DB, items []Item) error {
|
||||||
return db.Transact(func(tx *sql.Tx) error {
|
return db.Transact(func(tx *sql.Tx) error {
|
||||||
stmt, err := tx.Prepare(`
|
stmt, err := tx.Prepare(`
|
||||||
insert into items (source, id, active, title, author, body, link, time, action)
|
insert into items (source, id, active, title, author, body, link, time, action)
|
||||||
@ -95,7 +95,7 @@ func BackfillItem(new *Item, old *Item) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateItems(db *DB, items []Item) error {
|
func UpdateItems(db DB, items []Item) error {
|
||||||
return db.Transact(func(tx *sql.Tx) error {
|
return db.Transact(func(tx *sql.Tx) error {
|
||||||
stmt, err := tx.Prepare(`
|
stmt, err := tx.Prepare(`
|
||||||
update items
|
update items
|
||||||
@ -127,7 +127,7 @@ func UpdateItems(db *DB, items []Item) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Deactivate an item, returning its previous active state.
|
// Deactivate an item, returning its previous active state.
|
||||||
func DeactivateItem(db *DB, source string, id string) (bool, error) {
|
func DeactivateItem(db DB, source string, id string) (bool, error) {
|
||||||
row := db.QueryRow(`
|
row := db.QueryRow(`
|
||||||
select active
|
select active
|
||||||
from items
|
from items
|
||||||
@ -150,7 +150,7 @@ func DeactivateItem(db *DB, source string, id string) (bool, error) {
|
|||||||
return active, nil
|
return active, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteItem(db *DB, source string, id string) (int64, error) {
|
func DeleteItem(db DB, source string, id string) (int64, error) {
|
||||||
res, err := db.Exec(`
|
res, err := db.Exec(`
|
||||||
delete from items
|
delete from items
|
||||||
where source = ?
|
where source = ?
|
||||||
@ -162,7 +162,7 @@ func DeleteItem(db *DB, source string, id string) (int64, error) {
|
|||||||
return res.RowsAffected()
|
return res.RowsAffected()
|
||||||
}
|
}
|
||||||
|
|
||||||
func getItems(db *DB, query string, args ...any) ([]Item, error) {
|
func getItems(db DB, query string, args ...any) ([]Item, error) {
|
||||||
rows, err := db.Query(query, args...)
|
rows, err := db.Query(query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -179,7 +179,7 @@ func getItems(db *DB, query string, args ...any) ([]Item, error) {
|
|||||||
return items, nil
|
return items, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetItem(db *DB, source string, id string) (Item, error) {
|
func GetItem(db DB, source string, id string) (Item, error) {
|
||||||
items, err := getItems(db, `
|
items, err := getItems(db, `
|
||||||
select source, id, created, active, title, author, body, link, time, json(action)
|
select source, id, created, active, title, author, body, link, time, json(action)
|
||||||
from items
|
from items
|
||||||
@ -196,7 +196,7 @@ func GetItem(db *DB, source string, id string) (Item, error) {
|
|||||||
return items[0], nil
|
return items[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllActiveItems(db *DB) ([]Item, error) {
|
func GetAllActiveItems(db DB) ([]Item, error) {
|
||||||
return getItems(db, `
|
return getItems(db, `
|
||||||
select
|
select
|
||||||
source, id, created, active, title, author, body, link, time, json(action)
|
source, id, created, active, title, author, body, link, time, json(action)
|
||||||
@ -206,7 +206,7 @@ func GetAllActiveItems(db *DB) ([]Item, error) {
|
|||||||
`)
|
`)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllItems(db *DB) ([]Item, error) {
|
func GetAllItems(db DB) ([]Item, error) {
|
||||||
return getItems(db, `
|
return getItems(db, `
|
||||||
select
|
select
|
||||||
source, id, created, active, title, author, body, link, time, json(action)
|
source, id, created, active, title, author, body, link, time, json(action)
|
||||||
@ -215,7 +215,7 @@ func GetAllItems(db *DB) ([]Item, error) {
|
|||||||
`)
|
`)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetActiveItemsForSource(db *DB, source string) ([]Item, error) {
|
func GetActiveItemsForSource(db DB, source string) ([]Item, error) {
|
||||||
return getItems(db, `
|
return getItems(db, `
|
||||||
select
|
select
|
||||||
source, id, created, active, title, author, body, link, time, json(action)
|
source, id, created, active, title, author, body, link, time, json(action)
|
||||||
@ -227,7 +227,7 @@ func GetActiveItemsForSource(db *DB, source string) ([]Item, error) {
|
|||||||
`, source)
|
`, source)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllItemsForSource(db *DB, source string) ([]Item, error) {
|
func GetAllItemsForSource(db DB, source string) ([]Item, error) {
|
||||||
return getItems(db, `
|
return getItems(db, `
|
||||||
select
|
select
|
||||||
source, id, created, active, title, author, body, link, time, json(action)
|
source, id, created, active, title, author, body, link, time, json(action)
|
||||||
@ -238,14 +238,14 @@ func GetAllItemsForSource(db *DB, source string) ([]Item, error) {
|
|||||||
`, source)
|
`, source)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetState(db *DB, source string) ([]byte, error) {
|
func GetState(db DB, source string) ([]byte, error) {
|
||||||
row := db.QueryRow("select state from sources where name = ?", source)
|
row := db.QueryRow("select state from sources where name = ?", source)
|
||||||
var state []byte
|
var state []byte
|
||||||
err := row.Scan(&state)
|
err := row.Scan(&state)
|
||||||
return state, err
|
return state, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetState(db *DB, source string, state []byte) error {
|
func SetState(db DB, source string, state []byte) error {
|
||||||
_, err := db.Exec("update sources set state = ? where name = ?", state, source)
|
_, err := db.Exec("update sources set state = ? where name = ?", state, source)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -253,7 +253,7 @@ func SetState(db *DB, source string, state []byte) error {
|
|||||||
// Given the results of a fetch, add new items, update existing items, and delete expired items.
|
// Given the results of a fetch, add new items, update existing items, and delete expired items.
|
||||||
//
|
//
|
||||||
// Returns the number of new and deleted items on success.
|
// Returns the number of new and deleted items on success.
|
||||||
func UpdateWithFetchedItems(db *DB, source string, state []byte, items []Item) (int, int, error) {
|
func UpdateWithFetchedItems(db DB, source string, state []byte, items []Item) (int, int, error) {
|
||||||
// Get the existing items
|
// Get the existing items
|
||||||
existingItems, err := GetAllItemsForSource(db, source)
|
existingItems, err := GetAllItemsForSource(db, source)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -9,7 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Env struct {
|
type Env struct {
|
||||||
db *core.DB
|
db core.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
func logged(handler http.HandlerFunc) http.HandlerFunc {
|
func logged(handler http.HandlerFunc) http.HandlerFunc {
|
||||||
@ -23,7 +23,7 @@ func handleFunc(pattern string, handler http.HandlerFunc) {
|
|||||||
http.HandleFunc(pattern, logged(handler))
|
http.HandleFunc(pattern, logged(handler))
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunServer(db *core.DB, addr string, port string) {
|
func RunServer(db core.DB, addr string, port string) {
|
||||||
env := &Env{db}
|
env := &Env{db}
|
||||||
bind := net.JoinHostPort(addr, port)
|
bind := net.JoinHostPort(addr, port)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user