Refactor db access to ensure pragmas are set, fix foreign keys
This commit is contained in:
parent
1468c3adc4
commit
2a58c01319
@ -2,7 +2,6 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"database/sql"
|
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
@ -58,7 +57,7 @@ func add() {
|
|||||||
addId = hex.EncodeToString(bytes)
|
addId = hex.EncodeToString(bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := sql.Open("sqlite3", getDbPath())
|
db, err := core.OpenDb(getDbPath())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to open %s", dbPath)
|
log.Fatalf("Failed to open %s", dbPath)
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
@ -35,7 +34,7 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func deactivate() {
|
func deactivate() {
|
||||||
db, err := sql.Open("sqlite3", getDbPath())
|
db, err := core.OpenDb(getDbPath())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to open %s", dbPath)
|
log.Fatalf("Failed to open %s", dbPath)
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
@ -44,7 +43,7 @@ func feed() {
|
|||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := sql.Open("sqlite3", getDbPath())
|
db, err := core.OpenDb(getDbPath())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("error: failed to open %s", dbPath)
|
log.Fatalf("error: failed to open %s", dbPath)
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
@ -30,7 +29,7 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func migrate() {
|
func migrate() {
|
||||||
db, err := sql.Open("sqlite3", getDbPath())
|
db, err := core.OpenDb(getDbPath())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
70
core/db.go
Normal file
70
core/db.go
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
package core
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DB struct {
|
||||||
|
ro *sql.DB
|
||||||
|
rw *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) Query(query string, args ...any) (*sql.Rows, error) {
|
||||||
|
return db.ro.Query(query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) QueryRow(query string, args ...any) *sql.Row {
|
||||||
|
return db.ro.QueryRow(query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) Exec(query string, args ...any) (sql.Result, error) {
|
||||||
|
return db.rw.Exec(query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultPragma(db *sql.DB) (sql.Result, error) {
|
||||||
|
return db.Exec(`
|
||||||
|
pragma journal_mode = WAL;
|
||||||
|
pragma busy_timeout = 5000;
|
||||||
|
pragma synchronous = NORMAL;
|
||||||
|
pragma cache_size = 1000000000;
|
||||||
|
pragma foreign_keys = true;
|
||||||
|
pragma temp_store = memory;
|
||||||
|
pragma mmap_size = 3000000000;
|
||||||
|
`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func OpenDb(dataSourceName string) (*DB, error) {
|
||||||
|
ro, err := sql.Open("sqlite3", dataSourceName)
|
||||||
|
if err != nil {
|
||||||
|
defer ro.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ro.SetMaxOpenConns(max(4, runtime.NumCPU()))
|
||||||
|
_, err = defaultPragma(ro)
|
||||||
|
if err != nil {
|
||||||
|
defer ro.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rw, err := sql.Open("sqlite3", dataSourceName)
|
||||||
|
if err != nil {
|
||||||
|
defer ro.Close()
|
||||||
|
defer rw.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
rw.SetMaxOpenConns(1)
|
||||||
|
_, err = defaultPragma(rw)
|
||||||
|
if err != nil {
|
||||||
|
defer ro.Close()
|
||||||
|
defer rw.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
wrapper := new(DB)
|
||||||
|
wrapper.ro = ro
|
||||||
|
wrapper.rw = rw
|
||||||
|
return wrapper, nil
|
||||||
|
}
|
48
core/db_test.go
Normal file
48
core/db_test.go
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
package core
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDeleteSourceCascade(t *testing.T) {
|
||||||
|
db := EphemeralDb(t)
|
||||||
|
|
||||||
|
if err := AddSource(db, "source1"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := AddSource(db, "source2"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := AddItem(db, "source1", "item1", "", "", "", "", 0); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := AddItem(db, "source2", "item2", "", "", "", "", 0); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
items, err := GetAllActiveItems(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(items) != 2 {
|
||||||
|
t.Fatal("Expected 2 items")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DeleteSource(db, "source1"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
items, err = GetAllActiveItems(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(items) != 1 {
|
||||||
|
t.Fatal("Expected only 1 item after source delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = AddItem(db, "source1", "item3", "", "", "", "", 0)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Unexpected success adding item for nonexistent source")
|
||||||
|
}
|
||||||
|
}
|
@ -1,7 +1,6 @@
|
|||||||
package core
|
package core
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"embed"
|
"embed"
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
@ -12,7 +11,7 @@ import (
|
|||||||
var migrations embed.FS
|
var migrations embed.FS
|
||||||
|
|
||||||
// Idempotently initialize the database. Safe to call unconditionally.
|
// Idempotently initialize the database. Safe to call unconditionally.
|
||||||
func InitDatabase(db *sql.DB) error {
|
func InitDatabase(db *DB) error {
|
||||||
rows, err := db.Query(`
|
rows, err := db.Query(`
|
||||||
select exists (
|
select exists (
|
||||||
select 1
|
select 1
|
||||||
@ -39,7 +38,7 @@ func InitDatabase(db *sql.DB) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get a map of migration names to whether the migration has been applied.
|
// Get a map of migration names to whether the migration has been applied.
|
||||||
func GetPendingMigrations(db *sql.DB) (map[string]bool, error) {
|
func GetPendingMigrations(db *DB) (map[string]bool, error) {
|
||||||
allMigrations, err := migrations.ReadDir("sql")
|
allMigrations, err := migrations.ReadDir("sql")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -64,7 +63,7 @@ func GetPendingMigrations(db *sql.DB) (map[string]bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Apply a migration by name.
|
// Apply a migration by name.
|
||||||
func ApplyMigration(db *sql.DB, name string) error {
|
func ApplyMigration(db *DB, name string) error {
|
||||||
data, err := migrations.ReadFile("sql/" + name)
|
data, err := migrations.ReadFile("sql/" + name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Missing migration %s", name)
|
log.Fatalf("Missing migration %s", name)
|
||||||
@ -79,7 +78,7 @@ func ApplyMigration(db *sql.DB, name string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Apply all pending migrations.
|
// Apply all pending migrations.
|
||||||
func MigrateDatabase(db *sql.DB) error {
|
func MigrateDatabase(db *DB) error {
|
||||||
pending, err := GetPendingMigrations(db)
|
pending, err := GetPendingMigrations(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -7,11 +7,18 @@ import (
|
|||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func EphemeralDb(t *testing.T) *sql.DB {
|
func EphemeralDb(t *testing.T) *DB {
|
||||||
db, err := sql.Open("sqlite3", ":memory:")
|
// We don't use OpenDb here because you can't open two connections to the same memory mem
|
||||||
|
mem, err := sql.Open("sqlite3", ":memory:")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
if _, err = defaultPragma(mem); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
db := new(DB)
|
||||||
|
db.ro = mem
|
||||||
|
db.rw = mem
|
||||||
if err = InitDatabase(db); err != nil {
|
if err = InitDatabase(db); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -22,11 +29,13 @@ func EphemeralDb(t *testing.T) *sql.DB {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestInitIdempotency(t *testing.T) {
|
func TestInitIdempotency(t *testing.T) {
|
||||||
db, err := sql.Open("sqlite3", ":memory:")
|
mem, err := sql.Open("sqlite3", ":memory:")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer db.Close()
|
db := new(DB)
|
||||||
|
db.ro = mem
|
||||||
|
db.rw = mem
|
||||||
if err = InitDatabase(db); err != nil {
|
if err = InitDatabase(db); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -37,7 +46,6 @@ func TestInitIdempotency(t *testing.T) {
|
|||||||
|
|
||||||
func TestMigrations(t *testing.T) {
|
func TestMigrations(t *testing.T) {
|
||||||
db := EphemeralDb(t)
|
db := EphemeralDb(t)
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
allMigrations, err := migrations.ReadDir("sql")
|
allMigrations, err := migrations.ReadDir("sql")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -8,7 +8,7 @@ import (
|
|||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func AddSource(db *sql.DB, name string) error {
|
func AddSource(db *DB, name string) error {
|
||||||
_, err := db.Exec(`
|
_, err := db.Exec(`
|
||||||
insert into sources (name)
|
insert into sources (name)
|
||||||
values (?)
|
values (?)
|
||||||
@ -17,7 +17,7 @@ func AddSource(db *sql.DB, name string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteSource(db *sql.DB, name string) error {
|
func DeleteSource(db *DB, name string) error {
|
||||||
_, err := db.Exec(`
|
_, err := db.Exec(`
|
||||||
delete from sources
|
delete from sources
|
||||||
where name = ?
|
where name = ?
|
||||||
@ -27,7 +27,7 @@ func DeleteSource(db *sql.DB, name string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func AddItem(
|
func AddItem(
|
||||||
db *sql.DB,
|
db *DB,
|
||||||
source string,
|
source string,
|
||||||
id string,
|
id string,
|
||||||
title string,
|
title string,
|
||||||
@ -45,7 +45,7 @@ func AddItem(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Deactivate an item, returning its previous active state.
|
// Deactivate an item, returning its previous active state.
|
||||||
func DeactivateItem(db *sql.DB, source string, id string) (bool, error) {
|
func DeactivateItem(db *DB, source string, id string) (bool, error) {
|
||||||
row := db.QueryRow(`
|
row := db.QueryRow(`
|
||||||
select active
|
select active
|
||||||
from items
|
from items
|
||||||
@ -68,7 +68,7 @@ func DeactivateItem(db *sql.DB, source string, id string) (bool, error) {
|
|||||||
return active, nil
|
return active, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllActiveItems(db *sql.DB) ([]Item, error) {
|
func GetAllActiveItems(db *DB) ([]Item, error) {
|
||||||
rows, err := db.Query(`
|
rows, err := db.Query(`
|
||||||
select
|
select
|
||||||
source, id, created, active, title, author, body, link, time
|
source, id, created, active, title, author, body, link, time
|
||||||
@ -87,7 +87,7 @@ func GetAllActiveItems(db *sql.DB) ([]Item, error) {
|
|||||||
return items, nil
|
return items, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetActiveItemsForSource(db *sql.DB, source string) ([]Item, error) {
|
func GetActiveItemsForSource(db *DB, source string) ([]Item, error) {
|
||||||
rows, err := db.Query(`
|
rows, err := db.Query(`
|
||||||
select
|
select
|
||||||
source, id, created, active, title, author, body, link, time
|
source, id, created, active, title, author, body, link, time
|
||||||
|
@ -10,7 +10,6 @@ import (
|
|||||||
|
|
||||||
func TestCreateSource(t *testing.T) {
|
func TestCreateSource(t *testing.T) {
|
||||||
db := EphemeralDb(t)
|
db := EphemeralDb(t)
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
if err := AddSource(db, "one"); err != nil {
|
if err := AddSource(db, "one"); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@ -62,7 +61,6 @@ func AssertItemIs(t *testing.T, item Item, expected string) {
|
|||||||
|
|
||||||
func TestAddItem(t *testing.T) {
|
func TestAddItem(t *testing.T) {
|
||||||
db := EphemeralDb(t)
|
db := EphemeralDb(t)
|
||||||
defer db.Close()
|
|
||||||
if err := AddSource(db, "test"); err != nil {
|
if err := AddSource(db, "test"); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
create table sources(name text) strict;
|
create table sources(
|
||||||
|
name text not null,
|
||||||
|
primary key (name)
|
||||||
|
) strict;
|
||||||
create table items(
|
create table items(
|
||||||
source text not null,
|
source text not null,
|
||||||
id text not null,
|
id text not null,
|
||||||
|
Loading…
Reference in New Issue
Block a user