Refactor migrations to their own files

This commit is contained in:
Tim Van Baak 2025-01-16 21:11:07 -08:00
parent 0c1b978264
commit 96ab254812
4 changed files with 117 additions and 38 deletions

View File

@ -2,7 +2,9 @@ package core
import ( import (
"database/sql" "database/sql"
"embed"
"fmt" "fmt"
"log"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
@ -19,56 +21,95 @@ type Item struct {
time int time int
} }
func InitDatabase(db *sql.DB) error { //go:embed sql/*.sql
db.Exec(` var migrations embed.FS
create table migrations (name text) strict;
`)
// Idempotently initialize the database. Safe to call unconditionally.
func InitDatabase(db *sql.DB) error {
rows, err := db.Query(`
select exists (
select 1
from sqlite_master
where type = 'table'
and name = 'migrations'
)
`)
if err != nil {
return err
}
var exists bool
for rows.Next() {
rows.Scan(&exists)
}
if exists {
return nil return nil
} }
func MigrateDatabase(db *sql.DB) error { err = ApplyMigration(db, "0000_baseline.sql")
rows, err := db.Query(`
select name from migrations;
`)
if err != nil {
return err return err
} }
// Get the names of existing migrations that haven't been applied yet.
func GetPendingMigrations(db *sql.DB) ([]string, error) {
allMigrations, err := migrations.ReadDir("sql")
if err != nil {
return nil, err
}
complete := map[string]bool{} complete := map[string]bool{}
for _, mig := range allMigrations {
complete[mig.Name()] = false
}
rows, err := db.Query("select name from migrations")
if err != nil {
return nil, err
}
for rows.Next() { for rows.Next() {
var name string var name string
err = rows.Scan(&name) rows.Scan(&name)
if err != nil {
return err
}
complete[name] = true complete[name] = true
} }
if !complete["0000_initial_schema"] { var pending []string
_, err = db.Exec(` for name, isComplete := range complete {
create table sources(name text) strict; if !isComplete {
create table items( pending = append(pending, name)
source text not null, }
id text not null, }
created int not null default (unixepoch()),
active int, return pending, nil
title text, }
author text,
body text, // Apply a migration by name.
link text, func ApplyMigration(db *sql.DB, name string) error {
time int, data, err := migrations.ReadFile("sql/" + name)
primary key (source, id), if err != nil {
foreign key (source) references sources (name) on delete cascade log.Fatalf("Missing migration %s", name)
) strict; }
insert into migrations (name) values ('0000_initial_schema'); log.Printf("Applying migration %s", name)
`) _, err = db.Exec(string(data))
if err != nil {
return err
}
_, err = db.Exec("insert into migrations (name) values (?)", name)
return err
}
// Apply all pending migrations.
func MigrateDatabase(db *sql.DB) error {
pending, err := GetPendingMigrations(db)
if err != nil {
return err
}
for _, name := range pending {
err = ApplyMigration(db, name)
if err != nil { if err != nil {
return err return err
} }
} }
return nil return nil
} }

View File

@ -23,17 +23,40 @@ func EphemeralDb(t *testing.T) *sql.DB {
return db return db
} }
func TestInitIdempotency(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
if err = InitDatabase(db); err != nil {
t.Fatal(err)
}
if err = InitDatabase(db); err != nil {
t.Fatal(err)
}
}
func TestMigrations(t *testing.T) { func TestMigrations(t *testing.T) {
db := EphemeralDb(t) db := EphemeralDb(t)
defer db.Close() defer db.Close()
var count int allMigrations, err := migrations.ReadDir("sql")
row := db.QueryRow("select count(name) from migrations") if err != nil {
if err := row.Scan(&count); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if count != 1 {
t.Fatalf("Unexpected migration count: %d", count) rows, err := db.Query("select name from migrations")
if err != nil {
t.Fatal(err)
}
count := 0
for rows.Next() {
count += 1
}
if count != len(allMigrations) {
t.Fatalf("Expected %d migrations, got %d", len(allMigrations), count)
} }
} }

View File

@ -0,0 +1 @@
create table migrations (name text) strict;

View File

@ -0,0 +1,14 @@
create table sources(name text) strict;
create table items(
source text not null,
id text not null,
created int not null default (unixepoch()),
active int not null,
title text,
author text,
body text,
link text,
time int,
primary key (source, id),
foreign key (source) references sources (name) on delete cascade
) strict;