Refactor db access to ensure pragmas are set, fix foreign keys
This commit is contained in:
parent
1468c3adc4
commit
2a58c01319
@ -2,7 +2,6 @@ package cmd
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
@ -58,7 +57,7 @@ func add() {
|
||||
addId = hex.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite3", getDbPath())
|
||||
db, err := core.OpenDb(getDbPath())
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to open %s", dbPath)
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
@ -35,7 +34,7 @@ func init() {
|
||||
}
|
||||
|
||||
func deactivate() {
|
||||
db, err := sql.Open("sqlite3", getDbPath())
|
||||
db, err := core.OpenDb(getDbPath())
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to open %s", dbPath)
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
@ -44,7 +43,7 @@ func feed() {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite3", getDbPath())
|
||||
db, err := core.OpenDb(getDbPath())
|
||||
if err != nil {
|
||||
log.Fatalf("error: failed to open %s", dbPath)
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
@ -30,7 +29,7 @@ func init() {
|
||||
}
|
||||
|
||||
func migrate() {
|
||||
db, err := sql.Open("sqlite3", getDbPath())
|
||||
db, err := core.OpenDb(getDbPath())
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
70
core/db.go
Normal file
70
core/db.go
Normal file
@ -0,0 +1,70 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"runtime"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
ro *sql.DB
|
||||
rw *sql.DB
|
||||
}
|
||||
|
||||
func (db *DB) Query(query string, args ...any) (*sql.Rows, error) {
|
||||
return db.ro.Query(query, args...)
|
||||
}
|
||||
|
||||
func (db *DB) QueryRow(query string, args ...any) *sql.Row {
|
||||
return db.ro.QueryRow(query, args...)
|
||||
}
|
||||
|
||||
func (db *DB) Exec(query string, args ...any) (sql.Result, error) {
|
||||
return db.rw.Exec(query, args...)
|
||||
}
|
||||
|
||||
func defaultPragma(db *sql.DB) (sql.Result, error) {
|
||||
return db.Exec(`
|
||||
pragma journal_mode = WAL;
|
||||
pragma busy_timeout = 5000;
|
||||
pragma synchronous = NORMAL;
|
||||
pragma cache_size = 1000000000;
|
||||
pragma foreign_keys = true;
|
||||
pragma temp_store = memory;
|
||||
pragma mmap_size = 3000000000;
|
||||
`)
|
||||
}
|
||||
|
||||
func OpenDb(dataSourceName string) (*DB, error) {
|
||||
ro, err := sql.Open("sqlite3", dataSourceName)
|
||||
if err != nil {
|
||||
defer ro.Close()
|
||||
return nil, err
|
||||
}
|
||||
ro.SetMaxOpenConns(max(4, runtime.NumCPU()))
|
||||
_, err = defaultPragma(ro)
|
||||
if err != nil {
|
||||
defer ro.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rw, err := sql.Open("sqlite3", dataSourceName)
|
||||
if err != nil {
|
||||
defer ro.Close()
|
||||
defer rw.Close()
|
||||
return nil, err
|
||||
}
|
||||
rw.SetMaxOpenConns(1)
|
||||
_, err = defaultPragma(rw)
|
||||
if err != nil {
|
||||
defer ro.Close()
|
||||
defer rw.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wrapper := new(DB)
|
||||
wrapper.ro = ro
|
||||
wrapper.rw = rw
|
||||
return wrapper, nil
|
||||
}
|
48
core/db_test.go
Normal file
48
core/db_test.go
Normal file
@ -0,0 +1,48 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestDeleteSourceCascade(t *testing.T) {
|
||||
db := EphemeralDb(t)
|
||||
|
||||
if err := AddSource(db, "source1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := AddSource(db, "source2"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := AddItem(db, "source1", "item1", "", "", "", "", 0); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := AddItem(db, "source2", "item2", "", "", "", "", 0); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
items, err := GetAllActiveItems(db)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(items) != 2 {
|
||||
t.Fatal("Expected 2 items")
|
||||
}
|
||||
|
||||
if err := DeleteSource(db, "source1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
items, err = GetAllActiveItems(db)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(items) != 1 {
|
||||
t.Fatal("Expected only 1 item after source delete")
|
||||
}
|
||||
|
||||
err = AddItem(db, "source1", "item3", "", "", "", "", 0)
|
||||
if err == nil {
|
||||
t.Fatal("Unexpected success adding item for nonexistent source")
|
||||
}
|
||||
}
|
@ -1,7 +1,6 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"embed"
|
||||
"log"
|
||||
|
||||
@ -12,7 +11,7 @@ import (
|
||||
var migrations embed.FS
|
||||
|
||||
// Idempotently initialize the database. Safe to call unconditionally.
|
||||
func InitDatabase(db *sql.DB) error {
|
||||
func InitDatabase(db *DB) error {
|
||||
rows, err := db.Query(`
|
||||
select exists (
|
||||
select 1
|
||||
@ -39,7 +38,7 @@ func InitDatabase(db *sql.DB) error {
|
||||
}
|
||||
|
||||
// Get a map of migration names to whether the migration has been applied.
|
||||
func GetPendingMigrations(db *sql.DB) (map[string]bool, error) {
|
||||
func GetPendingMigrations(db *DB) (map[string]bool, error) {
|
||||
allMigrations, err := migrations.ReadDir("sql")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -64,7 +63,7 @@ func GetPendingMigrations(db *sql.DB) (map[string]bool, error) {
|
||||
}
|
||||
|
||||
// Apply a migration by name.
|
||||
func ApplyMigration(db *sql.DB, name string) error {
|
||||
func ApplyMigration(db *DB, name string) error {
|
||||
data, err := migrations.ReadFile("sql/" + name)
|
||||
if err != nil {
|
||||
log.Fatalf("Missing migration %s", name)
|
||||
@ -79,7 +78,7 @@ func ApplyMigration(db *sql.DB, name string) error {
|
||||
}
|
||||
|
||||
// Apply all pending migrations.
|
||||
func MigrateDatabase(db *sql.DB) error {
|
||||
func MigrateDatabase(db *DB) error {
|
||||
pending, err := GetPendingMigrations(db)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -7,11 +7,18 @@ import (
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func EphemeralDb(t *testing.T) *sql.DB {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
func EphemeralDb(t *testing.T) *DB {
|
||||
// We don't use OpenDb here because you can't open two connections to the same memory mem
|
||||
mem, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err = defaultPragma(mem); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
db := new(DB)
|
||||
db.ro = mem
|
||||
db.rw = mem
|
||||
if err = InitDatabase(db); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -22,11 +29,13 @@ func EphemeralDb(t *testing.T) *sql.DB {
|
||||
}
|
||||
|
||||
func TestInitIdempotency(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
mem, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
db := new(DB)
|
||||
db.ro = mem
|
||||
db.rw = mem
|
||||
if err = InitDatabase(db); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -37,7 +46,6 @@ func TestInitIdempotency(t *testing.T) {
|
||||
|
||||
func TestMigrations(t *testing.T) {
|
||||
db := EphemeralDb(t)
|
||||
defer db.Close()
|
||||
|
||||
allMigrations, err := migrations.ReadDir("sql")
|
||||
if err != nil {
|
||||
|
@ -8,7 +8,7 @@ import (
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func AddSource(db *sql.DB, name string) error {
|
||||
func AddSource(db *DB, name string) error {
|
||||
_, err := db.Exec(`
|
||||
insert into sources (name)
|
||||
values (?)
|
||||
@ -17,7 +17,7 @@ func AddSource(db *sql.DB, name string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func DeleteSource(db *sql.DB, name string) error {
|
||||
func DeleteSource(db *DB, name string) error {
|
||||
_, err := db.Exec(`
|
||||
delete from sources
|
||||
where name = ?
|
||||
@ -27,7 +27,7 @@ func DeleteSource(db *sql.DB, name string) error {
|
||||
}
|
||||
|
||||
func AddItem(
|
||||
db *sql.DB,
|
||||
db *DB,
|
||||
source string,
|
||||
id string,
|
||||
title string,
|
||||
@ -45,7 +45,7 @@ func AddItem(
|
||||
}
|
||||
|
||||
// Deactivate an item, returning its previous active state.
|
||||
func DeactivateItem(db *sql.DB, source string, id string) (bool, error) {
|
||||
func DeactivateItem(db *DB, source string, id string) (bool, error) {
|
||||
row := db.QueryRow(`
|
||||
select active
|
||||
from items
|
||||
@ -68,7 +68,7 @@ func DeactivateItem(db *sql.DB, source string, id string) (bool, error) {
|
||||
return active, nil
|
||||
}
|
||||
|
||||
func GetAllActiveItems(db *sql.DB) ([]Item, error) {
|
||||
func GetAllActiveItems(db *DB) ([]Item, error) {
|
||||
rows, err := db.Query(`
|
||||
select
|
||||
source, id, created, active, title, author, body, link, time
|
||||
@ -87,7 +87,7 @@ func GetAllActiveItems(db *sql.DB) ([]Item, error) {
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func GetActiveItemsForSource(db *sql.DB, source string) ([]Item, error) {
|
||||
func GetActiveItemsForSource(db *DB, source string) ([]Item, error) {
|
||||
rows, err := db.Query(`
|
||||
select
|
||||
source, id, created, active, title, author, body, link, time
|
||||
|
@ -10,7 +10,6 @@ import (
|
||||
|
||||
func TestCreateSource(t *testing.T) {
|
||||
db := EphemeralDb(t)
|
||||
defer db.Close()
|
||||
|
||||
if err := AddSource(db, "one"); err != nil {
|
||||
t.Fatal(err)
|
||||
@ -62,7 +61,6 @@ func AssertItemIs(t *testing.T, item Item, expected string) {
|
||||
|
||||
func TestAddItem(t *testing.T) {
|
||||
db := EphemeralDb(t)
|
||||
defer db.Close()
|
||||
if err := AddSource(db, "test"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -1,4 +1,7 @@
|
||||
create table sources(name text) strict;
|
||||
create table sources(
|
||||
name text not null,
|
||||
primary key (name)
|
||||
) strict;
|
||||
create table items(
|
||||
source text not null,
|
||||
id text not null,
|
||||
|
Loading…
Reference in New Issue
Block a user