From 5798190254a69245356a4c0a18bf65d72788e5b2 Mon Sep 17 00:00:00 2001 From: Tim Van Baak Date: Thu, 16 Jan 2025 13:46:30 -0800 Subject: [PATCH] Basic adding and deactivating items --- db.go | 97 ++++++++++++++++++++++++++++++++++++++++++++ db_test.go | 115 ++++++++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 201 insertions(+), 11 deletions(-) diff --git a/db.go b/db.go index a40cea4..009a2f0 100644 --- a/db.go +++ b/db.go @@ -2,10 +2,23 @@ package main import ( "database/sql" + "fmt" _ "github.com/mattn/go-sqlite3" ) +type Item struct { + source string + id string + created int + active bool + title string + author string + body string + link string + time int +} + func InitDatabase(db *sql.DB) error { db.Exec(` create table migrations (name text) strict; @@ -58,3 +71,87 @@ func MigrateDatabase(db *sql.DB) error { return nil } + +func AddSource(db *sql.DB, name string) error { + _, err := db.Exec(` + insert into sources (name) + values (?) + `, name) + + return err +} + +func DeleteSource(db *sql.DB, name string) error { + _, err := db.Exec(` + delete from sources + where name = ? + `, name) + + return err +} + +func AddItem( + db *sql.DB, + source string, + id string, + title string, + author string, + body string, + link string, + time int, +) error { + _, err := db.Exec(` + insert into items (source, id, active, title, author, body, link, time) + values (?, ?, ?, ?, ?, ?, ?, ?) + `, source, id, true, title, author, body, link, time) + + return err +} + +func DeactivateItem(db *sql.DB, source string, id string) error { + res, err := db.Exec(` + update items + set active = 0 + where source = ? and id = ? + `, source, id) + if err != nil { + return err + } + num, err := res.RowsAffected() + if err != nil { + return err + } + if num == 0 { + return fmt.Errorf("item %s/%s not found", source, id) + } + return nil +} + +func GetActiveItems(db *sql.DB, source string) ([]Item, error) { + rows, err := db.Query(` + select + source, + id, + created, + active, + title, + author, + body, + link, + time + from items + where + source = ? + and active <> 0 + `, source) + if err != nil { + return nil, err + } + var items []Item + for rows.Next() { + var item Item + rows.Scan(&item.source, &item.id, &item.created, &item.active, &item.title, &item.author, &item.body, &item.link, &item.time) + items = append(items, item) + } + return items, nil +} diff --git a/db_test.go b/db_test.go index 8f80e4d..d6bde7b 100644 --- a/db_test.go +++ b/db_test.go @@ -2,31 +2,124 @@ package main import ( "database/sql" + "fmt" + "slices" "testing" _ "github.com/mattn/go-sqlite3" ) -func TestMigrations(t *testing.T) { +func EphemeralDb(t *testing.T) *sql.DB { db, err := sql.Open("sqlite3", ":memory:") if err != nil { t.Fatal(err) } - defer db.Close() + if err = InitDatabase(db); err != nil { + t.Fatal(err) + } + if err = MigrateDatabase(db); err != nil { + t.Fatal(err) + } + return db +} - err = InitDatabase(db) - if err != nil { - t.Fatal(err) - } - err = MigrateDatabase(db) - if err != nil { - t.Fatal(err) - } +func TestMigrations(t *testing.T) { + db := EphemeralDb(t) + defer db.Close() var count int row := db.QueryRow("select count(name) from migrations") - row.Scan(&count) + if err := row.Scan(&count); err != nil { + t.Fatal(err) + } if count != 1 { t.Fatalf("Unexpected migration count: %d", count) } } + +func TestCreateSource(t *testing.T) { + db := EphemeralDb(t) + defer db.Close() + + if err := AddSource(db, "one"); err != nil { + t.Fatal(err) + } + if err := AddSource(db, "two"); err != nil { + t.Fatal(err) + } + if err := AddSource(db, "three"); err != nil { + t.Fatal(err) + } + if err := DeleteSource(db, "two"); err != nil { + t.Fatal(err) + } + + rows, err := db.Query("select name from sources") + if err != nil { + t.Fatal(err) + } + var names []string + expected := []string{"one", "three"} + for rows.Next() { + var name string + rows.Scan(&name) + names = append(names, name) + } + for i := 0; i < len(expected); i += 1 { + if !slices.Contains(names, expected[i]) { + t.Fatalf("missing %s, have: %v", expected[i], names) + } + } +} + +func AssertItemIs(t *testing.T, item Item, expected string) { + actual := fmt.Sprintf( + "%s/%s/%t/%s/%s/%s/%s/%d", + item.source, + item.id, + item.active, + item.title, + item.author, + item.body, + item.link, + item.time, + ) + if actual != expected { + t.Fatalf("expected %s, got %s", expected, actual) + } +} + +func TestAddItem(t *testing.T) { + db := EphemeralDb(t) + defer db.Close() + if err := AddSource(db, "test"); err != nil { + t.Fatal(err) + } + + if err := AddItem(db, "test", "one", "", "", "", "", 0); err != nil { + t.Fatal(err) + } + if err := AddItem(db, "test", "two", "title", "author", "body", "link", 123456); err != nil { + t.Fatal(err) + } + items, err := GetActiveItems(db, "test") + if err != nil { + t.Fatal(err) + } + if len(items) != 2 { + t.Fatal("should get two items") + } + AssertItemIs(t, items[0], "test/one/true/////0") + AssertItemIs(t, items[1], "test/two/true/title/author/body/link/123456") + + if err = DeactivateItem(db, "test", "one"); err != nil { + t.Fatal(err) + } + items, err = GetActiveItems(db, "test") + if err != nil { + t.Fatal(err) + } + if len(items) != 1 { + t.Fatal("should get one item") + } +}