Refactor migrations to their own files
This commit is contained in:
parent
0c1b978264
commit
96ab254812
107
core/db.go
107
core/db.go
@ -2,7 +2,9 @@ package core
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"embed"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
@ -19,56 +21,95 @@ type Item struct {
|
||||
time int
|
||||
}
|
||||
|
||||
//go:embed sql/*.sql
|
||||
var migrations embed.FS
|
||||
|
||||
// Idempotently initialize the database. Safe to call unconditionally.
|
||||
func InitDatabase(db *sql.DB) error {
|
||||
db.Exec(`
|
||||
create table migrations (name text) strict;
|
||||
`)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func MigrateDatabase(db *sql.DB) error {
|
||||
rows, err := db.Query(`
|
||||
select name from migrations;
|
||||
select exists (
|
||||
select 1
|
||||
from sqlite_master
|
||||
where type = 'table'
|
||||
and name = 'migrations'
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
complete := map[string]bool{}
|
||||
var exists bool
|
||||
for rows.Next() {
|
||||
rows.Scan(&exists)
|
||||
}
|
||||
|
||||
if exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = ApplyMigration(db, "0000_baseline.sql")
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the names of existing migrations that haven't been applied yet.
|
||||
func GetPendingMigrations(db *sql.DB) ([]string, error) {
|
||||
allMigrations, err := migrations.ReadDir("sql")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
complete := map[string]bool{}
|
||||
for _, mig := range allMigrations {
|
||||
complete[mig.Name()] = false
|
||||
}
|
||||
|
||||
rows, err := db.Query("select name from migrations")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for rows.Next() {
|
||||
var name string
|
||||
err = rows.Scan(&name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows.Scan(&name)
|
||||
complete[name] = true
|
||||
}
|
||||
|
||||
if !complete["0000_initial_schema"] {
|
||||
_, err = db.Exec(`
|
||||
create table sources(name text) strict;
|
||||
create table items(
|
||||
source text not null,
|
||||
id text not null,
|
||||
created int not null default (unixepoch()),
|
||||
active int,
|
||||
title text,
|
||||
author text,
|
||||
body text,
|
||||
link text,
|
||||
time int,
|
||||
primary key (source, id),
|
||||
foreign key (source) references sources (name) on delete cascade
|
||||
) strict;
|
||||
insert into migrations (name) values ('0000_initial_schema');
|
||||
`)
|
||||
var pending []string
|
||||
for name, isComplete := range complete {
|
||||
if !isComplete {
|
||||
pending = append(pending, name)
|
||||
}
|
||||
}
|
||||
|
||||
return pending, nil
|
||||
}
|
||||
|
||||
// Apply a migration by name.
|
||||
func ApplyMigration(db *sql.DB, name string) error {
|
||||
data, err := migrations.ReadFile("sql/" + name)
|
||||
if err != nil {
|
||||
log.Fatalf("Missing migration %s", name)
|
||||
}
|
||||
log.Printf("Applying migration %s", name)
|
||||
_, err = db.Exec(string(data))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = db.Exec("insert into migrations (name) values (?)", name)
|
||||
return err
|
||||
}
|
||||
|
||||
// Apply all pending migrations.
|
||||
func MigrateDatabase(db *sql.DB) error {
|
||||
pending, err := GetPendingMigrations(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, name := range pending {
|
||||
err = ApplyMigration(db, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -23,17 +23,40 @@ func EphemeralDb(t *testing.T) *sql.DB {
|
||||
return db
|
||||
}
|
||||
|
||||
func TestInitIdempotency(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
if err = InitDatabase(db); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err = InitDatabase(db); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrations(t *testing.T) {
|
||||
db := EphemeralDb(t)
|
||||
defer db.Close()
|
||||
|
||||
var count int
|
||||
row := db.QueryRow("select count(name) from migrations")
|
||||
if err := row.Scan(&count); err != nil {
|
||||
allMigrations, err := migrations.ReadDir("sql")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Fatalf("Unexpected migration count: %d", count)
|
||||
|
||||
rows, err := db.Query("select name from migrations")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
count := 0
|
||||
for rows.Next() {
|
||||
count += 1
|
||||
}
|
||||
|
||||
if count != len(allMigrations) {
|
||||
t.Fatalf("Expected %d migrations, got %d", len(allMigrations), count)
|
||||
}
|
||||
}
|
||||
|
||||
|
1
core/sql/0000_baseline.sql
Normal file
1
core/sql/0000_baseline.sql
Normal file
@ -0,0 +1 @@
|
||||
create table migrations (name text) strict;
|
14
core/sql/0001_initial_schema.sql
Normal file
14
core/sql/0001_initial_schema.sql
Normal file
@ -0,0 +1,14 @@
|
||||
create table sources(name text) strict;
|
||||
create table items(
|
||||
source text not null,
|
||||
id text not null,
|
||||
created int not null default (unixepoch()),
|
||||
active int not null,
|
||||
title text,
|
||||
author text,
|
||||
body text,
|
||||
link text,
|
||||
time int,
|
||||
primary key (source, id),
|
||||
foreign key (source) references sources (name) on delete cascade
|
||||
) strict;
|
Loading…
Reference in New Issue
Block a user