From 40db5520106361eb2ae6c051323b7f9beceb95b9 Mon Sep 17 00:00:00 2001 From: Dhruv Thakur <13575379+dhth@users.noreply.github.com> Date: Sun, 25 Aug 2024 02:39:04 +0530 Subject: [PATCH] test: add database tests (#35) --- .github/workflows/build.yml | 13 +- .github/workflows/run.yml | 37 +++ cmd/db.go | 34 --- cmd/guide.go | 2 +- cmd/import.go | 26 +- cmd/root.go | 59 ++-- cmd/tasks.go | 2 +- internal/persistence/init.go | 39 +++ .../persistence/migrations.go | 20 +- .../persistence/migrations_test.go | 2 +- internal/persistence/queries.go | 244 ++++++---------- internal/persistence/queries_test.go | 260 ++++++++++++++++++ internal/types/types.go | 1 - internal/ui/cmds.go | 18 +- 14 files changed, 493 insertions(+), 264 deletions(-) create mode 100644 .github/workflows/run.yml create mode 100644 internal/persistence/init.go rename cmd/db_migrations.go => internal/persistence/migrations.go (78%) rename cmd/db_migrations_test.go => internal/persistence/migrations_test.go (93%) create mode 100644 internal/persistence/queries_test.go diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 934cbdf..1c91c61 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -17,11 +17,7 @@ env: jobs: build: - name: build - strategy: - matrix: - os: [ubuntu-latest, macos-latest] - runs-on: ${{ matrix.os }} + runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Set up Go @@ -32,10 +28,3 @@ jobs: run: go build -v ./... - name: go test run: go test -v ./... - - name: run omm - run: | - go build . - cat assets/sample-tasks.txt | ./omm import - ./omm 'test: a task' - ./omm tasks - ./.github/scripts/checknumtasks.sh "$(./omm tasks | wc -l | xargs)" 11 diff --git a/.github/workflows/run.yml b/.github/workflows/run.yml new file mode 100644 index 0000000..5d7960d --- /dev/null +++ b/.github/workflows/run.yml @@ -0,0 +1,37 @@ +name: run + +on: + push: + branches: [ "main" ] + pull_request: + paths: + - "go.*" + - "**/*.go" + - ".github/workflows/*.yml" + +permissions: + contents: read + +env: + GO_VERSION: '1.22.5' + +jobs: + run: + strategy: + matrix: + os: [ubuntu-latest, macos-latest] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: ${{ env.GO_VERSION }} + - name: build + run: go build . + - name: run + run: | + cat assets/sample-tasks.txt | ./omm import + ./omm 'test: a task' + ./omm tasks + ./.github/scripts/checknumtasks.sh "$(./omm tasks | wc -l | xargs)" 11 diff --git a/cmd/db.go b/cmd/db.go index ac102df..63d0770 100644 --- a/cmd/db.go +++ b/cmd/db.go @@ -2,7 +2,6 @@ package cmd import ( "database/sql" - "time" ) func getDB(dbpath string) (*sql.DB, error) { @@ -14,36 +13,3 @@ func getDB(dbpath string) (*sql.DB, error) { db.SetMaxIdleConns(1) return db, err } - -func initDB(db *sql.DB) error { - // these init queries cannot be changed once omm is released; only further - // migrations can be added, which are run when omm sees a difference between - // the values in the db_versions table and latestDBVersion - _, err := db.Exec(` -CREATE TABLE IF NOT EXISTS db_versions ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - version INTEGER NOT NULL, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP -); - -CREATE TABLE IF NOT EXISTS task ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - summary TEXT NOT NULL, - active BOOLEAN NOT NULL DEFAULT true, - created_at TIMESTAMP NOT NULL, - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP -); - -CREATE TABLE task_sequence ( - id INTEGER PRIMARY KEY, - sequence JSON NOT NULL -); - -INSERT INTO task_sequence (id, sequence) VALUES (1, '[]'); - -INSERT INTO db_versions (version, created_at) -VALUES (1, ?); -`, time.Now().UTC()) - - return err -} diff --git a/cmd/guide.go b/cmd/guide.go index 77dbe9c..052cb6a 100644 --- a/cmd/guide.go +++ b/cmd/guide.go @@ -90,7 +90,7 @@ func insertGuideTasks(db *sql.DB) error { } } - err = pers.InsertTasksIntoDB(db, tasks) + _, err = pers.InsertTasks(db, tasks, true) return err } diff --git a/cmd/import.go b/cmd/import.go index fa9874d..375949d 100644 --- a/cmd/import.go +++ b/cmd/import.go @@ -7,12 +7,13 @@ import ( "time" pers "github.com/dhth/omm/internal/persistence" + "github.com/dhth/omm/internal/types" ) var errWillExceedCapacity = errors.New("import will exceed capacity") func importTask(db *sql.DB, taskSummary string) error { - numTasks, err := pers.FetchNumActiveTasksFromDB(db) + numTasks, err := pers.FetchNumActiveTasksShown(db) if err != nil { return err } @@ -21,11 +22,18 @@ func importTask(db *sql.DB, taskSummary string) error { } now := time.Now() - return pers.ImportTaskIntoDB(db, taskSummary, true, now, now) + task := types.Task{ + Summary: taskSummary, + Active: true, + CreatedAt: now, + UpdatedAt: now, + } + _, err = pers.InsertTasks(db, []types.Task{task}, true) + return err } func importTasks(db *sql.DB, taskSummaries []string) error { - numTasks, err := pers.FetchNumActiveTasksFromDB(db) + numTasks, err := pers.FetchNumActiveTasksShown(db) if err != nil { return err } @@ -34,5 +42,15 @@ func importTasks(db *sql.DB, taskSummaries []string) error { } now := time.Now() - return pers.ImportTaskSummariesIntoDB(db, taskSummaries, true, now, now) + tasks := make([]types.Task, len(taskSummaries)) + for i, summ := range taskSummaries { + tasks[i] = types.Task{ + Summary: summ, + Active: true, + CreatedAt: now, + UpdatedAt: now, + } + } + _, err = pers.InsertTasks(db, tasks, true) + return err } diff --git a/cmd/root.go b/cmd/root.go index bddf99e..0a8ee41 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -37,24 +37,23 @@ const ( ) var ( - errConfigFileExtIncorrect = errors.New("config file must be a TOML file") - errConfigFileDoesntExist = errors.New("config file does not exist") - errDBFileExtIncorrect = errors.New("db file needs to end with .db") - - errMaxImportLimitExceeded = errors.New("import limit exceeded") - errNothingToImport = errors.New("nothing to import") - - errListDensityIncorrect = errors.New("list density is incorrect; valid values: compact/spacious") - + errCouldntGetHomeDir = errors.New("couldn't get home directory") + errConfigFileExtIncorrect = errors.New("config file must be a TOML file") + errConfigFileDoesntExist = errors.New("config file does not exist") + errDBFileExtIncorrect = errors.New("db file needs to end with .db") + errMaxImportLimitExceeded = errors.New("import limit exceeded") + errNothingToImport = errors.New("nothing to import") + errListDensityIncorrect = errors.New("list density is incorrect; valid values: compact/spacious") errCouldntCreateDBDirectory = errors.New("couldn't create directory for database") errCouldntCreateDB = errors.New("couldn't create database") errCouldntInitializeDB = errors.New("couldn't initialize database") errCouldntOpenDB = errors.New("couldn't open database") + errCouldntSetupGuide = errors.New("couldn't set up guided walkthrough") //go:embed assets/updates.txt updateContents string - reportIssueMsg = fmt.Sprintf("Let %s know about this error via %s.", author, repoIssuesURL) + reportIssueMsg = fmt.Sprintf("This isn't supposed to happen; let %s know about this error via \n%s.", author, repoIssuesURL) maxImportNumMsg = fmt.Sprintf(`A maximum of %d tasks that can be imported at a time. Archive/Delete tasks that are not active using ctrl+d/ctrl+x. @@ -68,14 +67,22 @@ Archive/Delete tasks that are not active using ctrl+d/ctrl+x. func Execute(version string) error { rootCmd, err := NewRootCommand() - - rootCmd.Version = version if err != nil { fmt.Fprintf(os.Stderr, "Error: %s\n", err) - os.Exit(1) + switch { + case errors.Is(err, errCouldntGetHomeDir): + fmt.Printf("\n%s\n", reportIssueMsg) + } + return err } + rootCmd.Version = version - return rootCmd.Execute() + err = rootCmd.Execute() + switch { + case errors.Is(err, errCouldntSetupGuide): + fmt.Printf("\n%s\n", reportIssueMsg) + } + return err } func setupDB(dbPathFull string) (*sql.DB, error) { @@ -96,11 +103,11 @@ func setupDB(dbPathFull string) (*sql.DB, error) { return nil, fmt.Errorf("%w: %s", errCouldntCreateDB, err.Error()) } - err = initDB(db) + err = pers.InitDB(db) if err != nil { return nil, fmt.Errorf("%w: %s", errCouldntInitializeDB, err.Error()) } - err = upgradeDB(db, 1) + err = pers.UpgradeDB(db, 1) if err != nil { return nil, err } @@ -109,7 +116,7 @@ func setupDB(dbPathFull string) (*sql.DB, error) { if err != nil { return nil, fmt.Errorf("%w: %s", errCouldntOpenDB, err.Error()) } - err = upgradeDBIfNeeded(db) + err = pers.UpgradeDBIfNeeded(db) if err != nil { return nil, err } @@ -213,19 +220,19 @@ Clean up error: %s %s `, reportIssueMsg) - case errors.Is(err, errCouldntFetchDBVersion): + case errors.Is(err, pers.ErrCouldntFetchDBVersion): fmt.Fprintf(os.Stderr, `Couldn't get omm's latest database version. This is a fatal error. %s `, reportIssueMsg) - case errors.Is(err, errDBDowngraded): + case errors.Is(err, pers.ErrDBDowngraded): fmt.Fprintf(os.Stderr, `Looks like you downgraded omm. You should either delete omm's database file (you will lose data by doing that), or upgrade omm to the latest version. %s `, reportIssueMsg) - case errors.Is(err, errDBMigrationFailed): + case errors.Is(err, pers.ErrDBMigrationFailed): fmt.Fprintf(os.Stderr, `Something went wrong migrating omm's database. This is not supposed to happen. You can try running omm by passing it a custom database file path (using --db-path; this will create a new database) to see if that fixes things. If that @@ -360,10 +367,7 @@ Sorry for breaking the upgrade step! PreRunE: func(_ *cobra.Command, _ []string) error { guideErr := insertGuideTasks(db) if guideErr != nil { - return fmt.Errorf(`Failed to set up a guided walkthrough. -%s - -Error: %w`, reportIssueMsg, guideErr) + return fmt.Errorf("%w: %s", errCouldntSetupGuide, guideErr.Error()) } return nil @@ -405,12 +409,7 @@ Error: %w`, reportIssueMsg, guideErr) var configPathAdditionalCxt, dbPathAdditionalCxt string hd, err := os.UserHomeDir() if err != nil { - return nil, fmt.Errorf(`Couldn't get your home directory. This is a fatal error; -use --dbpath to specify database path manually - -%s - -Error: %w`, reportIssueMsg, err) + return nil, fmt.Errorf("%w: %s", errCouldntGetHomeDir, err.Error()) } switch ros { diff --git a/cmd/tasks.go b/cmd/tasks.go index 575a19d..9de3f01 100644 --- a/cmd/tasks.go +++ b/cmd/tasks.go @@ -9,7 +9,7 @@ import ( ) func printTasks(db *sql.DB, limit uint8, writer io.Writer) error { - tasks, err := pers.FetchActiveTasksFromDB(db, int(limit)) + tasks, err := pers.FetchActiveTasks(db, int(limit)) if err != nil { return err } diff --git a/internal/persistence/init.go b/internal/persistence/init.go new file mode 100644 index 0000000..b40acdc --- /dev/null +++ b/internal/persistence/init.go @@ -0,0 +1,39 @@ +package persistence + +import ( + "database/sql" + "time" +) + +func InitDB(db *sql.DB) error { + // these init queries cannot be changed once omm is released; only further + // migrations can be added, which are run when omm sees a difference between + // the values in the db_versions table and latestDBVersion + _, err := db.Exec(` +CREATE TABLE IF NOT EXISTS db_versions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + version INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS task ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + summary TEXT NOT NULL, + active BOOLEAN NOT NULL DEFAULT true, + created_at TIMESTAMP NOT NULL, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE task_sequence ( + id INTEGER PRIMARY KEY, + sequence JSON NOT NULL +); + +INSERT INTO task_sequence (id, sequence) VALUES (1, '[]'); + +INSERT INTO db_versions (version, created_at) +VALUES (1, ?); +`, time.Now().UTC()) + + return err +} diff --git a/cmd/db_migrations.go b/internal/persistence/migrations.go similarity index 78% rename from cmd/db_migrations.go rename to internal/persistence/migrations.go index f59c2e7..62b3a4a 100644 --- a/cmd/db_migrations.go +++ b/internal/persistence/migrations.go @@ -1,4 +1,4 @@ -package cmd +package persistence import ( "database/sql" @@ -12,9 +12,9 @@ const ( ) var ( - errDBDowngraded = errors.New("database downgraded") - errDBMigrationFailed = errors.New("database migration failed") - errCouldntFetchDBVersion = errors.New("couldn't fetch version") + ErrDBDowngraded = errors.New("database downgraded") + ErrDBMigrationFailed = errors.New("database migration failed") + ErrCouldntFetchDBVersion = errors.New("couldn't fetch version") ) type dbVersionInfo struct { @@ -54,18 +54,18 @@ LIMIT 1; return dbVersion, err } -func upgradeDBIfNeeded(db *sql.DB) error { +func UpgradeDBIfNeeded(db *sql.DB) error { latestVersionInDB, err := fetchLatestDBVersion(db) if err != nil { - return fmt.Errorf("%w: %s", errCouldntFetchDBVersion, err.Error()) + return fmt.Errorf("%w: %s", ErrCouldntFetchDBVersion, err.Error()) } if latestVersionInDB.version > latestDBVersion { - return errDBDowngraded + return ErrDBDowngraded } if latestVersionInDB.version < latestDBVersion { - err = upgradeDB(db, latestVersionInDB.version) + err = UpgradeDB(db, latestVersionInDB.version) if err != nil { return err } @@ -74,13 +74,13 @@ func upgradeDBIfNeeded(db *sql.DB) error { return nil } -func upgradeDB(db *sql.DB, currentVersion int) error { +func UpgradeDB(db *sql.DB, currentVersion int) error { migrations := getMigrations() for i := currentVersion + 1; i <= latestDBVersion; i++ { migrateQuery := migrations[i] migrateErr := runMigration(db, migrateQuery, i) if migrateErr != nil { - return fmt.Errorf("%w (version %d): %v", errDBMigrationFailed, i, migrateErr.Error()) + return fmt.Errorf("%w (version %d): %v", ErrDBMigrationFailed, i, migrateErr.Error()) } } return nil diff --git a/cmd/db_migrations_test.go b/internal/persistence/migrations_test.go similarity index 93% rename from cmd/db_migrations_test.go rename to internal/persistence/migrations_test.go index 03c40fe..6a75da2 100644 --- a/cmd/db_migrations_test.go +++ b/internal/persistence/migrations_test.go @@ -1,4 +1,4 @@ -package cmd +package persistence import ( "testing" diff --git a/internal/persistence/queries.go b/internal/persistence/queries.go index 60ea6ea..aaf84b7 100644 --- a/internal/persistence/queries.go +++ b/internal/persistence/queries.go @@ -13,7 +13,53 @@ const ( ContextMaxBytes = 4096 // 4KB seems to be sufficient for context ) -func FetchNumActiveTasksFromDB(db *sql.DB) (int, error) { +func fetchTaskSequence(db *sql.DB) ([]uint64, error) { + var seq []byte + seqRow := db.QueryRow("SELECT sequence from task_sequence where id=1;") + + err := seqRow.Scan(&seq) + if err != nil { + return nil, err + } + + var seqItems []uint64 + err = json.Unmarshal(seq, &seqItems) + if err != nil { + return nil, err + } + return seqItems, nil +} + +func fetchNumActiveTasks(db *sql.DB) (int, error) { + var rowCount int + err := db.QueryRow("SELECT count(*) from task where active is true").Scan(&rowCount) + return rowCount, err +} + +func fetchNumTotalTasks(db *sql.DB) (int, error) { + var rowCount int + err := db.QueryRow("SELECT count(*) from task").Scan(&rowCount) + return rowCount, err +} + +func fetchTaskByID(db *sql.DB, ID int64) (types.Task, error) { + var entry types.Task + row := db.QueryRow(` +SELECT id, summary, active, context, created_at, updated_at +from task +WHERE id=?; +`, ID) + err := row.Scan(&entry.ID, + &entry.Summary, + &entry.Active, + &entry.Context, + &entry.CreatedAt, + &entry.UpdatedAt, + ) + return entry, err +} + +func FetchNumActiveTasksShown(db *sql.DB) (int, error) { row := db.QueryRow(` SELECT json_array_length(sequence) AS num_tasks FROM task_sequence where id=1; @@ -28,7 +74,7 @@ FROM task_sequence where id=1; return numTasks, nil } -func UpdateTaskSequenceInDB(db *sql.DB, sequence []uint64) error { +func UpdateTaskSequence(db *sql.DB, sequence []uint64) error { sequenceJSON, err := json.Marshal(sequence) if err != nil { return err @@ -52,7 +98,7 @@ WHERE id = 1; return nil } -func InsertTaskInDB(db *sql.DB, summary string, createdAt, updatedAt time.Time) (uint64, error) { +func InsertTask(db *sql.DB, summary string, createdAt, updatedAt time.Time) (uint64, error) { stmt, err := db.Prepare(` INSERT INTO task (summary, active, created_at, updated_at) VALUES (?, true, ?, ?); @@ -75,107 +121,38 @@ VALUES (?, true, ?, ?); return uint64(li), nil } -func ImportTaskIntoDB(db *sql.DB, summary string, active bool, createdAt, updatedAt time.Time) error { - tx, err := db.Begin() - if err != nil { - return err - } - defer func() { - _ = tx.Rollback() - }() - - query := `INSERT INTO task (summary, active, created_at, updated_at) -VALUES (?, ?, ?, ?);` - - res, err := tx.Exec(query, summary, active, createdAt.UTC(), updatedAt.UTC()) - if err != nil { - return err - } - - lastInsertID, err := res.LastInsertId() - if err != nil { - return err - } - - var seq []byte - seqRow := tx.QueryRow("SELECT sequence from task_sequence where id=1;") - - err = seqRow.Scan(&seq) - if err != nil { - return err - } - - var seqItems []int - err = json.Unmarshal(seq, &seqItems) - if err != nil { - return err - } - - newTaskID := make([]int, 1) - newTaskID[0] = int(lastInsertID) - updatedSeqItems := append(newTaskID, seqItems...) - - sequenceJSON, err := json.Marshal(updatedSeqItems) - if err != nil { - return err - } - - seqUpdateStmt, err := tx.Prepare(` -UPDATE task_sequence -SET sequence = ? -WHERE id = 1; -`) - if err != nil { - return err - } - defer seqUpdateStmt.Close() - - _, err = seqUpdateStmt.Exec(sequenceJSON) - if err != nil { - return err - } - - err = tx.Commit() - if err != nil { - return err - } - return nil -} - -func ImportTaskSummariesIntoDB(db *sql.DB, tasks []string, active bool, createdAt, updatedAt time.Time) error { +func InsertTasks(db *sql.DB, tasks []types.Task, insertAtTop bool) (int64, error) { tx, err := db.Begin() if err != nil { - return err + return -1, err } defer func() { _ = tx.Rollback() }() - query := `INSERT INTO task (summary, active, created_at, updated_at) + query := `INSERT INTO task (summary, context, active, created_at, updated_at) VALUES ` values := make([]interface{}, 0, len(tasks)*4) - ca := createdAt.UTC() - ua := updatedAt.UTC() - - for i, ts := range tasks { + for i, t := range tasks { if i > 0 { query += "," } - query += "(?, ?, ?, ?)" - values = append(values, ts, active, ca, ua) + query += "(?, ?, ?, ?, ?)" + values = append(values, t.Summary, t.Context, t.Active, t.CreatedAt.UTC(), t.UpdatedAt.UTC()) } + query += ";" res, err := tx.Exec(query, values...) if err != nil { - return err + return -1, err } lastInsertID, err := res.LastInsertId() if err != nil { - return err + return -1, err } var seq []byte @@ -183,89 +160,34 @@ VALUES ` err = seqRow.Scan(&seq) if err != nil { - return err + return -1, err } var seqItems []int err = json.Unmarshal(seq, &seqItems) if err != nil { - return err - } - - newTaskIDs := make([]int, len(tasks)) - counter := 0 - for i := int(lastInsertID) - len(tasks) + 1; i <= int(lastInsertID); i++ { - newTaskIDs[counter] = i - counter++ - } - updatedSeqItems := append(newTaskIDs, seqItems...) - - sequenceJSON, err := json.Marshal(updatedSeqItems) - if err != nil { - return err - } - - seqUpdateStmt, err := tx.Prepare(` -UPDATE task_sequence -SET sequence = ? -WHERE id = 1; -`) - if err != nil { - return err - } - defer seqUpdateStmt.Close() - - _, err = seqUpdateStmt.Exec(sequenceJSON) - if err != nil { - return err - } - - err = tx.Commit() - if err != nil { - return err - } - return nil -} - -func InsertTasksIntoDB(db *sql.DB, tasks []types.Task) error { - tx, err := db.Begin() - if err != nil { - return err + return -1, err } - defer func() { - _ = tx.Rollback() - }() - - query := `INSERT INTO task (summary, context, active, created_at, updated_at) -VALUES ` - - values := make([]interface{}, 0, len(tasks)*4) - - var seqItems []int - seqCounter := 1 - for i, t := range tasks { - if i > 0 { - query += "," - } - query += "(?, ?, ?, ?, ?)" - values = append(values, t.Summary, t.Context, t.Active, t.CreatedAt.UTC(), t.UpdatedAt.UTC()) + var newTaskIDs []int + taskID := int(lastInsertID) - len(tasks) + 1 + for _, t := range tasks { if t.Active { - seqItems = append(seqItems, seqCounter) + newTaskIDs = append(newTaskIDs, taskID) } - seqCounter++ + taskID++ } - query += ";" - - _, err = tx.Exec(query, values...) - if err != nil { - return err + var updatedSeqItems []int + if insertAtTop { + updatedSeqItems = append(newTaskIDs, seqItems...) + } else { + updatedSeqItems = append(seqItems, newTaskIDs...) } - sequenceJSON, err := json.Marshal(seqItems) + sequenceJSON, err := json.Marshal(updatedSeqItems) if err != nil { - return err + return -1, err } seqUpdateStmt, err := tx.Prepare(` @@ -274,23 +196,23 @@ SET sequence = ? WHERE id = 1; `) if err != nil { - return err + return -1, err } defer seqUpdateStmt.Close() _, err = seqUpdateStmt.Exec(sequenceJSON) if err != nil { - return err + return -1, err } err = tx.Commit() if err != nil { - return err + return -1, err } - return nil + return lastInsertID, nil } -func UpdateTaskSummaryInDB(db *sql.DB, id uint64, summary string, updatedAt time.Time) error { +func UpdateTaskSummary(db *sql.DB, id uint64, summary string, updatedAt time.Time) error { stmt, err := db.Prepare(` UPDATE task SET summary = ?, @@ -309,7 +231,7 @@ WHERE id = ? return nil } -func UpdateTaskContextInDB(db *sql.DB, id uint64, context string, updatedAt time.Time) error { +func UpdateTaskContext(db *sql.DB, id uint64, context string, updatedAt time.Time) error { stmt, err := db.Prepare(` UPDATE task SET context = ?, @@ -328,7 +250,7 @@ WHERE id = ? return nil } -func UnsetTaskContextInDB(db *sql.DB, id uint64, updatedAt time.Time) error { +func UnsetTaskContext(db *sql.DB, id uint64, updatedAt time.Time) error { stmt, err := db.Prepare(` UPDATE task SET context = NULL, @@ -347,7 +269,7 @@ WHERE id = ? return nil } -func ChangeTaskStatusInDB(db *sql.DB, id uint64, active bool, updatedAt time.Time) error { +func ChangeTaskStatus(db *sql.DB, id uint64, active bool, updatedAt time.Time) error { stmt, err := db.Prepare(` UPDATE task SET active = ?, @@ -366,7 +288,7 @@ WHERE id = ? return nil } -func FetchActiveTasksFromDB(db *sql.DB, limit int) ([]types.Task, error) { +func FetchActiveTasks(db *sql.DB, limit int) ([]types.Task, error) { var tasks []types.Task rows, err := db.Query(` @@ -407,7 +329,7 @@ LIMIT ?; return tasks, nil } -func FetchInActiveTasksFromDB(db *sql.DB, limit int) ([]types.Task, error) { +func FetchInActiveTasks(db *sql.DB, limit int) ([]types.Task, error) { var tasks []types.Task rows, err := db.Query(` @@ -446,7 +368,7 @@ LIMIT ?; return tasks, nil } -func DeleteTaskInDB(db *sql.DB, id uint64) error { +func DeleteTask(db *sql.DB, id uint64) error { stmt, err := db.Prepare(` DELETE from task WHERE id=?; diff --git a/internal/persistence/queries_test.go b/internal/persistence/queries_test.go new file mode 100644 index 0000000..5241190 --- /dev/null +++ b/internal/persistence/queries_test.go @@ -0,0 +1,260 @@ +package persistence + +import ( + "database/sql" + "encoding/json" + "fmt" + "os" + "testing" + "time" + + "github.com/dhth/omm/internal/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + _ "modernc.org/sqlite" // sqlite driver +) + +var testDB *sql.DB + +func TestMain(m *testing.M) { + var err error + testDB, err = sql.Open("sqlite", ":memory:") + if err != nil { + panic(err) + } + + err = InitDB(testDB) + if err != nil { + panic(err) + } + err = UpgradeDB(testDB, 1) + if err != nil { + panic(err) + } + code := m.Run() + + testDB.Close() + + os.Exit(code) +} + +func cleanupDB(t *testing.T) { + var err error + for _, tbl := range []string{"task"} { + _, err = testDB.Exec(fmt.Sprintf("DELETE FROM %s", tbl)) + if err != nil { + t.Fatalf("failed to clean up table %q: %v", tbl, err) + } + _, err := testDB.Exec("DELETE FROM sqlite_sequence WHERE name=?;", tbl) + if err != nil { + t.Fatalf("failed to reset auto increment for table %q: %v", tbl, err) + } + } + _, err = testDB.Exec(`UPDATE task_sequence +SET sequence = '[]' +WHERE id = 1;`) + if err != nil { + t.Fatalf("failed to clean up table task_sequence: %v", err) + } +} + +func getSampleTasks() ([]types.Task, int, int) { + numActive := 3 + numInactive := 2 + + tasks := make([]types.Task, numActive+numInactive) + contexts := make([]string, numActive+numInactive) + now := time.Now().UTC() + counter := 0 + for range numActive { + contexts[counter] = fmt.Sprintf("context for task %d", counter) + tasks[counter] = types.Task{ + Summary: fmt.Sprintf("prefix: task %d", counter), + Active: true, + Context: &contexts[counter], + CreatedAt: now, + UpdatedAt: now, + } + counter++ + } + for range numInactive { + contexts[counter] = fmt.Sprintf("context for task %d", counter) + tasks[counter] = types.Task{ + Summary: fmt.Sprintf("prefix: task %d", counter), + Active: false, + Context: &contexts[counter], + CreatedAt: now, + UpdatedAt: now, + } + counter++ + } + + return tasks, numActive, numInactive +} + +func seedDB(t *testing.T, db *sql.DB) (int, int) { + t.Helper() + + tasks, na, ni := getSampleTasks() + + for _, task := range tasks { + _, err := db.Exec(` +INSERT INTO task (summary, active, created_at, updated_at) +VALUES (?, ?, ?, ?)`, task.Summary, task.Active, task.CreatedAt, task.UpdatedAt) + if err != nil { + t.Fatalf("failed to insert data into table \"task\": %v", err) + } + } + + seqItems := make([]int, na) + for i := range na { + seqItems[i] = i + 1 + } + sequenceJSON, err := json.Marshal(seqItems) + if err != nil { + t.Fatalf("failed to marshall JSON data for seeding: %v", err) + } + + _, err = db.Exec(` +UPDATE task_sequence +SET sequence = ? +WHERE id = 1; +`, sequenceJSON) + if err != nil { + t.Fatalf("failed to insert data into table \"task_sequence\": %v", err) + } + + return na, ni +} + +func TestInsertTasksWorksWithEmptyTaskList(t *testing.T) { + t.Cleanup(func() { cleanupDB(t) }) + + // GIVEN + // WHEN + now := time.Now().UTC() + tasks := []types.Task{ + { + Summary: "prefix: new task 1", + Active: true, + CreatedAt: now, + UpdatedAt: now, + }, + { + Summary: "prefix: new inactive task 1", + Active: false, + CreatedAt: now, + UpdatedAt: now, + }, + { + Summary: "prefix: new task 3", + Active: true, + CreatedAt: now, + UpdatedAt: now, + }, + } + lastID, err := InsertTasks(testDB, tasks, true) + assert.Equal(t, lastID, int64(3), "last ID is not correct") + require.NoError(t, err) + + // THEN + numActiveRes, err := fetchNumActiveTasks(testDB) + require.NoError(t, err) + assert.Equal(t, numActiveRes, 2, "number of active tasks didn't increase by the correct amount") + + numTotalRes, err := fetchNumTotalTasks(testDB) + require.NoError(t, err) + assert.Equal(t, numTotalRes, 3, "number of total tasks didn't increase by the correct amount") + + lastTask, err := fetchTaskByID(testDB, lastID) + require.NoError(t, err) + assert.Equal(t, tasks[2].Active, lastTask.Active) + assert.Equal(t, tasks[2].Summary, lastTask.Summary) + assert.Equal(t, tasks[2].Context, lastTask.Context) + + seq, err := fetchTaskSequence(testDB) + require.NoError(t, err) + assert.Equal(t, seq, []uint64{1, 3}, "task sequence isn't correct") +} + +func TestInsertTasksAddsTasksAtTheTop(t *testing.T) { + t.Cleanup(func() { cleanupDB(t) }) + + // GIVEN + na, ni := seedDB(t, testDB) + + // WHEN + now := time.Now().UTC() + tasks := []types.Task{ + { + Summary: "prefix: new task 1", + Active: true, + CreatedAt: now, + UpdatedAt: now, + }, + { + Summary: "prefix: new inactive task 1", + Active: false, + CreatedAt: now, + UpdatedAt: now, + }, + { + Summary: "prefix: new task 3", + Active: true, + CreatedAt: now, + UpdatedAt: now, + }, + } + + _, err := InsertTasks(testDB, tasks, true) + require.NoError(t, err) + + // THEN + numActiveRes, err := fetchNumActiveTasks(testDB) + require.NoError(t, err) + assert.Equal(t, numActiveRes, na+2, "number of active tasks didn't increase by the correct amount") + + numTotalRes, err := fetchNumTotalTasks(testDB) + require.NoError(t, err) + assert.Equal(t, numTotalRes, na+ni+3, "number of total tasks didn't increase by the correct amount") + + seq, err := fetchTaskSequence(testDB) + require.NoError(t, err) + assert.Equal(t, seq, []uint64{6, 8, 1, 2, 3}, "task sequence isn't correct") +} + +func TestInsertTasksAddsTasksAtTheEnd(t *testing.T) { + t.Cleanup(func() { cleanupDB(t) }) + + // GIVEN + na, _ := seedDB(t, testDB) + + // WHEN + now := time.Now().UTC() + tasks := []types.Task{ + { + Summary: "prefix: new task 1", + Active: true, + CreatedAt: now, + UpdatedAt: now, + }, + { + Summary: "prefix: new task 2", + Active: true, + CreatedAt: now, + UpdatedAt: now, + }, + } + + _, err := InsertTasks(testDB, tasks, false) + require.NoError(t, err) + + // THEN + numActiveRes, err := fetchNumActiveTasks(testDB) + require.NoError(t, err) + assert.Equal(t, numActiveRes, na+2, "number of active tasks didn't increase by the correct amount") + + seq, err := fetchTaskSequence(testDB) + require.NoError(t, err) + assert.Equal(t, seq, []uint64{1, 2, 3, 6, 7}, "task sequence isn't correct") +} diff --git a/internal/types/types.go b/internal/types/types.go index b3a7729..9d575ec 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -39,7 +39,6 @@ var ( ) type Task struct { - ItemTitle string ID uint64 Summary string Context *string diff --git a/internal/ui/cmds.go b/internal/ui/cmds.go index da222fb..f79083e 100644 --- a/internal/ui/cmds.go +++ b/internal/ui/cmds.go @@ -21,21 +21,21 @@ func hideHelp(interval time.Duration) tea.Cmd { func updateTaskSequence(db *sql.DB, sequence []uint64) tea.Cmd { return func() tea.Msg { - err := pers.UpdateTaskSequenceInDB(db, sequence) + err := pers.UpdateTaskSequence(db, sequence) return taskSequenceUpdatedMsg{err} } } func createTask(db *sql.DB, summary string, createdAt, updatedAt time.Time) tea.Cmd { return func() tea.Msg { - id, err := pers.InsertTaskInDB(db, summary, createdAt, updatedAt) + id, err := pers.InsertTask(db, summary, createdAt, updatedAt) return taskCreatedMsg{id, summary, createdAt, updatedAt, err} } } func deleteTask(db *sql.DB, id uint64, index int, active bool) tea.Cmd { return func() tea.Msg { - err := pers.DeleteTaskInDB(db, id) + err := pers.DeleteTask(db, id) return taskDeletedMsg{id, index, active, err} } } @@ -43,7 +43,7 @@ func deleteTask(db *sql.DB, id uint64, index int, active bool) tea.Cmd { func updateTaskSummary(db *sql.DB, listIndex int, id uint64, summary string) tea.Cmd { return func() tea.Msg { now := time.Now() - err := pers.UpdateTaskSummaryInDB(db, id, summary, now) + err := pers.UpdateTaskSummary(db, id, summary, now) return taskSummaryUpdatedMsg{listIndex, id, summary, now, err} } } @@ -53,9 +53,9 @@ func updateTaskContext(db *sql.DB, listIndex int, id uint64, context string, lis var err error now := time.Now() if context == "" { - err = pers.UnsetTaskContextInDB(db, id, now) + err = pers.UnsetTaskContext(db, id, now) } else { - err = pers.UpdateTaskContextInDB(db, id, context, now) + err = pers.UpdateTaskContext(db, id, context, now) } return taskContextUpdatedMsg{listIndex, list, id, context, now, err} } @@ -63,7 +63,7 @@ func updateTaskContext(db *sql.DB, listIndex int, id uint64, context string, lis func changeTaskStatus(db *sql.DB, listIndex int, id uint64, active bool, updatedAt time.Time) tea.Cmd { return func() tea.Msg { - err := pers.ChangeTaskStatusInDB(db, id, active, updatedAt) + err := pers.ChangeTaskStatus(db, id, active, updatedAt) return taskStatusChangedMsg{listIndex, id, active, updatedAt, err} } } @@ -74,9 +74,9 @@ func fetchTasks(db *sql.DB, active bool, limit int) tea.Cmd { var err error switch active { case true: - tasks, err = pers.FetchActiveTasksFromDB(db, limit) + tasks, err = pers.FetchActiveTasks(db, limit) case false: - tasks, err = pers.FetchInActiveTasksFromDB(db, limit) + tasks, err = pers.FetchInActiveTasks(db, limit) } return tasksFetched{tasks, active, err} }