diff --git a/core/db.go b/core/db.go index e791d36..57490ba 100644 --- a/core/db.go +++ b/core/db.go @@ -2,7 +2,9 @@ package core import ( "database/sql" + "embed" "fmt" + "log" _ "github.com/mattn/go-sqlite3" ) @@ -19,56 +21,95 @@ type Item struct { time int } +//go:embed sql/*.sql +var migrations embed.FS + +// Idempotently initialize the database. Safe to call unconditionally. func InitDatabase(db *sql.DB) error { - db.Exec(` - create table migrations (name text) strict; - `) - - return nil -} - -func MigrateDatabase(db *sql.DB) error { rows, err := db.Query(` - select name from migrations; + select exists ( + select 1 + from sqlite_master + where type = 'table' + and name = 'migrations' + ) `) if err != nil { return err } - complete := map[string]bool{} + var exists bool + for rows.Next() { + rows.Scan(&exists) + } + if exists { + return nil + } + + err = ApplyMigration(db, "0000_baseline.sql") + return err +} + +// Get the names of existing migrations that haven't been applied yet. +func GetPendingMigrations(db *sql.DB) ([]string, error) { + allMigrations, err := migrations.ReadDir("sql") + if err != nil { + return nil, err + } + + complete := map[string]bool{} + for _, mig := range allMigrations { + complete[mig.Name()] = false + } + + rows, err := db.Query("select name from migrations") + if err != nil { + return nil, err + } for rows.Next() { var name string - err = rows.Scan(&name) - if err != nil { - return err - } + rows.Scan(&name) complete[name] = true } - if !complete["0000_initial_schema"] { - _, err = db.Exec(` - create table sources(name text) strict; - create table items( - source text not null, - id text not null, - created int not null default (unixepoch()), - active int, - title text, - author text, - body text, - link text, - time int, - primary key (source, id), - foreign key (source) references sources (name) on delete cascade - ) strict; - insert into migrations (name) values ('0000_initial_schema'); - `) + var pending []string + for name, isComplete := range complete { + if !isComplete { + pending = append(pending, name) + } + } + + return pending, nil +} + +// Apply a migration by name. +func ApplyMigration(db *sql.DB, name string) error { + data, err := migrations.ReadFile("sql/" + name) + if err != nil { + log.Fatalf("Missing migration %s", name) + } + log.Printf("Applying migration %s", name) + _, err = db.Exec(string(data)) + if err != nil { + return err + } + _, err = db.Exec("insert into migrations (name) values (?)", name) + return err +} + +// Apply all pending migrations. +func MigrateDatabase(db *sql.DB) error { + pending, err := GetPendingMigrations(db) + if err != nil { + return err + } + for _, name := range pending { + err = ApplyMigration(db, name) if err != nil { return err } } - return nil } diff --git a/core/db_test.go b/core/db_test.go index 9ded57b..a87c666 100644 --- a/core/db_test.go +++ b/core/db_test.go @@ -23,17 +23,40 @@ func EphemeralDb(t *testing.T) *sql.DB { return db } +func TestInitIdempotency(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + if err = InitDatabase(db); err != nil { + t.Fatal(err) + } + if err = InitDatabase(db); err != nil { + t.Fatal(err) + } +} + func TestMigrations(t *testing.T) { db := EphemeralDb(t) defer db.Close() - var count int - row := db.QueryRow("select count(name) from migrations") - if err := row.Scan(&count); err != nil { + allMigrations, err := migrations.ReadDir("sql") + if err != nil { t.Fatal(err) } - if count != 1 { - t.Fatalf("Unexpected migration count: %d", count) + + rows, err := db.Query("select name from migrations") + if err != nil { + t.Fatal(err) + } + count := 0 + for rows.Next() { + count += 1 + } + + if count != len(allMigrations) { + t.Fatalf("Expected %d migrations, got %d", len(allMigrations), count) } } diff --git a/core/sql/0000_baseline.sql b/core/sql/0000_baseline.sql new file mode 100644 index 0000000..13e067f --- /dev/null +++ b/core/sql/0000_baseline.sql @@ -0,0 +1 @@ +create table migrations (name text) strict; diff --git a/core/sql/0001_initial_schema.sql b/core/sql/0001_initial_schema.sql new file mode 100644 index 0000000..17876f1 --- /dev/null +++ b/core/sql/0001_initial_schema.sql @@ -0,0 +1,14 @@ +create table sources(name text) strict; +create table items( + source text not null, + id text not null, + created int not null default (unixepoch()), + active int not null, + title text, + author text, + body text, + link text, + time int, + primary key (source, id), + foreign key (source) references sources (name) on delete cascade +) strict;