Remove migrations code
This commit is contained in:
parent
cd00c0fedc
commit
bd488d7b47
@ -2,101 +2,12 @@ package core
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"embed"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed sql/*.sql
|
|
||||||
var migrations embed.FS
|
|
||||||
|
|
||||||
// Idempotently initialize the database. Safe to call unconditionally.
|
|
||||||
func InitDatabase(db *sql.DB) error {
|
|
||||||
rows, err := db.Query(`
|
|
||||||
select exists (
|
|
||||||
select 1
|
|
||||||
from sqlite_master
|
|
||||||
where type = 'table'
|
|
||||||
and name = 'migrations'
|
|
||||||
)
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var exists bool
|
|
||||||
for rows.Next() {
|
|
||||||
rows.Scan(&exists)
|
|
||||||
}
|
|
||||||
|
|
||||||
if exists {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err = ApplyMigration(db, "0000_baseline.sql")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get a map of migration names to whether the migration has been applied.
|
|
||||||
func GetPendingMigrations(db *sql.DB) (map[string]bool, error) {
|
|
||||||
allMigrations, err := migrations.ReadDir("sql")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
complete := map[string]bool{}
|
|
||||||
for _, mig := range allMigrations {
|
|
||||||
complete[mig.Name()] = false
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := db.Query("select name from migrations")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
for rows.Next() {
|
|
||||||
var name string
|
|
||||||
rows.Scan(&name)
|
|
||||||
complete[name] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
return complete, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply a migration by name.
|
|
||||||
func ApplyMigration(db *sql.DB, name string) error {
|
|
||||||
data, err := migrations.ReadFile("sql/" + name)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Missing migration %s", name)
|
|
||||||
}
|
|
||||||
log.Printf("Applying migration %s", name)
|
|
||||||
_, err = db.Exec(string(data))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = db.Exec("insert into migrations (name) values (?)", name)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply all pending migrations.
|
|
||||||
func MigrateDatabase(db *sql.DB) error {
|
|
||||||
pending, err := GetPendingMigrations(db)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
for name, complete := range pending {
|
|
||||||
if !complete {
|
|
||||||
err = ApplyMigration(db, name)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func AddSource(db *sql.DB, name string) error {
|
func AddSource(db *sql.DB, name string) error {
|
||||||
_, err := db.Exec(`
|
_, err := db.Exec(`
|
||||||
insert into sources (name)
|
insert into sources (name)
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package core
|
package core
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
@ -9,57 +8,6 @@ import (
|
|||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func EphemeralDb(t *testing.T) *sql.DB {
|
|
||||||
db, err := sql.Open("sqlite3", ":memory:")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if err = InitDatabase(db); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if err = MigrateDatabase(db); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
return db
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInitIdempotency(t *testing.T) {
|
|
||||||
db, err := sql.Open("sqlite3", ":memory:")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer db.Close()
|
|
||||||
if err = InitDatabase(db); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if err = InitDatabase(db); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMigrations(t *testing.T) {
|
|
||||||
db := EphemeralDb(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
allMigrations, err := migrations.ReadDir("sql")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := db.Query("select name from migrations")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
count := 0
|
|
||||||
for rows.Next() {
|
|
||||||
count += 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if count != len(allMigrations) {
|
|
||||||
t.Fatalf("Expected %d migrations, got %d", len(allMigrations), count)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreateSource(t *testing.T) {
|
func TestCreateSource(t *testing.T) {
|
||||||
db := EphemeralDb(t)
|
db := EphemeralDb(t)
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
Loading…
Reference in New Issue
Block a user