Refactor migrations to their own files

This commit is contained in:
Tim Van Baak 2025-01-16 21:11:07 -08:00
parent 0c1b978264
commit 96ab254812
4 changed files with 117 additions and 38 deletions

View File

@ -2,7 +2,9 @@ package core
import (
"database/sql"
"embed"
"fmt"
"log"
_ "github.com/mattn/go-sqlite3"
)
@ -19,56 +21,95 @@ type Item struct {
time int
}
//go:embed sql/*.sql
var migrations embed.FS
// Idempotently initialize the database. Safe to call unconditionally.
func InitDatabase(db *sql.DB) error {
db.Exec(`
create table migrations (name text) strict;
`)
return nil
}
func MigrateDatabase(db *sql.DB) error {
rows, err := db.Query(`
select name from migrations;
select exists (
select 1
from sqlite_master
where type = 'table'
and name = 'migrations'
)
`)
if err != nil {
return err
}
complete := map[string]bool{}
var exists bool
for rows.Next() {
rows.Scan(&exists)
}
if exists {
return nil
}
err = ApplyMigration(db, "0000_baseline.sql")
return err
}
// Get the names of existing migrations that haven't been applied yet.
func GetPendingMigrations(db *sql.DB) ([]string, error) {
allMigrations, err := migrations.ReadDir("sql")
if err != nil {
return nil, err
}
complete := map[string]bool{}
for _, mig := range allMigrations {
complete[mig.Name()] = false
}
rows, err := db.Query("select name from migrations")
if err != nil {
return nil, err
}
for rows.Next() {
var name string
err = rows.Scan(&name)
if err != nil {
return err
}
rows.Scan(&name)
complete[name] = true
}
if !complete["0000_initial_schema"] {
_, err = db.Exec(`
create table sources(name text) strict;
create table items(
source text not null,
id text not null,
created int not null default (unixepoch()),
active int,
title text,
author text,
body text,
link text,
time int,
primary key (source, id),
foreign key (source) references sources (name) on delete cascade
) strict;
insert into migrations (name) values ('0000_initial_schema');
`)
var pending []string
for name, isComplete := range complete {
if !isComplete {
pending = append(pending, name)
}
}
return pending, nil
}
// Apply a migration by name.
func ApplyMigration(db *sql.DB, name string) error {
data, err := migrations.ReadFile("sql/" + name)
if err != nil {
log.Fatalf("Missing migration %s", name)
}
log.Printf("Applying migration %s", name)
_, err = db.Exec(string(data))
if err != nil {
return err
}
_, err = db.Exec("insert into migrations (name) values (?)", name)
return err
}
// Apply all pending migrations.
func MigrateDatabase(db *sql.DB) error {
pending, err := GetPendingMigrations(db)
if err != nil {
return err
}
for _, name := range pending {
err = ApplyMigration(db, name)
if err != nil {
return err
}
}
return nil
}

View File

@ -23,17 +23,40 @@ func EphemeralDb(t *testing.T) *sql.DB {
return db
}
func TestInitIdempotency(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
if err = InitDatabase(db); err != nil {
t.Fatal(err)
}
if err = InitDatabase(db); err != nil {
t.Fatal(err)
}
}
func TestMigrations(t *testing.T) {
db := EphemeralDb(t)
defer db.Close()
var count int
row := db.QueryRow("select count(name) from migrations")
if err := row.Scan(&count); err != nil {
allMigrations, err := migrations.ReadDir("sql")
if err != nil {
t.Fatal(err)
}
if count != 1 {
t.Fatalf("Unexpected migration count: %d", count)
rows, err := db.Query("select name from migrations")
if err != nil {
t.Fatal(err)
}
count := 0
for rows.Next() {
count += 1
}
if count != len(allMigrations) {
t.Fatalf("Expected %d migrations, got %d", len(allMigrations), count)
}
}

View File

@ -0,0 +1 @@
create table migrations (name text) strict;

View File

@ -0,0 +1,14 @@
create table sources(name text) strict;
create table items(
source text not null,
id text not null,
created int not null default (unixepoch()),
active int not null,
title text,
author text,
body text,
link text,
time int,
primary key (source, id),
foreign key (source) references sources (name) on delete cascade
) strict;