diff --git a/cmd/add.go b/cmd/add.go index 2b3d401..597aa3c 100644 --- a/cmd/add.go +++ b/cmd/add.go @@ -2,7 +2,6 @@ package cmd import ( "crypto/rand" - "database/sql" "encoding/hex" "fmt" "log" @@ -58,7 +57,7 @@ func add() { addId = hex.EncodeToString(bytes) } - db, err := sql.Open("sqlite3", getDbPath()) + db, err := core.OpenDb(getDbPath()) if err != nil { log.Fatalf("Failed to open %s", dbPath) } diff --git a/cmd/deactivate.go b/cmd/deactivate.go index d5665b9..bd5a233 100644 --- a/cmd/deactivate.go +++ b/cmd/deactivate.go @@ -1,7 +1,6 @@ package cmd import ( - "database/sql" "fmt" "log" @@ -35,7 +34,7 @@ func init() { } func deactivate() { - db, err := sql.Open("sqlite3", getDbPath()) + db, err := core.OpenDb(getDbPath()) if err != nil { log.Fatalf("Failed to open %s", dbPath) } diff --git a/cmd/feed.go b/cmd/feed.go index 2c22087..0260036 100644 --- a/cmd/feed.go +++ b/cmd/feed.go @@ -1,7 +1,6 @@ package cmd import ( - "database/sql" "fmt" "log" @@ -44,7 +43,7 @@ func feed() { log.Fatal(err) } - db, err := sql.Open("sqlite3", getDbPath()) + db, err := core.OpenDb(getDbPath()) if err != nil { log.Fatalf("error: failed to open %s", dbPath) } diff --git a/cmd/migrate.go b/cmd/migrate.go index de19d60..6fe4b73 100644 --- a/cmd/migrate.go +++ b/cmd/migrate.go @@ -1,7 +1,6 @@ package cmd import ( - "database/sql" "fmt" "log" @@ -30,7 +29,7 @@ func init() { } func migrate() { - db, err := sql.Open("sqlite3", getDbPath()) + db, err := core.OpenDb(getDbPath()) if err != nil { log.Fatal(err) } diff --git a/core/db.go b/core/db.go new file mode 100644 index 0000000..ccc9e33 --- /dev/null +++ b/core/db.go @@ -0,0 +1,70 @@ +package core + +import ( + "database/sql" + "runtime" + + _ "github.com/mattn/go-sqlite3" +) + +type DB struct { + ro *sql.DB + rw *sql.DB +} + +func (db *DB) Query(query string, args ...any) (*sql.Rows, error) { + return db.ro.Query(query, args...) +} + +func (db *DB) QueryRow(query string, args ...any) *sql.Row { + return db.ro.QueryRow(query, args...) +} + +func (db *DB) Exec(query string, args ...any) (sql.Result, error) { + return db.rw.Exec(query, args...) +} + +func defaultPragma(db *sql.DB) (sql.Result, error) { + return db.Exec(` + pragma journal_mode = WAL; + pragma busy_timeout = 5000; + pragma synchronous = NORMAL; + pragma cache_size = 1000000000; + pragma foreign_keys = true; + pragma temp_store = memory; + pragma mmap_size = 3000000000; + `) +} + +func OpenDb(dataSourceName string) (*DB, error) { + ro, err := sql.Open("sqlite3", dataSourceName) + if err != nil { + defer ro.Close() + return nil, err + } + ro.SetMaxOpenConns(max(4, runtime.NumCPU())) + _, err = defaultPragma(ro) + if err != nil { + defer ro.Close() + return nil, err + } + + rw, err := sql.Open("sqlite3", dataSourceName) + if err != nil { + defer ro.Close() + defer rw.Close() + return nil, err + } + rw.SetMaxOpenConns(1) + _, err = defaultPragma(rw) + if err != nil { + defer ro.Close() + defer rw.Close() + return nil, err + } + + wrapper := new(DB) + wrapper.ro = ro + wrapper.rw = rw + return wrapper, nil +} diff --git a/core/db_test.go b/core/db_test.go new file mode 100644 index 0000000..9d376a8 --- /dev/null +++ b/core/db_test.go @@ -0,0 +1,48 @@ +package core + +import ( + "testing" + + _ "github.com/mattn/go-sqlite3" +) + +func TestDeleteSourceCascade(t *testing.T) { + db := EphemeralDb(t) + + if err := AddSource(db, "source1"); err != nil { + t.Fatal(err) + } + if err := AddSource(db, "source2"); err != nil { + t.Fatal(err) + } + if err := AddItem(db, "source1", "item1", "", "", "", "", 0); err != nil { + t.Fatal(err) + } + if err := AddItem(db, "source2", "item2", "", "", "", "", 0); err != nil { + t.Fatal(err) + } + + items, err := GetAllActiveItems(db) + if err != nil { + t.Fatal(err) + } + if len(items) != 2 { + t.Fatal("Expected 2 items") + } + + if err := DeleteSource(db, "source1"); err != nil { + t.Fatal(err) + } + items, err = GetAllActiveItems(db) + if err != nil { + t.Fatal(err) + } + if len(items) != 1 { + t.Fatal("Expected only 1 item after source delete") + } + + err = AddItem(db, "source1", "item3", "", "", "", "", 0) + if err == nil { + t.Fatal("Unexpected success adding item for nonexistent source") + } +} diff --git a/core/migrations.go b/core/migrations.go index 5f9b25b..5008a03 100644 --- a/core/migrations.go +++ b/core/migrations.go @@ -1,7 +1,6 @@ package core import ( - "database/sql" "embed" "log" @@ -12,7 +11,7 @@ import ( var migrations embed.FS // Idempotently initialize the database. Safe to call unconditionally. -func InitDatabase(db *sql.DB) error { +func InitDatabase(db *DB) error { rows, err := db.Query(` select exists ( select 1 @@ -39,7 +38,7 @@ func InitDatabase(db *sql.DB) error { } // Get a map of migration names to whether the migration has been applied. -func GetPendingMigrations(db *sql.DB) (map[string]bool, error) { +func GetPendingMigrations(db *DB) (map[string]bool, error) { allMigrations, err := migrations.ReadDir("sql") if err != nil { return nil, err @@ -64,7 +63,7 @@ func GetPendingMigrations(db *sql.DB) (map[string]bool, error) { } // Apply a migration by name. -func ApplyMigration(db *sql.DB, name string) error { +func ApplyMigration(db *DB, name string) error { data, err := migrations.ReadFile("sql/" + name) if err != nil { log.Fatalf("Missing migration %s", name) @@ -79,7 +78,7 @@ func ApplyMigration(db *sql.DB, name string) error { } // Apply all pending migrations. -func MigrateDatabase(db *sql.DB) error { +func MigrateDatabase(db *DB) error { pending, err := GetPendingMigrations(db) if err != nil { return err diff --git a/core/migrations_test.go b/core/migrations_test.go index 17cb9cc..1386f50 100644 --- a/core/migrations_test.go +++ b/core/migrations_test.go @@ -7,11 +7,18 @@ import ( _ "github.com/mattn/go-sqlite3" ) -func EphemeralDb(t *testing.T) *sql.DB { - db, err := sql.Open("sqlite3", ":memory:") +func EphemeralDb(t *testing.T) *DB { + // We don't use OpenDb here because you can't open two connections to the same memory mem + mem, err := sql.Open("sqlite3", ":memory:") if err != nil { t.Fatal(err) } + if _, err = defaultPragma(mem); err != nil { + t.Fatal(err) + } + db := new(DB) + db.ro = mem + db.rw = mem if err = InitDatabase(db); err != nil { t.Fatal(err) } @@ -22,11 +29,13 @@ func EphemeralDb(t *testing.T) *sql.DB { } func TestInitIdempotency(t *testing.T) { - db, err := sql.Open("sqlite3", ":memory:") + mem, err := sql.Open("sqlite3", ":memory:") if err != nil { t.Fatal(err) } - defer db.Close() + db := new(DB) + db.ro = mem + db.rw = mem if err = InitDatabase(db); err != nil { t.Fatal(err) } @@ -37,7 +46,6 @@ func TestInitIdempotency(t *testing.T) { func TestMigrations(t *testing.T) { db := EphemeralDb(t) - defer db.Close() allMigrations, err := migrations.ReadDir("sql") if err != nil { diff --git a/core/source.go b/core/source.go index 8f61f7b..6aafce5 100644 --- a/core/source.go +++ b/core/source.go @@ -8,7 +8,7 @@ import ( _ "github.com/mattn/go-sqlite3" ) -func AddSource(db *sql.DB, name string) error { +func AddSource(db *DB, name string) error { _, err := db.Exec(` insert into sources (name) values (?) @@ -17,7 +17,7 @@ func AddSource(db *sql.DB, name string) error { return err } -func DeleteSource(db *sql.DB, name string) error { +func DeleteSource(db *DB, name string) error { _, err := db.Exec(` delete from sources where name = ? @@ -27,7 +27,7 @@ func DeleteSource(db *sql.DB, name string) error { } func AddItem( - db *sql.DB, + db *DB, source string, id string, title string, @@ -45,7 +45,7 @@ func AddItem( } // Deactivate an item, returning its previous active state. -func DeactivateItem(db *sql.DB, source string, id string) (bool, error) { +func DeactivateItem(db *DB, source string, id string) (bool, error) { row := db.QueryRow(` select active from items @@ -68,7 +68,7 @@ func DeactivateItem(db *sql.DB, source string, id string) (bool, error) { return active, nil } -func GetAllActiveItems(db *sql.DB) ([]Item, error) { +func GetAllActiveItems(db *DB) ([]Item, error) { rows, err := db.Query(` select source, id, created, active, title, author, body, link, time @@ -87,7 +87,7 @@ func GetAllActiveItems(db *sql.DB) ([]Item, error) { return items, nil } -func GetActiveItemsForSource(db *sql.DB, source string) ([]Item, error) { +func GetActiveItemsForSource(db *DB, source string) ([]Item, error) { rows, err := db.Query(` select source, id, created, active, title, author, body, link, time diff --git a/core/source_test.go b/core/source_test.go index ed34bbf..d164584 100644 --- a/core/source_test.go +++ b/core/source_test.go @@ -10,7 +10,6 @@ import ( func TestCreateSource(t *testing.T) { db := EphemeralDb(t) - defer db.Close() if err := AddSource(db, "one"); err != nil { t.Fatal(err) @@ -62,7 +61,6 @@ func AssertItemIs(t *testing.T, item Item, expected string) { func TestAddItem(t *testing.T) { db := EphemeralDb(t) - defer db.Close() if err := AddSource(db, "test"); err != nil { t.Fatal(err) } diff --git a/core/sql/0001_initial_schema.sql b/core/sql/0001_initial_schema.sql index 17876f1..d6e5d47 100644 --- a/core/sql/0001_initial_schema.sql +++ b/core/sql/0001_initial_schema.sql @@ -1,4 +1,7 @@ -create table sources(name text) strict; +create table sources( + name text not null, + primary key (name) +) strict; create table items( source text not null, id text not null,