Refactor db access to ensure pragmas are set, fix foreign keys

This commit is contained in:
Tim Van Baak 2025-01-19 21:33:49 -08:00
parent 1468c3adc4
commit 2a58c01319
11 changed files with 149 additions and 27 deletions

View File

@ -2,7 +2,6 @@ package cmd
import (
"crypto/rand"
"database/sql"
"encoding/hex"
"fmt"
"log"
@ -58,7 +57,7 @@ func add() {
addId = hex.EncodeToString(bytes)
}
db, err := sql.Open("sqlite3", getDbPath())
db, err := core.OpenDb(getDbPath())
if err != nil {
log.Fatalf("Failed to open %s", dbPath)
}

View File

@ -1,7 +1,6 @@
package cmd
import (
"database/sql"
"fmt"
"log"
@ -35,7 +34,7 @@ func init() {
}
func deactivate() {
db, err := sql.Open("sqlite3", getDbPath())
db, err := core.OpenDb(getDbPath())
if err != nil {
log.Fatalf("Failed to open %s", dbPath)
}

View File

@ -1,7 +1,6 @@
package cmd
import (
"database/sql"
"fmt"
"log"
@ -44,7 +43,7 @@ func feed() {
log.Fatal(err)
}
db, err := sql.Open("sqlite3", getDbPath())
db, err := core.OpenDb(getDbPath())
if err != nil {
log.Fatalf("error: failed to open %s", dbPath)
}

View File

@ -1,7 +1,6 @@
package cmd
import (
"database/sql"
"fmt"
"log"
@ -30,7 +29,7 @@ func init() {
}
func migrate() {
db, err := sql.Open("sqlite3", getDbPath())
db, err := core.OpenDb(getDbPath())
if err != nil {
log.Fatal(err)
}

70
core/db.go Normal file
View File

@ -0,0 +1,70 @@
package core
import (
"database/sql"
"runtime"
_ "github.com/mattn/go-sqlite3"
)
type DB struct {
ro *sql.DB
rw *sql.DB
}
func (db *DB) Query(query string, args ...any) (*sql.Rows, error) {
return db.ro.Query(query, args...)
}
func (db *DB) QueryRow(query string, args ...any) *sql.Row {
return db.ro.QueryRow(query, args...)
}
func (db *DB) Exec(query string, args ...any) (sql.Result, error) {
return db.rw.Exec(query, args...)
}
func defaultPragma(db *sql.DB) (sql.Result, error) {
return db.Exec(`
pragma journal_mode = WAL;
pragma busy_timeout = 5000;
pragma synchronous = NORMAL;
pragma cache_size = 1000000000;
pragma foreign_keys = true;
pragma temp_store = memory;
pragma mmap_size = 3000000000;
`)
}
func OpenDb(dataSourceName string) (*DB, error) {
ro, err := sql.Open("sqlite3", dataSourceName)
if err != nil {
defer ro.Close()
return nil, err
}
ro.SetMaxOpenConns(max(4, runtime.NumCPU()))
_, err = defaultPragma(ro)
if err != nil {
defer ro.Close()
return nil, err
}
rw, err := sql.Open("sqlite3", dataSourceName)
if err != nil {
defer ro.Close()
defer rw.Close()
return nil, err
}
rw.SetMaxOpenConns(1)
_, err = defaultPragma(rw)
if err != nil {
defer ro.Close()
defer rw.Close()
return nil, err
}
wrapper := new(DB)
wrapper.ro = ro
wrapper.rw = rw
return wrapper, nil
}

48
core/db_test.go Normal file
View File

@ -0,0 +1,48 @@
package core
import (
"testing"
_ "github.com/mattn/go-sqlite3"
)
func TestDeleteSourceCascade(t *testing.T) {
db := EphemeralDb(t)
if err := AddSource(db, "source1"); err != nil {
t.Fatal(err)
}
if err := AddSource(db, "source2"); err != nil {
t.Fatal(err)
}
if err := AddItem(db, "source1", "item1", "", "", "", "", 0); err != nil {
t.Fatal(err)
}
if err := AddItem(db, "source2", "item2", "", "", "", "", 0); err != nil {
t.Fatal(err)
}
items, err := GetAllActiveItems(db)
if err != nil {
t.Fatal(err)
}
if len(items) != 2 {
t.Fatal("Expected 2 items")
}
if err := DeleteSource(db, "source1"); err != nil {
t.Fatal(err)
}
items, err = GetAllActiveItems(db)
if err != nil {
t.Fatal(err)
}
if len(items) != 1 {
t.Fatal("Expected only 1 item after source delete")
}
err = AddItem(db, "source1", "item3", "", "", "", "", 0)
if err == nil {
t.Fatal("Unexpected success adding item for nonexistent source")
}
}

View File

@ -1,7 +1,6 @@
package core
import (
"database/sql"
"embed"
"log"
@ -12,7 +11,7 @@ import (
var migrations embed.FS
// Idempotently initialize the database. Safe to call unconditionally.
func InitDatabase(db *sql.DB) error {
func InitDatabase(db *DB) error {
rows, err := db.Query(`
select exists (
select 1
@ -39,7 +38,7 @@ func InitDatabase(db *sql.DB) error {
}
// Get a map of migration names to whether the migration has been applied.
func GetPendingMigrations(db *sql.DB) (map[string]bool, error) {
func GetPendingMigrations(db *DB) (map[string]bool, error) {
allMigrations, err := migrations.ReadDir("sql")
if err != nil {
return nil, err
@ -64,7 +63,7 @@ func GetPendingMigrations(db *sql.DB) (map[string]bool, error) {
}
// Apply a migration by name.
func ApplyMigration(db *sql.DB, name string) error {
func ApplyMigration(db *DB, name string) error {
data, err := migrations.ReadFile("sql/" + name)
if err != nil {
log.Fatalf("Missing migration %s", name)
@ -79,7 +78,7 @@ func ApplyMigration(db *sql.DB, name string) error {
}
// Apply all pending migrations.
func MigrateDatabase(db *sql.DB) error {
func MigrateDatabase(db *DB) error {
pending, err := GetPendingMigrations(db)
if err != nil {
return err

View File

@ -7,11 +7,18 @@ import (
_ "github.com/mattn/go-sqlite3"
)
func EphemeralDb(t *testing.T) *sql.DB {
db, err := sql.Open("sqlite3", ":memory:")
func EphemeralDb(t *testing.T) *DB {
// We don't use OpenDb here because you can't open two connections to the same memory mem
mem, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
if _, err = defaultPragma(mem); err != nil {
t.Fatal(err)
}
db := new(DB)
db.ro = mem
db.rw = mem
if err = InitDatabase(db); err != nil {
t.Fatal(err)
}
@ -22,11 +29,13 @@ func EphemeralDb(t *testing.T) *sql.DB {
}
func TestInitIdempotency(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
mem, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
db := new(DB)
db.ro = mem
db.rw = mem
if err = InitDatabase(db); err != nil {
t.Fatal(err)
}
@ -37,7 +46,6 @@ func TestInitIdempotency(t *testing.T) {
func TestMigrations(t *testing.T) {
db := EphemeralDb(t)
defer db.Close()
allMigrations, err := migrations.ReadDir("sql")
if err != nil {

View File

@ -8,7 +8,7 @@ import (
_ "github.com/mattn/go-sqlite3"
)
func AddSource(db *sql.DB, name string) error {
func AddSource(db *DB, name string) error {
_, err := db.Exec(`
insert into sources (name)
values (?)
@ -17,7 +17,7 @@ func AddSource(db *sql.DB, name string) error {
return err
}
func DeleteSource(db *sql.DB, name string) error {
func DeleteSource(db *DB, name string) error {
_, err := db.Exec(`
delete from sources
where name = ?
@ -27,7 +27,7 @@ func DeleteSource(db *sql.DB, name string) error {
}
func AddItem(
db *sql.DB,
db *DB,
source string,
id string,
title string,
@ -45,7 +45,7 @@ func AddItem(
}
// Deactivate an item, returning its previous active state.
func DeactivateItem(db *sql.DB, source string, id string) (bool, error) {
func DeactivateItem(db *DB, source string, id string) (bool, error) {
row := db.QueryRow(`
select active
from items
@ -68,7 +68,7 @@ func DeactivateItem(db *sql.DB, source string, id string) (bool, error) {
return active, nil
}
func GetAllActiveItems(db *sql.DB) ([]Item, error) {
func GetAllActiveItems(db *DB) ([]Item, error) {
rows, err := db.Query(`
select
source, id, created, active, title, author, body, link, time
@ -87,7 +87,7 @@ func GetAllActiveItems(db *sql.DB) ([]Item, error) {
return items, nil
}
func GetActiveItemsForSource(db *sql.DB, source string) ([]Item, error) {
func GetActiveItemsForSource(db *DB, source string) ([]Item, error) {
rows, err := db.Query(`
select
source, id, created, active, title, author, body, link, time

View File

@ -10,7 +10,6 @@ import (
func TestCreateSource(t *testing.T) {
db := EphemeralDb(t)
defer db.Close()
if err := AddSource(db, "one"); err != nil {
t.Fatal(err)
@ -62,7 +61,6 @@ func AssertItemIs(t *testing.T, item Item, expected string) {
func TestAddItem(t *testing.T) {
db := EphemeralDb(t)
defer db.Close()
if err := AddSource(db, "test"); err != nil {
t.Fatal(err)
}

View File

@ -1,4 +1,7 @@
create table sources(name text) strict;
create table sources(
name text not null,
primary key (name)
) strict;
create table items(
source text not null,
id text not null,