From 44d3444bad0d8d2a4297d9828b76540d92de54dc Mon Sep 17 00:00:00 2001 From: Dhruv Thakur <13575379+dhth@users.noreply.github.com> Date: Tue, 13 Aug 2024 08:52:35 +0200 Subject: [PATCH 1/6] test: add groundwork for database tests --- cmd/db.go | 34 ---- cmd/guide.go | 2 +- cmd/import.go | 8 +- cmd/root.go | 12 +- cmd/tasks.go | 2 +- internal/persistence/init.go | 39 ++++ .../persistence/migrations.go | 20 +- .../persistence/migrations_test.go | 2 +- internal/persistence/queries.go | 72 +++++-- internal/persistence/queries_test.go | 180 ++++++++++++++++++ internal/types/types.go | 1 - internal/ui/cmds.go | 18 +- 12 files changed, 310 insertions(+), 80 deletions(-) create mode 100644 internal/persistence/init.go rename cmd/db_migrations.go => internal/persistence/migrations.go (78%) rename cmd/db_migrations_test.go => internal/persistence/migrations_test.go (93%) create mode 100644 internal/persistence/queries_test.go diff --git a/cmd/db.go b/cmd/db.go index ac102df..63d0770 100644 --- a/cmd/db.go +++ b/cmd/db.go @@ -2,7 +2,6 @@ package cmd import ( "database/sql" - "time" ) func getDB(dbpath string) (*sql.DB, error) { @@ -14,36 +13,3 @@ func getDB(dbpath string) (*sql.DB, error) { db.SetMaxIdleConns(1) return db, err } - -func initDB(db *sql.DB) error { - // these init queries cannot be changed once omm is released; only further - // migrations can be added, which are run when omm sees a difference between - // the values in the db_versions table and latestDBVersion - _, err := db.Exec(` -CREATE TABLE IF NOT EXISTS db_versions ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - version INTEGER NOT NULL, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP -); - -CREATE TABLE IF NOT EXISTS task ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - summary TEXT NOT NULL, - active BOOLEAN NOT NULL DEFAULT true, - created_at TIMESTAMP NOT NULL, - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP -); - -CREATE TABLE task_sequence ( - id INTEGER PRIMARY KEY, - sequence JSON NOT NULL -); - -INSERT INTO task_sequence (id, sequence) VALUES (1, '[]'); - -INSERT INTO db_versions (version, created_at) -VALUES (1, ?); -`, time.Now().UTC()) - - return err -} diff --git a/cmd/guide.go b/cmd/guide.go index 77dbe9c..36ab72d 100644 --- a/cmd/guide.go +++ b/cmd/guide.go @@ -90,7 +90,7 @@ func insertGuideTasks(db *sql.DB) error { } } - err = pers.InsertTasksIntoDB(db, tasks) + err = pers.InsertTasks(db, tasks) return err } diff --git a/cmd/import.go b/cmd/import.go index fa9874d..461312c 100644 --- a/cmd/import.go +++ b/cmd/import.go @@ -12,7 +12,7 @@ import ( var errWillExceedCapacity = errors.New("import will exceed capacity") func importTask(db *sql.DB, taskSummary string) error { - numTasks, err := pers.FetchNumActiveTasksFromDB(db) + numTasks, err := pers.FetchNumActiveTasksShown(db) if err != nil { return err } @@ -21,11 +21,11 @@ func importTask(db *sql.DB, taskSummary string) error { } now := time.Now() - return pers.ImportTaskIntoDB(db, taskSummary, true, now, now) + return pers.ImportTask(db, taskSummary, true, now, now) } func importTasks(db *sql.DB, taskSummaries []string) error { - numTasks, err := pers.FetchNumActiveTasksFromDB(db) + numTasks, err := pers.FetchNumActiveTasksShown(db) if err != nil { return err } @@ -34,5 +34,5 @@ func importTasks(db *sql.DB, taskSummaries []string) error { } now := time.Now() - return pers.ImportTaskSummariesIntoDB(db, taskSummaries, true, now, now) + return pers.ImportTaskSummaries(db, taskSummaries, true, now, now) } diff --git a/cmd/root.go b/cmd/root.go index bddf99e..d903646 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -96,11 +96,11 @@ func setupDB(dbPathFull string) (*sql.DB, error) { return nil, fmt.Errorf("%w: %s", errCouldntCreateDB, err.Error()) } - err = initDB(db) + err = pers.InitDB(db) if err != nil { return nil, fmt.Errorf("%w: %s", errCouldntInitializeDB, err.Error()) } - err = upgradeDB(db, 1) + err = pers.UpgradeDB(db, 1) if err != nil { return nil, err } @@ -109,7 +109,7 @@ func setupDB(dbPathFull string) (*sql.DB, error) { if err != nil { return nil, fmt.Errorf("%w: %s", errCouldntOpenDB, err.Error()) } - err = upgradeDBIfNeeded(db) + err = pers.UpgradeDBIfNeeded(db) if err != nil { return nil, err } @@ -213,19 +213,19 @@ Clean up error: %s %s `, reportIssueMsg) - case errors.Is(err, errCouldntFetchDBVersion): + case errors.Is(err, pers.ErrCouldntFetchDBVersion): fmt.Fprintf(os.Stderr, `Couldn't get omm's latest database version. This is a fatal error. %s `, reportIssueMsg) - case errors.Is(err, errDBDowngraded): + case errors.Is(err, pers.ErrDBDowngraded): fmt.Fprintf(os.Stderr, `Looks like you downgraded omm. You should either delete omm's database file (you will lose data by doing that), or upgrade omm to the latest version. %s `, reportIssueMsg) - case errors.Is(err, errDBMigrationFailed): + case errors.Is(err, pers.ErrDBMigrationFailed): fmt.Fprintf(os.Stderr, `Something went wrong migrating omm's database. This is not supposed to happen. You can try running omm by passing it a custom database file path (using --db-path; this will create a new database) to see if that fixes things. If that diff --git a/cmd/tasks.go b/cmd/tasks.go index 575a19d..9de3f01 100644 --- a/cmd/tasks.go +++ b/cmd/tasks.go @@ -9,7 +9,7 @@ import ( ) func printTasks(db *sql.DB, limit uint8, writer io.Writer) error { - tasks, err := pers.FetchActiveTasksFromDB(db, int(limit)) + tasks, err := pers.FetchActiveTasks(db, int(limit)) if err != nil { return err } diff --git a/internal/persistence/init.go b/internal/persistence/init.go new file mode 100644 index 0000000..b40acdc --- /dev/null +++ b/internal/persistence/init.go @@ -0,0 +1,39 @@ +package persistence + +import ( + "database/sql" + "time" +) + +func InitDB(db *sql.DB) error { + // these init queries cannot be changed once omm is released; only further + // migrations can be added, which are run when omm sees a difference between + // the values in the db_versions table and latestDBVersion + _, err := db.Exec(` +CREATE TABLE IF NOT EXISTS db_versions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + version INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS task ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + summary TEXT NOT NULL, + active BOOLEAN NOT NULL DEFAULT true, + created_at TIMESTAMP NOT NULL, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE task_sequence ( + id INTEGER PRIMARY KEY, + sequence JSON NOT NULL +); + +INSERT INTO task_sequence (id, sequence) VALUES (1, '[]'); + +INSERT INTO db_versions (version, created_at) +VALUES (1, ?); +`, time.Now().UTC()) + + return err +} diff --git a/cmd/db_migrations.go b/internal/persistence/migrations.go similarity index 78% rename from cmd/db_migrations.go rename to internal/persistence/migrations.go index f59c2e7..62b3a4a 100644 --- a/cmd/db_migrations.go +++ b/internal/persistence/migrations.go @@ -1,4 +1,4 @@ -package cmd +package persistence import ( "database/sql" @@ -12,9 +12,9 @@ const ( ) var ( - errDBDowngraded = errors.New("database downgraded") - errDBMigrationFailed = errors.New("database migration failed") - errCouldntFetchDBVersion = errors.New("couldn't fetch version") + ErrDBDowngraded = errors.New("database downgraded") + ErrDBMigrationFailed = errors.New("database migration failed") + ErrCouldntFetchDBVersion = errors.New("couldn't fetch version") ) type dbVersionInfo struct { @@ -54,18 +54,18 @@ LIMIT 1; return dbVersion, err } -func upgradeDBIfNeeded(db *sql.DB) error { +func UpgradeDBIfNeeded(db *sql.DB) error { latestVersionInDB, err := fetchLatestDBVersion(db) if err != nil { - return fmt.Errorf("%w: %s", errCouldntFetchDBVersion, err.Error()) + return fmt.Errorf("%w: %s", ErrCouldntFetchDBVersion, err.Error()) } if latestVersionInDB.version > latestDBVersion { - return errDBDowngraded + return ErrDBDowngraded } if latestVersionInDB.version < latestDBVersion { - err = upgradeDB(db, latestVersionInDB.version) + err = UpgradeDB(db, latestVersionInDB.version) if err != nil { return err } @@ -74,13 +74,13 @@ func upgradeDBIfNeeded(db *sql.DB) error { return nil } -func upgradeDB(db *sql.DB, currentVersion int) error { +func UpgradeDB(db *sql.DB, currentVersion int) error { migrations := getMigrations() for i := currentVersion + 1; i <= latestDBVersion; i++ { migrateQuery := migrations[i] migrateErr := runMigration(db, migrateQuery, i) if migrateErr != nil { - return fmt.Errorf("%w (version %d): %v", errDBMigrationFailed, i, migrateErr.Error()) + return fmt.Errorf("%w (version %d): %v", ErrDBMigrationFailed, i, migrateErr.Error()) } } return nil diff --git a/cmd/db_migrations_test.go b/internal/persistence/migrations_test.go similarity index 93% rename from cmd/db_migrations_test.go rename to internal/persistence/migrations_test.go index 03c40fe..6a75da2 100644 --- a/cmd/db_migrations_test.go +++ b/internal/persistence/migrations_test.go @@ -1,4 +1,4 @@ -package cmd +package persistence import ( "testing" diff --git a/internal/persistence/queries.go b/internal/persistence/queries.go index 60ea6ea..06fa5e8 100644 --- a/internal/persistence/queries.go +++ b/internal/persistence/queries.go @@ -13,7 +13,53 @@ const ( ContextMaxBytes = 4096 // 4KB seems to be sufficient for context ) -func FetchNumActiveTasksFromDB(db *sql.DB) (int, error) { +func fetchTaskSequence(db *sql.DB) ([]uint64, error) { + var seq []byte + seqRow := db.QueryRow("SELECT sequence from task_sequence where id=1;") + + err := seqRow.Scan(&seq) + if err != nil { + return nil, err + } + + var seqItems []uint64 + err = json.Unmarshal(seq, &seqItems) + if err != nil { + return nil, err + } + return seqItems, nil +} + +func fetchNumActiveTasks(db *sql.DB) (int, error) { + var rowCount int + err := db.QueryRow("SELECT count(*) from task where active is true").Scan(&rowCount) + return rowCount, err +} + +func fetchNumTotalTasks(db *sql.DB) (int, error) { + var rowCount int + err := db.QueryRow("SELECT count(*) from task").Scan(&rowCount) + return rowCount, err +} + +func fetchTaskByID(db *sql.DB, ID int) (types.Task, error) { + var entry types.Task + row := db.QueryRow(` +SELECT id, summary, active, context, created_at, updated_at +from task +WHERE id=?; +`, ID) + err := row.Scan(&entry.ID, + &entry.Summary, + &entry.Active, + &entry.Context, + &entry.CreatedAt, + &entry.UpdatedAt, + ) + return entry, err +} + +func FetchNumActiveTasksShown(db *sql.DB) (int, error) { row := db.QueryRow(` SELECT json_array_length(sequence) AS num_tasks FROM task_sequence where id=1; @@ -28,7 +74,7 @@ FROM task_sequence where id=1; return numTasks, nil } -func UpdateTaskSequenceInDB(db *sql.DB, sequence []uint64) error { +func UpdateTaskSequence(db *sql.DB, sequence []uint64) error { sequenceJSON, err := json.Marshal(sequence) if err != nil { return err @@ -52,7 +98,7 @@ WHERE id = 1; return nil } -func InsertTaskInDB(db *sql.DB, summary string, createdAt, updatedAt time.Time) (uint64, error) { +func InsertTask(db *sql.DB, summary string, createdAt, updatedAt time.Time) (uint64, error) { stmt, err := db.Prepare(` INSERT INTO task (summary, active, created_at, updated_at) VALUES (?, true, ?, ?); @@ -75,7 +121,7 @@ VALUES (?, true, ?, ?); return uint64(li), nil } -func ImportTaskIntoDB(db *sql.DB, summary string, active bool, createdAt, updatedAt time.Time) error { +func ImportTask(db *sql.DB, summary string, active bool, createdAt, updatedAt time.Time) error { tx, err := db.Begin() if err != nil { return err @@ -142,7 +188,7 @@ WHERE id = 1; return nil } -func ImportTaskSummariesIntoDB(db *sql.DB, tasks []string, active bool, createdAt, updatedAt time.Time) error { +func ImportTaskSummaries(db *sql.DB, tasks []string, active bool, createdAt, updatedAt time.Time) error { tx, err := db.Begin() if err != nil { return err @@ -227,7 +273,7 @@ WHERE id = 1; return nil } -func InsertTasksIntoDB(db *sql.DB, tasks []types.Task) error { +func InsertTasks(db *sql.DB, tasks []types.Task) error { tx, err := db.Begin() if err != nil { return err @@ -290,7 +336,7 @@ WHERE id = 1; return nil } -func UpdateTaskSummaryInDB(db *sql.DB, id uint64, summary string, updatedAt time.Time) error { +func UpdateTaskSummary(db *sql.DB, id uint64, summary string, updatedAt time.Time) error { stmt, err := db.Prepare(` UPDATE task SET summary = ?, @@ -309,7 +355,7 @@ WHERE id = ? return nil } -func UpdateTaskContextInDB(db *sql.DB, id uint64, context string, updatedAt time.Time) error { +func UpdateTaskContext(db *sql.DB, id uint64, context string, updatedAt time.Time) error { stmt, err := db.Prepare(` UPDATE task SET context = ?, @@ -328,7 +374,7 @@ WHERE id = ? return nil } -func UnsetTaskContextInDB(db *sql.DB, id uint64, updatedAt time.Time) error { +func UnsetTaskContext(db *sql.DB, id uint64, updatedAt time.Time) error { stmt, err := db.Prepare(` UPDATE task SET context = NULL, @@ -347,7 +393,7 @@ WHERE id = ? return nil } -func ChangeTaskStatusInDB(db *sql.DB, id uint64, active bool, updatedAt time.Time) error { +func ChangeTaskStatus(db *sql.DB, id uint64, active bool, updatedAt time.Time) error { stmt, err := db.Prepare(` UPDATE task SET active = ?, @@ -366,7 +412,7 @@ WHERE id = ? return nil } -func FetchActiveTasksFromDB(db *sql.DB, limit int) ([]types.Task, error) { +func FetchActiveTasks(db *sql.DB, limit int) ([]types.Task, error) { var tasks []types.Task rows, err := db.Query(` @@ -407,7 +453,7 @@ LIMIT ?; return tasks, nil } -func FetchInActiveTasksFromDB(db *sql.DB, limit int) ([]types.Task, error) { +func FetchInActiveTasks(db *sql.DB, limit int) ([]types.Task, error) { var tasks []types.Task rows, err := db.Query(` @@ -446,7 +492,7 @@ LIMIT ?; return tasks, nil } -func DeleteTaskInDB(db *sql.DB, id uint64) error { +func DeleteTask(db *sql.DB, id uint64) error { stmt, err := db.Prepare(` DELETE from task WHERE id=?; diff --git a/internal/persistence/queries_test.go b/internal/persistence/queries_test.go new file mode 100644 index 0000000..f5774e0 --- /dev/null +++ b/internal/persistence/queries_test.go @@ -0,0 +1,180 @@ +package persistence + +import ( + "database/sql" + "encoding/json" + "fmt" + "os" + "testing" + "time" + + "github.com/dhth/omm/internal/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + _ "modernc.org/sqlite" // sqlite driver +) + +var ( + testDB *sql.DB + numSeedActive = 3 + numSeedInActive = 2 +) + +func TestMain(m *testing.M) { + var err error + testDB, err = sql.Open("sqlite", ":memory:") + if err != nil { + panic(err) + } + + err = InitDB(testDB) + if err != nil { + panic(err) + } + err = UpgradeDB(testDB, 1) + if err != nil { + panic(err) + } + code := m.Run() + + testDB.Close() + + os.Exit(code) +} + +func cleanupDB(t *testing.T) { + var err error + for _, tbl := range []string{"task", "task_sequence"} { + _, err = testDB.Exec(fmt.Sprintf("DELETE FROM %s", tbl)) + if err != nil { + t.Fatalf("failed to clean up table %q: %v", tbl, err) + } + } +} + +func seedDB(t *testing.T, db *sql.DB) { + t.Helper() + + tasks := make([]types.Task, numSeedActive+numSeedInActive) + contexts := make([]string, numSeedActive+numSeedInActive) + now := time.Now().UTC() + counter := 0 + for range numSeedActive { + contexts[counter] = fmt.Sprintf("context for task %d", counter) + tasks[counter] = types.Task{ + Summary: fmt.Sprintf("prefix: task %d", counter), + Active: true, + Context: &contexts[counter], + CreatedAt: now, + UpdatedAt: now, + } + counter++ + } + for range numSeedInActive { + contexts[counter] = fmt.Sprintf("context for task %d", counter) + tasks[counter] = types.Task{ + Summary: fmt.Sprintf("prefix: task %d", counter), + Active: false, + Context: &contexts[counter], + CreatedAt: now, + UpdatedAt: now, + } + counter++ + } + for _, task := range tasks { + _, err := db.Exec(` +INSERT INTO task (summary, active, created_at, updated_at) +VALUES (?, ?, ?, ?)`, task.Summary, task.Active, task.CreatedAt, task.UpdatedAt) + if err != nil { + t.Fatalf("failed to insert data into table \"task\": %v", err) + } + } + + seqItems := make([]int, numSeedActive) + for i := range numSeedActive { + seqItems[i] = i + 1 + } + sequenceJSON, err := json.Marshal(seqItems) + if err != nil { + t.Fatalf("failed to marshall JSON data for seeding: %v", err) + } + + _, err = db.Exec(` +UPDATE task_sequence +SET sequence = ? +WHERE id = 1; +`, sequenceJSON) + if err != nil { + t.Fatalf("failed to insert data into table \"task_sequence\": %v", err) + } +} + +func TestImportTask(t *testing.T) { + t.Cleanup(func() { cleanupDB(t) }) + + // GIVEN + seedDB(t, testDB) + numActiveTasksBefore, err := fetchNumActiveTasks(testDB) + require.NoError(t, err) + numTotalTasksBefore, err := fetchNumTotalTasks(testDB) + require.NoError(t, err) + + // WHEN + summary := "prefix: an imported task" + now := time.Now().UTC() + err = ImportTask(testDB, summary, true, now, now) + require.NoError(t, err) + + // THEN + numActiveTasksAfter, err := fetchNumActiveTasks(testDB) + require.NoError(t, err) + assert.Equal(t, numActiveTasksAfter, numActiveTasksBefore+1, "number of active tasks didn't increase by 1") + + task, err := fetchTaskByID(testDB, numTotalTasksBefore+1) + require.NoError(t, err) + assert.True(t, task.Active) + assert.Equal(t, summary, task.Summary) + + seq, err := fetchTaskSequence(testDB) + require.NoError(t, err) + require.Equal(t, numActiveTasksAfter, len(seq), "number of tasks in task sequence doesn't match number of active tasks") + assert.Equal(t, seq[0], task.ID, "newly added task is not shown at the top of the list") +} + +func TestImportTaskSummaries(t *testing.T) { + t.Cleanup(func() { cleanupDB(t) }) + + // GIVEN + seedDB(t, testDB) + numActiveTasksBefore, err := fetchNumActiveTasks(testDB) + require.NoError(t, err) + numTotalTasksBefore, err := fetchNumTotalTasks(testDB) + require.NoError(t, err) + + // WHEN + summaries := []string{ + "prefix: imported task 1", + "prefix: imported task 2", + "prefix: imported task 3", + } + now := time.Now().UTC() + err = ImportTaskSummaries(testDB, summaries, true, now, now) + require.NoError(t, err) + + // THEN + numActiveTasksAfter, err := fetchNumActiveTasks(testDB) + require.NoError(t, err) + assert.Equal(t, numActiveTasksAfter, numActiveTasksBefore+len(summaries), "number of active tasks didn't increase by the correct amount") + + task, err := fetchTaskByID(testDB, numTotalTasksBefore+1) + require.NoError(t, err) + assert.True(t, task.Active) + assert.Equal(t, summaries[0], task.Summary) + + seq, err := fetchTaskSequence(testDB) + require.NoError(t, err) + require.Equal(t, numActiveTasksAfter, len(seq), "number of tasks in task sequence doesn't match number of active tasks") + for i := range summaries { + assert.Equal(t, numTotalTasksBefore+i+1, int(seq[i]), "task at sequence position %d is incorrect", i+1) + } +} diff --git a/internal/types/types.go b/internal/types/types.go index b3a7729..9d575ec 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -39,7 +39,6 @@ var ( ) type Task struct { - ItemTitle string ID uint64 Summary string Context *string diff --git a/internal/ui/cmds.go b/internal/ui/cmds.go index da222fb..f79083e 100644 --- a/internal/ui/cmds.go +++ b/internal/ui/cmds.go @@ -21,21 +21,21 @@ func hideHelp(interval time.Duration) tea.Cmd { func updateTaskSequence(db *sql.DB, sequence []uint64) tea.Cmd { return func() tea.Msg { - err := pers.UpdateTaskSequenceInDB(db, sequence) + err := pers.UpdateTaskSequence(db, sequence) return taskSequenceUpdatedMsg{err} } } func createTask(db *sql.DB, summary string, createdAt, updatedAt time.Time) tea.Cmd { return func() tea.Msg { - id, err := pers.InsertTaskInDB(db, summary, createdAt, updatedAt) + id, err := pers.InsertTask(db, summary, createdAt, updatedAt) return taskCreatedMsg{id, summary, createdAt, updatedAt, err} } } func deleteTask(db *sql.DB, id uint64, index int, active bool) tea.Cmd { return func() tea.Msg { - err := pers.DeleteTaskInDB(db, id) + err := pers.DeleteTask(db, id) return taskDeletedMsg{id, index, active, err} } } @@ -43,7 +43,7 @@ func deleteTask(db *sql.DB, id uint64, index int, active bool) tea.Cmd { func updateTaskSummary(db *sql.DB, listIndex int, id uint64, summary string) tea.Cmd { return func() tea.Msg { now := time.Now() - err := pers.UpdateTaskSummaryInDB(db, id, summary, now) + err := pers.UpdateTaskSummary(db, id, summary, now) return taskSummaryUpdatedMsg{listIndex, id, summary, now, err} } } @@ -53,9 +53,9 @@ func updateTaskContext(db *sql.DB, listIndex int, id uint64, context string, lis var err error now := time.Now() if context == "" { - err = pers.UnsetTaskContextInDB(db, id, now) + err = pers.UnsetTaskContext(db, id, now) } else { - err = pers.UpdateTaskContextInDB(db, id, context, now) + err = pers.UpdateTaskContext(db, id, context, now) } return taskContextUpdatedMsg{listIndex, list, id, context, now, err} } @@ -63,7 +63,7 @@ func updateTaskContext(db *sql.DB, listIndex int, id uint64, context string, lis func changeTaskStatus(db *sql.DB, listIndex int, id uint64, active bool, updatedAt time.Time) tea.Cmd { return func() tea.Msg { - err := pers.ChangeTaskStatusInDB(db, id, active, updatedAt) + err := pers.ChangeTaskStatus(db, id, active, updatedAt) return taskStatusChangedMsg{listIndex, id, active, updatedAt, err} } } @@ -74,9 +74,9 @@ func fetchTasks(db *sql.DB, active bool, limit int) tea.Cmd { var err error switch active { case true: - tasks, err = pers.FetchActiveTasksFromDB(db, limit) + tasks, err = pers.FetchActiveTasks(db, limit) case false: - tasks, err = pers.FetchInActiveTasksFromDB(db, limit) + tasks, err = pers.FetchInActiveTasks(db, limit) } return tasksFetched{tasks, active, err} } From 1f5851eefac05a21c8ef8344f75dc3e6478d0e58 Mon Sep 17 00:00:00 2001 From: Dhruv Thakur <13575379+dhth@users.noreply.github.com> Date: Sat, 24 Aug 2024 15:30:09 +0200 Subject: [PATCH 2/6] fix: clean up for tests --- cmd/import.go | 6 ++-- internal/persistence/queries.go | 52 ++++++++++++---------------- internal/persistence/queries_test.go | 39 +++++++++++++-------- 3 files changed, 52 insertions(+), 45 deletions(-) diff --git a/cmd/import.go b/cmd/import.go index 461312c..8a591ac 100644 --- a/cmd/import.go +++ b/cmd/import.go @@ -21,7 +21,8 @@ func importTask(db *sql.DB, taskSummary string) error { } now := time.Now() - return pers.ImportTask(db, taskSummary, true, now, now) + _, err = pers.ImportTask(db, taskSummary, true, now, now) + return err } func importTasks(db *sql.DB, taskSummaries []string) error { @@ -34,5 +35,6 @@ func importTasks(db *sql.DB, taskSummaries []string) error { } now := time.Now() - return pers.ImportTaskSummaries(db, taskSummaries, true, now, now) + _, err = pers.ImportTaskSummaries(db, taskSummaries, true, now, now) + return err } diff --git a/internal/persistence/queries.go b/internal/persistence/queries.go index 06fa5e8..479e916 100644 --- a/internal/persistence/queries.go +++ b/internal/persistence/queries.go @@ -36,13 +36,7 @@ func fetchNumActiveTasks(db *sql.DB) (int, error) { return rowCount, err } -func fetchNumTotalTasks(db *sql.DB) (int, error) { - var rowCount int - err := db.QueryRow("SELECT count(*) from task").Scan(&rowCount) - return rowCount, err -} - -func fetchTaskByID(db *sql.DB, ID int) (types.Task, error) { +func fetchTaskByID(db *sql.DB, ID int64) (types.Task, error) { var entry types.Task row := db.QueryRow(` SELECT id, summary, active, context, created_at, updated_at @@ -121,10 +115,10 @@ VALUES (?, true, ?, ?); return uint64(li), nil } -func ImportTask(db *sql.DB, summary string, active bool, createdAt, updatedAt time.Time) error { +func ImportTask(db *sql.DB, summary string, active bool, createdAt, updatedAt time.Time) (int64, error) { tx, err := db.Begin() if err != nil { - return err + return -1, err } defer func() { _ = tx.Rollback() @@ -135,12 +129,12 @@ VALUES (?, ?, ?, ?);` res, err := tx.Exec(query, summary, active, createdAt.UTC(), updatedAt.UTC()) if err != nil { - return err + return -1, err } lastInsertID, err := res.LastInsertId() if err != nil { - return err + return -1, err } var seq []byte @@ -148,13 +142,13 @@ VALUES (?, ?, ?, ?);` err = seqRow.Scan(&seq) if err != nil { - return err + return -1, err } var seqItems []int err = json.Unmarshal(seq, &seqItems) if err != nil { - return err + return -1, err } newTaskID := make([]int, 1) @@ -163,7 +157,7 @@ VALUES (?, ?, ?, ?);` sequenceJSON, err := json.Marshal(updatedSeqItems) if err != nil { - return err + return -1, err } seqUpdateStmt, err := tx.Prepare(` @@ -172,26 +166,26 @@ SET sequence = ? WHERE id = 1; `) if err != nil { - return err + return -1, err } defer seqUpdateStmt.Close() _, err = seqUpdateStmt.Exec(sequenceJSON) if err != nil { - return err + return -1, err } err = tx.Commit() if err != nil { - return err + return -1, err } - return nil + return lastInsertID, nil } -func ImportTaskSummaries(db *sql.DB, tasks []string, active bool, createdAt, updatedAt time.Time) error { +func ImportTaskSummaries(db *sql.DB, tasks []string, active bool, createdAt, updatedAt time.Time) (int64, error) { tx, err := db.Begin() if err != nil { - return err + return -1, err } defer func() { _ = tx.Rollback() @@ -216,12 +210,12 @@ VALUES ` res, err := tx.Exec(query, values...) if err != nil { - return err + return -1, err } lastInsertID, err := res.LastInsertId() if err != nil { - return err + return -1, err } var seq []byte @@ -229,13 +223,13 @@ VALUES ` err = seqRow.Scan(&seq) if err != nil { - return err + return -1, err } var seqItems []int err = json.Unmarshal(seq, &seqItems) if err != nil { - return err + return -1, err } newTaskIDs := make([]int, len(tasks)) @@ -248,7 +242,7 @@ VALUES ` sequenceJSON, err := json.Marshal(updatedSeqItems) if err != nil { - return err + return -1, err } seqUpdateStmt, err := tx.Prepare(` @@ -257,20 +251,20 @@ SET sequence = ? WHERE id = 1; `) if err != nil { - return err + return -1, err } defer seqUpdateStmt.Close() _, err = seqUpdateStmt.Exec(sequenceJSON) if err != nil { - return err + return -1, err } err = tx.Commit() if err != nil { - return err + return -1, err } - return nil + return lastInsertID, nil } func InsertTasks(db *sql.DB, tasks []types.Task) error { diff --git a/internal/persistence/queries_test.go b/internal/persistence/queries_test.go index f5774e0..fe49693 100644 --- a/internal/persistence/queries_test.go +++ b/internal/persistence/queries_test.go @@ -44,12 +44,18 @@ func TestMain(m *testing.M) { func cleanupDB(t *testing.T) { var err error - for _, tbl := range []string{"task", "task_sequence"} { + for _, tbl := range []string{"task"} { _, err = testDB.Exec(fmt.Sprintf("DELETE FROM %s", tbl)) if err != nil { t.Fatalf("failed to clean up table %q: %v", tbl, err) } } + _, err = testDB.Exec(`UPDATE task_sequence +SET sequence = '[]' +WHERE id = 1;`) + if err != nil { + t.Fatalf("failed to clean up table task_sequence: %v", err) + } } func seedDB(t *testing.T, db *sql.DB) { @@ -116,13 +122,11 @@ func TestImportTask(t *testing.T) { seedDB(t, testDB) numActiveTasksBefore, err := fetchNumActiveTasks(testDB) require.NoError(t, err) - numTotalTasksBefore, err := fetchNumTotalTasks(testDB) - require.NoError(t, err) // WHEN summary := "prefix: an imported task" now := time.Now().UTC() - err = ImportTask(testDB, summary, true, now, now) + lastID, err := ImportTask(testDB, summary, true, now, now) require.NoError(t, err) // THEN @@ -130,7 +134,7 @@ func TestImportTask(t *testing.T) { require.NoError(t, err) assert.Equal(t, numActiveTasksAfter, numActiveTasksBefore+1, "number of active tasks didn't increase by 1") - task, err := fetchTaskByID(testDB, numTotalTasksBefore+1) + task, err := fetchTaskByID(testDB, lastID) require.NoError(t, err) assert.True(t, task.Active) assert.Equal(t, summary, task.Summary) @@ -148,33 +152,40 @@ func TestImportTaskSummaries(t *testing.T) { seedDB(t, testDB) numActiveTasksBefore, err := fetchNumActiveTasks(testDB) require.NoError(t, err) - numTotalTasksBefore, err := fetchNumTotalTasks(testDB) - require.NoError(t, err) // WHEN - summaries := []string{ + newTaskSummaries := []string{ "prefix: imported task 1", "prefix: imported task 2", "prefix: imported task 3", } now := time.Now().UTC() - err = ImportTaskSummaries(testDB, summaries, true, now, now) + lastID, err := ImportTaskSummaries(testDB, newTaskSummaries, true, now, now) require.NoError(t, err) // THEN numActiveTasksAfter, err := fetchNumActiveTasks(testDB) require.NoError(t, err) - assert.Equal(t, numActiveTasksAfter, numActiveTasksBefore+len(summaries), "number of active tasks didn't increase by the correct amount") + assert.Equal(t, numActiveTasksAfter, numActiveTasksBefore+len(newTaskSummaries), "number of active tasks didn't increase by the correct amount") - task, err := fetchTaskByID(testDB, numTotalTasksBefore+1) + task, err := fetchTaskByID(testDB, lastID) require.NoError(t, err) assert.True(t, task.Active) - assert.Equal(t, summaries[0], task.Summary) + assert.Equal(t, newTaskSummaries[2], task.Summary) seq, err := fetchTaskSequence(testDB) require.NoError(t, err) require.Equal(t, numActiveTasksAfter, len(seq), "number of tasks in task sequence doesn't match number of active tasks") - for i := range summaries { - assert.Equal(t, numTotalTasksBefore+i+1, int(seq[i]), "task at sequence position %d is incorrect", i+1) + + // ensure new task sequence is correct + // that is: + // imported task 1 + // imported task 2 + // imported task 3 + // ... old sequence + currentID := int(lastID) - len(newTaskSummaries) + 1 + for i := range len(newTaskSummaries) { + assert.Equal(t, currentID, int(seq[i]), "task at sequence position %d is incorrect", i+1) + currentID++ } } From 2f9f1badd1f1cf0144f7e2d69f013def18d0f9ab Mon Sep 17 00:00:00 2001 From: Dhruv Thakur <13575379+dhth@users.noreply.github.com> Date: Sat, 24 Aug 2024 17:14:01 +0200 Subject: [PATCH 3/6] refactor: remove redundant func --- cmd/guide.go | 2 +- cmd/import.go | 12 +- internal/persistence/queries.go | 103 +++++------------ internal/persistence/queries_test.go | 166 +++++++++++++++++++-------- 4 files changed, 159 insertions(+), 124 deletions(-) diff --git a/cmd/guide.go b/cmd/guide.go index 36ab72d..052cb6a 100644 --- a/cmd/guide.go +++ b/cmd/guide.go @@ -90,7 +90,7 @@ func insertGuideTasks(db *sql.DB) error { } } - err = pers.InsertTasks(db, tasks) + _, err = pers.InsertTasks(db, tasks, true) return err } diff --git a/cmd/import.go b/cmd/import.go index 8a591ac..63f6416 100644 --- a/cmd/import.go +++ b/cmd/import.go @@ -7,6 +7,7 @@ import ( "time" pers "github.com/dhth/omm/internal/persistence" + "github.com/dhth/omm/internal/types" ) var errWillExceedCapacity = errors.New("import will exceed capacity") @@ -35,6 +36,15 @@ func importTasks(db *sql.DB, taskSummaries []string) error { } now := time.Now() - _, err = pers.ImportTaskSummaries(db, taskSummaries, true, now, now) + tasks := make([]types.Task, len(taskSummaries)) + for i, summ := range taskSummaries { + tasks[i] = types.Task{ + Summary: summ, + Active: true, + CreatedAt: now, + UpdatedAt: now, + } + } + _, err = pers.InsertTasks(db, tasks, true) return err } diff --git a/internal/persistence/queries.go b/internal/persistence/queries.go index 479e916..47347aa 100644 --- a/internal/persistence/queries.go +++ b/internal/persistence/queries.go @@ -36,6 +36,12 @@ func fetchNumActiveTasks(db *sql.DB) (int, error) { return rowCount, err } +func fetchNumTotalTasks(db *sql.DB) (int, error) { + var rowCount int + err := db.QueryRow("SELECT count(*) from task").Scan(&rowCount) + return rowCount, err +} + func fetchTaskByID(db *sql.DB, ID int64) (types.Task, error) { var entry types.Task row := db.QueryRow(` @@ -182,7 +188,7 @@ WHERE id = 1; return lastInsertID, nil } -func ImportTaskSummaries(db *sql.DB, tasks []string, active bool, createdAt, updatedAt time.Time) (int64, error) { +func InsertTasks(db *sql.DB, tasks []types.Task, insertAtTop bool) (int64, error) { tx, err := db.Begin() if err != nil { return -1, err @@ -191,21 +197,19 @@ func ImportTaskSummaries(db *sql.DB, tasks []string, active bool, createdAt, upd _ = tx.Rollback() }() - query := `INSERT INTO task (summary, active, created_at, updated_at) + query := `INSERT INTO task (summary, context, active, created_at, updated_at) VALUES ` values := make([]interface{}, 0, len(tasks)*4) - ca := createdAt.UTC() - ua := updatedAt.UTC() - - for i, ts := range tasks { + for i, t := range tasks { if i > 0 { query += "," } - query += "(?, ?, ?, ?)" - values = append(values, ts, active, ca, ua) + query += "(?, ?, ?, ?, ?)" + values = append(values, t.Summary, t.Context, t.Active, t.CreatedAt.UTC(), t.UpdatedAt.UTC()) } + query += ";" res, err := tx.Exec(query, values...) @@ -232,13 +236,21 @@ VALUES ` return -1, err } - newTaskIDs := make([]int, len(tasks)) - counter := 0 - for i := int(lastInsertID) - len(tasks) + 1; i <= int(lastInsertID); i++ { - newTaskIDs[counter] = i - counter++ + var newTaskIDs []int + taskID := int(lastInsertID) - len(tasks) + 1 + for _, t := range tasks { + if t.Active { + newTaskIDs = append(newTaskIDs, taskID) + } + taskID++ + } + + var updatedSeqItems []int + if insertAtTop { + updatedSeqItems = append(newTaskIDs, seqItems...) + } else { + updatedSeqItems = append(seqItems, newTaskIDs...) } - updatedSeqItems := append(newTaskIDs, seqItems...) sequenceJSON, err := json.Marshal(updatedSeqItems) if err != nil { @@ -267,69 +279,6 @@ WHERE id = 1; return lastInsertID, nil } -func InsertTasks(db *sql.DB, tasks []types.Task) error { - tx, err := db.Begin() - if err != nil { - return err - } - defer func() { - _ = tx.Rollback() - }() - - query := `INSERT INTO task (summary, context, active, created_at, updated_at) -VALUES ` - - values := make([]interface{}, 0, len(tasks)*4) - - var seqItems []int - seqCounter := 1 - for i, t := range tasks { - if i > 0 { - query += "," - } - query += "(?, ?, ?, ?, ?)" - values = append(values, t.Summary, t.Context, t.Active, t.CreatedAt.UTC(), t.UpdatedAt.UTC()) - - if t.Active { - seqItems = append(seqItems, seqCounter) - } - seqCounter++ - } - - query += ";" - - _, err = tx.Exec(query, values...) - if err != nil { - return err - } - - sequenceJSON, err := json.Marshal(seqItems) - if err != nil { - return err - } - - seqUpdateStmt, err := tx.Prepare(` -UPDATE task_sequence -SET sequence = ? -WHERE id = 1; -`) - if err != nil { - return err - } - defer seqUpdateStmt.Close() - - _, err = seqUpdateStmt.Exec(sequenceJSON) - if err != nil { - return err - } - - err = tx.Commit() - if err != nil { - return err - } - return nil -} - func UpdateTaskSummary(db *sql.DB, id uint64, summary string, updatedAt time.Time) error { stmt, err := db.Prepare(` UPDATE task diff --git a/internal/persistence/queries_test.go b/internal/persistence/queries_test.go index fe49693..9d0d06a 100644 --- a/internal/persistence/queries_test.go +++ b/internal/persistence/queries_test.go @@ -14,11 +14,7 @@ import ( _ "modernc.org/sqlite" // sqlite driver ) -var ( - testDB *sql.DB - numSeedActive = 3 - numSeedInActive = 2 -) +var testDB *sql.DB func TestMain(m *testing.M) { var err error @@ -49,6 +45,10 @@ func cleanupDB(t *testing.T) { if err != nil { t.Fatalf("failed to clean up table %q: %v", tbl, err) } + _, err := testDB.Exec("DELETE FROM sqlite_sequence WHERE name=?;", tbl) + if err != nil { + t.Fatalf("failed to reset auto increment for table %q: %v", tbl, err) + } } _, err = testDB.Exec(`UPDATE task_sequence SET sequence = '[]' @@ -58,14 +58,15 @@ WHERE id = 1;`) } } -func seedDB(t *testing.T, db *sql.DB) { - t.Helper() +func getSampleTasks() ([]types.Task, int, int) { + numActive := 3 + numInactive := 2 - tasks := make([]types.Task, numSeedActive+numSeedInActive) - contexts := make([]string, numSeedActive+numSeedInActive) + tasks := make([]types.Task, numActive+numInactive) + contexts := make([]string, numActive+numInactive) now := time.Now().UTC() counter := 0 - for range numSeedActive { + for range numActive { contexts[counter] = fmt.Sprintf("context for task %d", counter) tasks[counter] = types.Task{ Summary: fmt.Sprintf("prefix: task %d", counter), @@ -76,7 +77,7 @@ func seedDB(t *testing.T, db *sql.DB) { } counter++ } - for range numSeedInActive { + for range numInactive { contexts[counter] = fmt.Sprintf("context for task %d", counter) tasks[counter] = types.Task{ Summary: fmt.Sprintf("prefix: task %d", counter), @@ -87,6 +88,15 @@ func seedDB(t *testing.T, db *sql.DB) { } counter++ } + + return tasks, numActive, numInactive +} + +func seedDB(t *testing.T, db *sql.DB) (int, int) { + t.Helper() + + tasks, na, ni := getSampleTasks() + for _, task := range tasks { _, err := db.Exec(` INSERT INTO task (summary, active, created_at, updated_at) @@ -96,8 +106,8 @@ VALUES (?, ?, ?, ?)`, task.Summary, task.Active, task.CreatedAt, task.UpdatedAt) } } - seqItems := make([]int, numSeedActive) - for i := range numSeedActive { + seqItems := make([]int, na) + for i := range na { seqItems[i] = i + 1 } sequenceJSON, err := json.Marshal(seqItems) @@ -113,15 +123,15 @@ WHERE id = 1; if err != nil { t.Fatalf("failed to insert data into table \"task_sequence\": %v", err) } + + return na, ni } func TestImportTask(t *testing.T) { t.Cleanup(func() { cleanupDB(t) }) // GIVEN - seedDB(t, testDB) - numActiveTasksBefore, err := fetchNumActiveTasks(testDB) - require.NoError(t, err) + na, _ := seedDB(t, testDB) // WHEN summary := "prefix: an imported task" @@ -132,7 +142,7 @@ func TestImportTask(t *testing.T) { // THEN numActiveTasksAfter, err := fetchNumActiveTasks(testDB) require.NoError(t, err) - assert.Equal(t, numActiveTasksAfter, numActiveTasksBefore+1, "number of active tasks didn't increase by 1") + assert.Equal(t, numActiveTasksAfter, na+1, "number of active tasks didn't increase by 1") task, err := fetchTaskByID(testDB, lastID) require.NoError(t, err) @@ -141,51 +151,117 @@ func TestImportTask(t *testing.T) { seq, err := fetchTaskSequence(testDB) require.NoError(t, err) - require.Equal(t, numActiveTasksAfter, len(seq), "number of tasks in task sequence doesn't match number of active tasks") - assert.Equal(t, seq[0], task.ID, "newly added task is not shown at the top of the list") + assert.Equal(t, seq, []uint64{6, 1, 2, 3}, "task sequence isn't correct") } -func TestImportTaskSummaries(t *testing.T) { +func TestInsertTasksWorksWithEmptyTaskList(t *testing.T) { t.Cleanup(func() { cleanupDB(t) }) // GIVEN - seedDB(t, testDB) - numActiveTasksBefore, err := fetchNumActiveTasks(testDB) + // WHEN + tasks, na, ni := getSampleTasks() + lastID, err := InsertTasks(testDB, tasks, true) + assert.Equal(t, lastID, int64(na+ni), "last ID is not correct") require.NoError(t, err) + // THEN + numActiveRes, err := fetchNumActiveTasks(testDB) + require.NoError(t, err) + assert.Equal(t, numActiveRes, na, "number of active tasks didn't increase by the correct amount") + + numTotalRes, err := fetchNumTotalTasks(testDB) + require.NoError(t, err) + assert.Equal(t, numTotalRes, na+ni, "number of total tasks didn't increase by the correct amount") + + lastTask, err := fetchTaskByID(testDB, lastID) + require.NoError(t, err) + assert.Equal(t, tasks[len(tasks)-1].Active, lastTask.Active) + assert.Equal(t, tasks[len(tasks)-1].Summary, lastTask.Summary) + assert.Equal(t, tasks[len(tasks)-1].Context, lastTask.Context) + + seq, err := fetchTaskSequence(testDB) + require.NoError(t, err) + assert.Equal(t, seq, []uint64{1, 2, 3}, "task sequence isn't correct") +} + +func TestInsertTasksAddsTasksAtTheTop(t *testing.T) { + t.Cleanup(func() { cleanupDB(t) }) + + // GIVEN + na, ni := seedDB(t, testDB) + // WHEN - newTaskSummaries := []string{ - "prefix: imported task 1", - "prefix: imported task 2", - "prefix: imported task 3", - } now := time.Now().UTC() - lastID, err := ImportTaskSummaries(testDB, newTaskSummaries, true, now, now) + tasks := []types.Task{ + { + Summary: "prefix: new task 1", + Active: true, + CreatedAt: now, + UpdatedAt: now, + }, + { + Summary: "prefix: new inactive task 1", + Active: false, + CreatedAt: now, + UpdatedAt: now, + }, + { + Summary: "prefix: new task 3", + Active: true, + CreatedAt: now, + UpdatedAt: now, + }, + } + + _, err := InsertTasks(testDB, tasks, true) require.NoError(t, err) // THEN - numActiveTasksAfter, err := fetchNumActiveTasks(testDB) + numActiveRes, err := fetchNumActiveTasks(testDB) require.NoError(t, err) - assert.Equal(t, numActiveTasksAfter, numActiveTasksBefore+len(newTaskSummaries), "number of active tasks didn't increase by the correct amount") + assert.Equal(t, numActiveRes, na+2, "number of active tasks didn't increase by the correct amount") - task, err := fetchTaskByID(testDB, lastID) + numTotalRes, err := fetchNumTotalTasks(testDB) require.NoError(t, err) - assert.True(t, task.Active) - assert.Equal(t, newTaskSummaries[2], task.Summary) + assert.Equal(t, numTotalRes, na+ni+3, "number of total tasks didn't increase by the correct amount") seq, err := fetchTaskSequence(testDB) require.NoError(t, err) - require.Equal(t, numActiveTasksAfter, len(seq), "number of tasks in task sequence doesn't match number of active tasks") - - // ensure new task sequence is correct - // that is: - // imported task 1 - // imported task 2 - // imported task 3 - // ... old sequence - currentID := int(lastID) - len(newTaskSummaries) + 1 - for i := range len(newTaskSummaries) { - assert.Equal(t, currentID, int(seq[i]), "task at sequence position %d is incorrect", i+1) - currentID++ + assert.Equal(t, seq, []uint64{6, 8, 1, 2, 3}, "task sequence isn't correct") +} + +func TestInsertTasksAddsTasksAtTheEnd(t *testing.T) { + t.Cleanup(func() { cleanupDB(t) }) + + // GIVEN + na, _ := seedDB(t, testDB) + + // WHEN + now := time.Now().UTC() + tasks := []types.Task{ + { + Summary: "prefix: new task 1", + Active: true, + CreatedAt: now, + UpdatedAt: now, + }, + { + Summary: "prefix: new task 2", + Active: true, + CreatedAt: now, + UpdatedAt: now, + }, } + + _, err := InsertTasks(testDB, tasks, false) + require.NoError(t, err) + + // THEN + numActiveRes, err := fetchNumActiveTasks(testDB) + require.NoError(t, err) + assert.Equal(t, numActiveRes, na+2, "number of active tasks didn't increase by the correct amount") + + seq, err := fetchTaskSequence(testDB) + require.NoError(t, err) + assert.Equal(t, seq, []uint64{1, 2, 3, 6, 7}, "task sequence isn't correct") } From 7befdf8cb5c26e9baf364825d76135da4fd02ccc Mon Sep 17 00:00:00 2001 From: Dhruv Thakur <13575379+dhth@users.noreply.github.com> Date: Sat, 24 Aug 2024 22:31:25 +0200 Subject: [PATCH 4/6] refactor: remove redundant function --- cmd/import.go | 8 +++- internal/persistence/queries.go | 67 ---------------------------- internal/persistence/queries_test.go | 63 ++++++++++++-------------- 3 files changed, 35 insertions(+), 103 deletions(-) diff --git a/cmd/import.go b/cmd/import.go index 63f6416..375949d 100644 --- a/cmd/import.go +++ b/cmd/import.go @@ -22,7 +22,13 @@ func importTask(db *sql.DB, taskSummary string) error { } now := time.Now() - _, err = pers.ImportTask(db, taskSummary, true, now, now) + task := types.Task{ + Summary: taskSummary, + Active: true, + CreatedAt: now, + UpdatedAt: now, + } + _, err = pers.InsertTasks(db, []types.Task{task}, true) return err } diff --git a/internal/persistence/queries.go b/internal/persistence/queries.go index 47347aa..aaf84b7 100644 --- a/internal/persistence/queries.go +++ b/internal/persistence/queries.go @@ -121,73 +121,6 @@ VALUES (?, true, ?, ?); return uint64(li), nil } -func ImportTask(db *sql.DB, summary string, active bool, createdAt, updatedAt time.Time) (int64, error) { - tx, err := db.Begin() - if err != nil { - return -1, err - } - defer func() { - _ = tx.Rollback() - }() - - query := `INSERT INTO task (summary, active, created_at, updated_at) -VALUES (?, ?, ?, ?);` - - res, err := tx.Exec(query, summary, active, createdAt.UTC(), updatedAt.UTC()) - if err != nil { - return -1, err - } - - lastInsertID, err := res.LastInsertId() - if err != nil { - return -1, err - } - - var seq []byte - seqRow := tx.QueryRow("SELECT sequence from task_sequence where id=1;") - - err = seqRow.Scan(&seq) - if err != nil { - return -1, err - } - - var seqItems []int - err = json.Unmarshal(seq, &seqItems) - if err != nil { - return -1, err - } - - newTaskID := make([]int, 1) - newTaskID[0] = int(lastInsertID) - updatedSeqItems := append(newTaskID, seqItems...) - - sequenceJSON, err := json.Marshal(updatedSeqItems) - if err != nil { - return -1, err - } - - seqUpdateStmt, err := tx.Prepare(` -UPDATE task_sequence -SET sequence = ? -WHERE id = 1; -`) - if err != nil { - return -1, err - } - defer seqUpdateStmt.Close() - - _, err = seqUpdateStmt.Exec(sequenceJSON) - if err != nil { - return -1, err - } - - err = tx.Commit() - if err != nil { - return -1, err - } - return lastInsertID, nil -} - func InsertTasks(db *sql.DB, tasks []types.Task, insertAtTop bool) (int64, error) { tx, err := db.Begin() if err != nil { diff --git a/internal/persistence/queries_test.go b/internal/persistence/queries_test.go index 9d0d06a..5241190 100644 --- a/internal/persistence/queries_test.go +++ b/internal/persistence/queries_test.go @@ -127,61 +127,54 @@ WHERE id = 1; return na, ni } -func TestImportTask(t *testing.T) { - t.Cleanup(func() { cleanupDB(t) }) - - // GIVEN - na, _ := seedDB(t, testDB) - - // WHEN - summary := "prefix: an imported task" - now := time.Now().UTC() - lastID, err := ImportTask(testDB, summary, true, now, now) - require.NoError(t, err) - - // THEN - numActiveTasksAfter, err := fetchNumActiveTasks(testDB) - require.NoError(t, err) - assert.Equal(t, numActiveTasksAfter, na+1, "number of active tasks didn't increase by 1") - - task, err := fetchTaskByID(testDB, lastID) - require.NoError(t, err) - assert.True(t, task.Active) - assert.Equal(t, summary, task.Summary) - - seq, err := fetchTaskSequence(testDB) - require.NoError(t, err) - assert.Equal(t, seq, []uint64{6, 1, 2, 3}, "task sequence isn't correct") -} - func TestInsertTasksWorksWithEmptyTaskList(t *testing.T) { t.Cleanup(func() { cleanupDB(t) }) // GIVEN // WHEN - tasks, na, ni := getSampleTasks() + now := time.Now().UTC() + tasks := []types.Task{ + { + Summary: "prefix: new task 1", + Active: true, + CreatedAt: now, + UpdatedAt: now, + }, + { + Summary: "prefix: new inactive task 1", + Active: false, + CreatedAt: now, + UpdatedAt: now, + }, + { + Summary: "prefix: new task 3", + Active: true, + CreatedAt: now, + UpdatedAt: now, + }, + } lastID, err := InsertTasks(testDB, tasks, true) - assert.Equal(t, lastID, int64(na+ni), "last ID is not correct") + assert.Equal(t, lastID, int64(3), "last ID is not correct") require.NoError(t, err) // THEN numActiveRes, err := fetchNumActiveTasks(testDB) require.NoError(t, err) - assert.Equal(t, numActiveRes, na, "number of active tasks didn't increase by the correct amount") + assert.Equal(t, numActiveRes, 2, "number of active tasks didn't increase by the correct amount") numTotalRes, err := fetchNumTotalTasks(testDB) require.NoError(t, err) - assert.Equal(t, numTotalRes, na+ni, "number of total tasks didn't increase by the correct amount") + assert.Equal(t, numTotalRes, 3, "number of total tasks didn't increase by the correct amount") lastTask, err := fetchTaskByID(testDB, lastID) require.NoError(t, err) - assert.Equal(t, tasks[len(tasks)-1].Active, lastTask.Active) - assert.Equal(t, tasks[len(tasks)-1].Summary, lastTask.Summary) - assert.Equal(t, tasks[len(tasks)-1].Context, lastTask.Context) + assert.Equal(t, tasks[2].Active, lastTask.Active) + assert.Equal(t, tasks[2].Summary, lastTask.Summary) + assert.Equal(t, tasks[2].Context, lastTask.Context) seq, err := fetchTaskSequence(testDB) require.NoError(t, err) - assert.Equal(t, seq, []uint64{1, 2, 3}, "task sequence isn't correct") + assert.Equal(t, seq, []uint64{1, 3}, "task sequence isn't correct") } func TestInsertTasksAddsTasksAtTheTop(t *testing.T) { From 26e1cdacb847d22b43aa257cf75ec785aeada3de Mon Sep 17 00:00:00 2001 From: Dhruv Thakur <13575379+dhth@users.noreply.github.com> Date: Sat, 24 Aug 2024 22:35:15 +0200 Subject: [PATCH 5/6] ci: add separate run action --- .github/workflows/build.yml | 13 +------------ .github/workflows/run.yml | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 12 deletions(-) create mode 100644 .github/workflows/run.yml diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 934cbdf..1c91c61 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -17,11 +17,7 @@ env: jobs: build: - name: build - strategy: - matrix: - os: [ubuntu-latest, macos-latest] - runs-on: ${{ matrix.os }} + runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Set up Go @@ -32,10 +28,3 @@ jobs: run: go build -v ./... - name: go test run: go test -v ./... - - name: run omm - run: | - go build . - cat assets/sample-tasks.txt | ./omm import - ./omm 'test: a task' - ./omm tasks - ./.github/scripts/checknumtasks.sh "$(./omm tasks | wc -l | xargs)" 11 diff --git a/.github/workflows/run.yml b/.github/workflows/run.yml new file mode 100644 index 0000000..5d7960d --- /dev/null +++ b/.github/workflows/run.yml @@ -0,0 +1,37 @@ +name: run + +on: + push: + branches: [ "main" ] + pull_request: + paths: + - "go.*" + - "**/*.go" + - ".github/workflows/*.yml" + +permissions: + contents: read + +env: + GO_VERSION: '1.22.5' + +jobs: + run: + strategy: + matrix: + os: [ubuntu-latest, macos-latest] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: ${{ env.GO_VERSION }} + - name: build + run: go build . + - name: run + run: | + cat assets/sample-tasks.txt | ./omm import + ./omm 'test: a task' + ./omm tasks + ./.github/scripts/checknumtasks.sh "$(./omm tasks | wc -l | xargs)" 11 From 776608dcce669f9e3bd2b76dd4a9355467398dd3 Mon Sep 17 00:00:00 2001 From: Dhruv Thakur <13575379+dhth@users.noreply.github.com> Date: Sat, 24 Aug 2024 23:07:17 +0200 Subject: [PATCH 6/6] refactor: error handling --- cmd/root.go | 47 +++++++++++++++++++++++------------------------ 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index d903646..0a8ee41 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -37,24 +37,23 @@ const ( ) var ( - errConfigFileExtIncorrect = errors.New("config file must be a TOML file") - errConfigFileDoesntExist = errors.New("config file does not exist") - errDBFileExtIncorrect = errors.New("db file needs to end with .db") - - errMaxImportLimitExceeded = errors.New("import limit exceeded") - errNothingToImport = errors.New("nothing to import") - - errListDensityIncorrect = errors.New("list density is incorrect; valid values: compact/spacious") - + errCouldntGetHomeDir = errors.New("couldn't get home directory") + errConfigFileExtIncorrect = errors.New("config file must be a TOML file") + errConfigFileDoesntExist = errors.New("config file does not exist") + errDBFileExtIncorrect = errors.New("db file needs to end with .db") + errMaxImportLimitExceeded = errors.New("import limit exceeded") + errNothingToImport = errors.New("nothing to import") + errListDensityIncorrect = errors.New("list density is incorrect; valid values: compact/spacious") errCouldntCreateDBDirectory = errors.New("couldn't create directory for database") errCouldntCreateDB = errors.New("couldn't create database") errCouldntInitializeDB = errors.New("couldn't initialize database") errCouldntOpenDB = errors.New("couldn't open database") + errCouldntSetupGuide = errors.New("couldn't set up guided walkthrough") //go:embed assets/updates.txt updateContents string - reportIssueMsg = fmt.Sprintf("Let %s know about this error via %s.", author, repoIssuesURL) + reportIssueMsg = fmt.Sprintf("This isn't supposed to happen; let %s know about this error via \n%s.", author, repoIssuesURL) maxImportNumMsg = fmt.Sprintf(`A maximum of %d tasks that can be imported at a time. Archive/Delete tasks that are not active using ctrl+d/ctrl+x. @@ -68,14 +67,22 @@ Archive/Delete tasks that are not active using ctrl+d/ctrl+x. func Execute(version string) error { rootCmd, err := NewRootCommand() - - rootCmd.Version = version if err != nil { fmt.Fprintf(os.Stderr, "Error: %s\n", err) - os.Exit(1) + switch { + case errors.Is(err, errCouldntGetHomeDir): + fmt.Printf("\n%s\n", reportIssueMsg) + } + return err } + rootCmd.Version = version - return rootCmd.Execute() + err = rootCmd.Execute() + switch { + case errors.Is(err, errCouldntSetupGuide): + fmt.Printf("\n%s\n", reportIssueMsg) + } + return err } func setupDB(dbPathFull string) (*sql.DB, error) { @@ -360,10 +367,7 @@ Sorry for breaking the upgrade step! PreRunE: func(_ *cobra.Command, _ []string) error { guideErr := insertGuideTasks(db) if guideErr != nil { - return fmt.Errorf(`Failed to set up a guided walkthrough. -%s - -Error: %w`, reportIssueMsg, guideErr) + return fmt.Errorf("%w: %s", errCouldntSetupGuide, guideErr.Error()) } return nil @@ -405,12 +409,7 @@ Error: %w`, reportIssueMsg, guideErr) var configPathAdditionalCxt, dbPathAdditionalCxt string hd, err := os.UserHomeDir() if err != nil { - return nil, fmt.Errorf(`Couldn't get your home directory. This is a fatal error; -use --dbpath to specify database path manually - -%s - -Error: %w`, reportIssueMsg, err) + return nil, fmt.Errorf("%w: %s", errCouldntGetHomeDir, err.Error()) } switch ros {