package core

import (
	"database/sql"
	"testing"

	_ "github.com/mattn/go-sqlite3"
)

func EphemeralDb(t *testing.T) DB {
	// We don't use OpenDb here because you can't open two connections to the same memory mem
	mem, err := sql.Open("sqlite3", ":memory:")
	if err != nil {
		t.Fatal(err)
	}
	if _, err = defaultPragma(mem); err != nil {
		t.Fatal(err)
	}
	db := new(RoRwDb)
	db.ro = mem
	db.rw = mem
	if err = InitDatabase(db); err != nil {
		t.Fatal(err)
	}
	if err = MigrateDatabase(db); err != nil {
		t.Fatal(err)
	}
	return db
}

func TestInitIdempotency(t *testing.T) {
	mem, err := sql.Open("sqlite3", ":memory:")
	if err != nil {
		t.Fatal(err)
	}
	db := new(RoRwDb)
	db.ro = mem
	db.rw = mem
	if err = InitDatabase(db); err != nil {
		t.Fatal(err)
	}
	if err = InitDatabase(db); err != nil {
		t.Fatal(err)
	}
}

func TestMigrations(t *testing.T) {
	db := EphemeralDb(t)

	allMigrations, err := migrations.ReadDir("sql")
	if err != nil {
		t.Fatal(err)
	}

	rows, err := db.Query("select name from migrations")
	if err != nil {
		t.Fatal(err)
	}
	count := 0
	for rows.Next() {
		count += 1
	}

	if count != len(allMigrations) {
		t.Fatalf("Expected %d migrations, got %d", len(allMigrations), count)
	}
}