From 927e71d2e6245acb628194362f4ef4d982590a75 Mon Sep 17 00:00:00 2001 From: CHEN YI YOU Date: Sun, 1 Dec 2024 15:04:25 +0800 Subject: [PATCH] Feat/get acc info from monthly trans (#184) * feat: add interface and mock * feat: add IsSameMonth * feat: get acc info from monthly trans * test: modify testing * fix: remove checking time range * feat: get acc info when no monthly trans --------- Co-authored-by: Eyo Chen --- cmd/api/main.go | 2 +- internal/domain/time.go | 18 ++ internal/handler/interfaces/interfaces.go | 2 +- internal/handler/transaction/transaction.go | 5 +- internal/usecase/interfaces/interfaces.go | 3 + internal/usecase/transaction/transaction.go | 54 ++++- .../usecase/transaction/transaction_test.go | 220 ++++++++++++++---- internal/usecase/usecase.go | 19 +- mocks/MonthlyTransRepo.go | 28 +++ mocks/TransactionUC.go | 18 +- pkg/validator/transaction.go | 1 - 11 files changed, 286 insertions(+), 84 deletions(-) create mode 100644 internal/domain/time.go diff --git a/cmd/api/main.go b/cmd/api/main.go index 1db195d..cdee458 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -56,7 +56,7 @@ func main() { // Setup adapter, usecase, and handler adapter := adapter.New(mysqlDB, redisClient, s3Client, presignClient, os.Getenv("AWS_BUCKET")) - usecase := usecase.New(adapter.User, adapter.MainCateg, adapter.SubCateg, adapter.Icon, adapter.Transaction, adapter.RedisService, adapter.UserIcon, adapter.S3Service) + usecase := usecase.New(adapter.User, adapter.MainCateg, adapter.SubCateg, adapter.Icon, adapter.Transaction, adapter.MonthlyTrans, adapter.RedisService, adapter.UserIcon, adapter.S3Service) handler := handler.New(usecase.User, usecase.MainCateg, usecase.SubCateg, usecase.Transaction, usecase.Icon, usecase.UserIcon, usecase.InitData) if err := initServe(handler); err != nil { logger.Fatal("Unable to start server", "error", err) diff --git a/internal/domain/time.go b/internal/domain/time.go new file mode 100644 index 0000000..50f3199 --- /dev/null +++ b/internal/domain/time.go @@ -0,0 +1,18 @@ +package domain + +import "time" + +// IsSameMonth returns true if the two given times are in the same month. +func IsSameMonth(t1, t2 string) bool { + date1, err := time.Parse(time.DateOnly, t1) + if err != nil { + return false + } + + date2, err := time.Parse(time.DateOnly, t2) + if err != nil { + return false + } + + return date1.Year() == date2.Year() && date1.Month() == date2.Month() +} diff --git a/internal/handler/interfaces/interfaces.go b/internal/handler/interfaces/interfaces.go index 5aefa57..9197344 100644 --- a/internal/handler/interfaces/interfaces.go +++ b/internal/handler/interfaces/interfaces.go @@ -66,7 +66,7 @@ type TransactionUC interface { Delete(ctx context.Context, id int64, user domain.User) error // GetAccInfo returns the accunulated information by user id. - GetAccInfo(ctx context.Context, query domain.GetAccInfoQuery, user domain.User) (domain.AccInfo, error) + GetAccInfo(ctx context.Context, user domain.User, query domain.GetAccInfoQuery, timeRange domain.TimeRangeType) (domain.AccInfo, error) // GetBarChartData returns bar chart data. GetBarChartData(ctx context.Context, chartDateRange domain.ChartDateRange, timeRangeType domain.TimeRangeType, transactionType domain.TransactionType, mainCategIDs []int64, user domain.User) (domain.ChartData, error) diff --git a/internal/handler/transaction/transaction.go b/internal/handler/transaction/transaction.go index 5d5411f..bbf6391 100644 --- a/internal/handler/transaction/transaction.go +++ b/internal/handler/transaction/transaction.go @@ -195,6 +195,9 @@ func (h *Hlr) Delete(w http.ResponseWriter, r *http.Request) { func (h *Hlr) GetAccInfo(w http.ResponseWriter, r *http.Request) { query := genGetAccInfoQuery(r) + rawTimeRangeType := r.URL.Query().Get("time_range") + timeRangeType := domain.CvtToTimeRangeType(rawTimeRangeType) + v := validator.New() if !v.GetAccInfo(query) { errutil.VildateErrorResponse(w, r, v.Error) @@ -203,7 +206,7 @@ func (h *Hlr) GetAccInfo(w http.ResponseWriter, r *http.Request) { user := ctxutil.GetUser(r) ctx := r.Context() - info, err := h.transaction.GetAccInfo(ctx, query, *user) + info, err := h.transaction.GetAccInfo(ctx, *user, query, timeRangeType) if err != nil { errutil.ServerErrorResponse(w, r, err) return diff --git a/internal/usecase/interfaces/interfaces.go b/internal/usecase/interfaces/interfaces.go index 3e262de..1991510 100644 --- a/internal/usecase/interfaces/interfaces.go +++ b/internal/usecase/interfaces/interfaces.go @@ -134,6 +134,9 @@ type TransactionRepo interface { type MonthlyTransRepo interface { // Create inserts new monthly transactions into the database. Create(ctx context.Context, date time.Time, trans []domain.MonthlyAggregatedData) error + + // GetByUserIDAndMonthDate returns monthly aggregated data by user id and month date. + GetByUserIDAndMonthDate(ctx context.Context, userID int64, monthDate time.Time) (domain.AccInfo, error) } // RedisService is the interface that wraps the basic methods for redis service. diff --git a/internal/usecase/transaction/transaction.go b/internal/usecase/transaction/transaction.go index 75bf49f..a774af7 100644 --- a/internal/usecase/transaction/transaction.go +++ b/internal/usecase/transaction/transaction.go @@ -2,6 +2,7 @@ package transaction import ( "context" + "errors" "time" "github.com/eyo-chen/expense-tracker-go/internal/domain" @@ -14,25 +15,34 @@ const ( PackageName = "usecase/transaction" ) +var ( + now = func() time.Time { + return time.Now() + } +) + type UC struct { - Transaction interfaces.TransactionRepo - MainCateg interfaces.MainCategRepo - SubCateg interfaces.SubCategRepo - Redis interfaces.RedisService - S3 interfaces.S3Service + Transaction interfaces.TransactionRepo + MainCateg interfaces.MainCategRepo + SubCateg interfaces.SubCategRepo + MonthlyTrans interfaces.MonthlyTransRepo + Redis interfaces.RedisService + S3 interfaces.S3Service } func New(t interfaces.TransactionRepo, m interfaces.MainCategRepo, s interfaces.SubCategRepo, + mt interfaces.MonthlyTransRepo, r interfaces.RedisService, s3 interfaces.S3Service) *UC { return &UC{ - Transaction: t, - MainCateg: m, - SubCateg: s, - Redis: r, - S3: s3, + Transaction: t, + MainCateg: m, + SubCateg: s, + MonthlyTrans: mt, + Redis: r, + S3: s3, } } @@ -174,8 +184,28 @@ func (u *UC) Delete(ctx context.Context, id int64, user domain.User) error { return u.Transaction.Delete(ctx, id) } -func (u *UC) GetAccInfo(ctx context.Context, query domain.GetAccInfoQuery, user domain.User) (domain.AccInfo, error) { - return u.Transaction.GetAccInfo(ctx, query, user.ID) +func (u *UC) GetAccInfo(ctx context.Context, user domain.User, query domain.GetAccInfoQuery, timeRange domain.TimeRangeType) (domain.AccInfo, error) { + if timeRange != domain.TimeRangeTypeOneMonth || + query.StartDate == nil || !domain.IsSameMonth(now().Format(time.DateOnly), *query.StartDate) { + return u.Transaction.GetAccInfo(ctx, query, user.ID) + } + + t, err := time.Parse(time.DateOnly, *query.StartDate) + if err != nil { + logger.Error("time.Parse failed", "package", PackageName, "err", err) + return domain.AccInfo{}, err + } + + accInfo, err := u.MonthlyTrans.GetByUserIDAndMonthDate(ctx, user.ID, t) + if errors.Is(err, domain.ErrDataNotFound) { + return u.Transaction.GetAccInfo(ctx, query, user.ID) + } + if err != nil { + return domain.AccInfo{}, err + } + + return accInfo, nil + } func (u *UC) GetBarChartData(ctx context.Context, chartDateRange domain.ChartDateRange, timeRangeType domain.TimeRangeType, transactionType domain.TransactionType, mainCategIDs []int64, user domain.User) (domain.ChartData, error) { diff --git a/internal/usecase/transaction/transaction_test.go b/internal/usecase/transaction/transaction_test.go index ae82ea9..ced78d4 100644 --- a/internal/usecase/transaction/transaction_test.go +++ b/internal/usecase/transaction/transaction_test.go @@ -23,12 +23,13 @@ var ( type TransactionSuite struct { suite.Suite - uc *UC - mockTransactionRepo *mocks.TransactionRepo - mockMainCategRepo *mocks.MainCategRepo - mockSubCategRepo *mocks.SubCategRepo - mockRedis *mocks.RedisService - mockS3 *mocks.S3Service + uc *UC + mockTransactionRepo *mocks.TransactionRepo + mockMonthlyTransRepo *mocks.MonthlyTransRepo + mockMainCategRepo *mocks.MainCategRepo + mockSubCategRepo *mocks.SubCategRepo + mockRedis *mocks.RedisService + mockS3 *mocks.S3Service } func TestTransactionSuite(t *testing.T) { @@ -39,13 +40,26 @@ func (s *TransactionSuite) SetupSuite() { logger.Register() } +func setNow(t time.Time) { + now = func() time.Time { + return t + } +} + +func resetNow() { + now = func() time.Time { + return time.Now() + } +} + func (s *TransactionSuite) SetupTest() { s.mockTransactionRepo = mocks.NewTransactionRepo(s.T()) + s.mockMonthlyTransRepo = mocks.NewMonthlyTransRepo(s.T()) s.mockMainCategRepo = mocks.NewMainCategRepo(s.T()) s.mockSubCategRepo = mocks.NewSubCategRepo(s.T()) s.mockRedis = mocks.NewRedisService(s.T()) s.mockS3 = mocks.NewS3Service(s.T()) - s.uc = New(s.mockTransactionRepo, s.mockMainCategRepo, s.mockSubCategRepo, s.mockRedis, s.mockS3) + s.uc = New(s.mockTransactionRepo, s.mockMainCategRepo, s.mockSubCategRepo, s.mockMonthlyTransRepo, s.mockRedis, s.mockS3) } func (s *TransactionSuite) TearDownTest() { @@ -111,12 +125,14 @@ func create_GetMainCategFail_ReturnError(s *TransactionSuite, desc string) { Note: "note", } + mockErr := errors.New("get main category fail") + // prepare mock services - s.mockMainCategRepo.Mock.On("GetByID", transInput.MainCategID, transInput.UserID).Return(nil, errors.New("get main category fail")).Once() + s.mockMainCategRepo.Mock.On("GetByID", transInput.MainCategID, transInput.UserID).Return(nil, mockErr).Once() // action, assertion err := s.uc.Create(mockCtx, transInput) - s.Require().Equal(errors.New("get main category fail"), err, desc) + s.Require().ErrorIs(err, mockErr, desc) } func create_TypeNotMatch_ReturnError(s *TransactionSuite, desc string) { @@ -139,7 +155,7 @@ func create_TypeNotMatch_ReturnError(s *TransactionSuite, desc string) { // action, assertion err := s.uc.Create(mockCtx, transInput) - s.Require().EqualError(err, domain.ErrTypeNotConsistent.Error(), desc) + s.Require().ErrorIs(err, domain.ErrTypeNotConsistent, desc) } func create_GetSubCategFail_ReturnError(s *TransactionSuite, desc string) { @@ -157,13 +173,15 @@ func create_GetSubCategFail_ReturnError(s *TransactionSuite, desc string) { Note: "note", } + mockErr := errors.New("get subcategory fail") + // prepare mock services s.mockMainCategRepo.Mock.On("GetByID", transInput.MainCategID, transInput.UserID).Return(&mainCateg, nil).Once() - s.mockSubCategRepo.Mock.On("GetByID", transInput.SubCategID, transInput.UserID).Return(nil, errors.New("get subcategory fail")).Once() + s.mockSubCategRepo.Mock.On("GetByID", transInput.SubCategID, transInput.UserID).Return(nil, mockErr).Once() // action, assertion err := s.uc.Create(mockCtx, transInput) - s.Require().EqualError(err, "get subcategory fail", desc) + s.Require().ErrorIs(err, mockErr, desc) } func create_MainCategNotMatch_ReturnError(s *TransactionSuite, desc string) { @@ -207,14 +225,16 @@ func create_CreateFail_ReturnError(s *TransactionSuite, desc string) { Note: "note", } + mockErr := errors.New("create fail") + // prepare mock services s.mockMainCategRepo.Mock.On("GetByID", transInput.MainCategID, transInput.UserID).Return(&mainCateg, nil).Once() s.mockSubCategRepo.Mock.On("GetByID", transInput.SubCategID, transInput.UserID).Return(&subCateg, nil).Once() - s.mockTransactionRepo.Mock.On("Create", mockCtx, transInput).Return(errors.New("create fail")).Once() + s.mockTransactionRepo.Mock.On("Create", mockCtx, transInput).Return(mockErr).Once() // action, assertion err := s.uc.Create(mockCtx, transInput) - s.Require().EqualError(err, "create fail", desc) + s.Require().ErrorIs(err, mockErr, desc) } func (s *TransactionSuite) TestGetAll() { @@ -255,12 +275,13 @@ func getAll_GetTransFail_ReturnError(s *TransactionSuite, desc string) { mockDecodedNextKeys := domain.DecodedNextKeys{} mockOpt := domain.GetTransOpt{} mockUser := domain.User{ID: 1} + mockErr := errors.New("get transactions fail") s.mockTransactionRepo.On("GetAll", mockCtx, mockOpt, int64(1)). - Return(nil, mockDecodedNextKeys, errors.New("error")).Once() + Return(nil, mockDecodedNextKeys, mockErr).Once() result, cursor, err := s.uc.GetAll(mockCtx, mockOpt, mockUser) - s.Require().Equal(errors.New("error"), err, desc) + s.Require().ErrorIs(err, mockErr, desc) s.Require().Nil(result, desc) s.Require().Empty(cursor, desc) } @@ -368,7 +389,7 @@ func getAll_WithCustomIcon_ReturnTransactions(s *TransactionSuite, desc string) result, cursor, err := s.uc.GetAll(mockCtx, mockOpt, mockUser) s.Require().NoError(err, desc) s.Require().Equal(mockTrans, result, desc) - s.Require().Equal(domain.Cursor{}, cursor, desc) + s.Require().Empty(cursor, desc) } func getAll_GetByFuncFail_ReturnError(s *TransactionSuite, desc string) { @@ -451,12 +472,13 @@ func update_GetMainCategFail_ReturnError(s *TransactionSuite, desc string) { Date: mockTimeNow, Note: "note", } + mockErr := errors.New("error") s.mockMainCategRepo.On("GetByID", trans.MainCategID, user.ID). - Return(nil, errors.New("error")).Once() + Return(nil, mockErr).Once() err := s.uc.Update(mockCtx, trans, user) - s.Require().Equal(errors.New("error"), err, desc) + s.Require().ErrorIs(err, mockErr, desc) } func update_TypeNotMatch_ReturnError(s *TransactionSuite, desc string) { @@ -476,7 +498,7 @@ func update_TypeNotMatch_ReturnError(s *TransactionSuite, desc string) { Return(&mainCateg, nil).Once() err := s.uc.Update(mockCtx, trans, user) - s.Require().Equal(domain.ErrTypeNotConsistent, err, desc) + s.Require().ErrorIs(err, domain.ErrTypeNotConsistent, desc) } func update_GetSubCategFail_ReturnError(s *TransactionSuite, desc string) { @@ -491,15 +513,15 @@ func update_GetSubCategFail_ReturnError(s *TransactionSuite, desc string) { Date: mockTimeNow, Note: "note", } - + mockErr := errors.New("error") s.mockMainCategRepo.On("GetByID", trans.MainCategID, user.ID). Return(&mainCateg, nil).Once() s.mockSubCategRepo.On("GetByID", trans.SubCategID, user.ID). - Return(nil, errors.New("error")).Once() + Return(nil, mockErr).Once() err := s.uc.Update(mockCtx, trans, user) - s.Require().Equal(errors.New("error"), err, desc) + s.Require().ErrorIs(err, mockErr, desc) } func update_MainCategNotMatch_ReturnError(s *TransactionSuite, desc string) { @@ -523,7 +545,7 @@ func update_MainCategNotMatch_ReturnError(s *TransactionSuite, desc string) { Return(&subCateg, nil).Once() err := s.uc.Update(mockCtx, trans, user) - s.Require().Equal(domain.ErrMainCategNotConsistent, err, desc) + s.Require().ErrorIs(err, domain.ErrMainCategNotConsistent, desc) } func update_GetTransFail_UpdateSuccessfully(s *TransactionSuite, desc string) { @@ -539,6 +561,7 @@ func update_GetTransFail_UpdateSuccessfully(s *TransactionSuite, desc string) { Date: mockTimeNow, Note: "note", } + mockErr := errors.New("error") s.mockMainCategRepo.On("GetByID", trans.MainCategID, user.ID). Return(&mainCateg, nil).Once() @@ -547,10 +570,10 @@ func update_GetTransFail_UpdateSuccessfully(s *TransactionSuite, desc string) { Return(&subCateg, nil).Once() s.mockTransactionRepo.On("GetByIDAndUserID", mockCtx, trans.ID, user.ID). - Return(domain.Transaction{}, errors.New("error")).Once() + Return(domain.Transaction{}, mockErr).Once() err := s.uc.Update(mockCtx, trans, user) - s.Require().Equal(errors.New("error"), err, desc) + s.Require().ErrorIs(err, mockErr, desc) } func update_UpdateFail_UpdateSuccessfully(s *TransactionSuite, desc string) { @@ -566,6 +589,7 @@ func update_UpdateFail_UpdateSuccessfully(s *TransactionSuite, desc string) { Date: mockTimeNow, Note: "note", } + mockErr := errors.New("error") s.mockMainCategRepo.On("GetByID", trans.MainCategID, user.ID). Return(&mainCateg, nil).Once() @@ -577,10 +601,10 @@ func update_UpdateFail_UpdateSuccessfully(s *TransactionSuite, desc string) { Return(domain.Transaction{}, nil).Once() s.mockTransactionRepo.On("Update", mockCtx, trans). - Return(errors.New("error")).Once() + Return(mockErr).Once() err := s.uc.Update(mockCtx, trans, user) - s.Require().Equal(errors.New("error"), err, desc) + s.Require().ErrorIs(err, mockErr, desc) } func (s *TransactionSuite) TestDelete() { @@ -617,18 +641,23 @@ func delete_CheckPermessionFail_ReturnError(s *TransactionSuite, desc string) { ID: 1, } + mockErr := errors.New("error") + s.mockTransactionRepo. On("GetByIDAndUserID", mockCtx, int64(1), user.ID). - Return(domain.Transaction{}, errors.New("error")).Once() + Return(domain.Transaction{}, mockErr).Once() err := s.uc.Delete(mockCtx, int64(1), user) - s.Require().Equal(errors.New("error"), err, desc) + s.Require().ErrorIs(err, mockErr, desc) } func (s *TransactionSuite) TestGetAccInfo() { for scenario, fn := range map[string]func(s *TransactionSuite, desc string){ - "when no error, return acc info": getAccInfo_NoError_ReturnAccInfo, - "when get acc info fail, return error": getAccInfo_GetAccInfoFail_ReturnError, + "when no error, return acc info": getAccInfo_NoError_ReturnAccInfo, + "when unspecified time range, return acc info": getAccInfo_UnspecifiedTimeRange_ReturnAccInfo, + "when get monthly trans, return acc info": getAccInfo_GetMonthlyTrans_ReturnAccInfo, + "when get acc info fail, return error": getAccInfo_GetAccInfoFail_ReturnError, + "when get monthly trans not found, return acc info": getAccInfo_GetMonthlyTransNotFound_ReturnAccInfo, } { s.Run(testutil.GetFunName(fn), func() { s.SetupTest() @@ -652,23 +681,110 @@ func getAccInfo_NoError_ReturnAccInfo(s *TransactionSuite, desc string) { s.mockTransactionRepo.On("GetAccInfo", mockCtx, query, user.ID). Return(accInfo, nil).Once() - result, err := s.uc.GetAccInfo(mockCtx, query, user) + result, err := s.uc.GetAccInfo(mockCtx, user, query, domain.TimeRangeTypeOneMonth) s.Require().NoError(err, desc) s.Require().Equal(accInfo, result, desc) } -func getAccInfo_GetAccInfoFail_ReturnError(s *TransactionSuite, desc string) { +func getAccInfo_UnspecifiedTimeRange_ReturnAccInfo(s *TransactionSuite, desc string) { startDate := "2024-03-01" endDate := "2024-03-31" user := domain.User{ID: 1} query := domain.GetAccInfoQuery{StartDate: &startDate, EndDate: &endDate} + accInfo := domain.AccInfo{ + TotalIncome: 100, + TotalExpense: 200, + TotalBalance: -100, + } s.mockTransactionRepo.On("GetAccInfo", mockCtx, query, user.ID). - Return(domain.AccInfo{}, errors.New("get acc info fail")).Once() + Return(accInfo, nil).Once() - result, err := s.uc.GetAccInfo(mockCtx, query, user) - s.Require().EqualError(err, "get acc info fail", desc) - s.Require().Equal(domain.AccInfo{}, result, desc) + result, err := s.uc.GetAccInfo(mockCtx, user, query, domain.TimeRangeTypeUnSpecified) + s.Require().NoError(err, desc) + s.Require().Equal(accInfo, result, desc) +} + +func getAccInfo_GetMonthlyTrans_ReturnAccInfo(s *TransactionSuite, desc string) { + startDate := "2024-10-01" + endDate := "2024-10-31" + user := domain.User{ID: 1} + query := domain.GetAccInfoQuery{StartDate: &startDate, EndDate: &endDate} + accInfo := domain.AccInfo{ + TotalIncome: 100, + TotalExpense: 200, + TotalBalance: -100, + } + + startTime, err := time.Parse(time.DateOnly, startDate) + s.Require().NoError(err) + + // set now to target month + setNow(startTime) + + s.mockMonthlyTransRepo.On("GetByUserIDAndMonthDate", mockCtx, user.ID, startTime). + Return(accInfo, nil).Once() + + result, err := s.uc.GetAccInfo(mockCtx, user, query, domain.TimeRangeTypeOneMonth) + s.Require().NoError(err, desc) + s.Require().Equal(accInfo, result, desc) + + // reset now + resetNow() +} + +func getAccInfo_GetAccInfoFail_ReturnError(s *TransactionSuite, desc string) { + startDate := "2024-10-01" + endDate := "2024-10-31" + user := domain.User{ID: 1} + query := domain.GetAccInfoQuery{StartDate: &startDate, EndDate: &endDate} + mockErr := errors.New("get monthly trans fail") + + startTime, err := time.Parse(time.DateOnly, startDate) + s.Require().NoError(err) + + // set now to target month + setNow(startTime) + + s.mockMonthlyTransRepo.On("GetByUserIDAndMonthDate", mockCtx, user.ID, startTime). + Return(domain.AccInfo{}, mockErr).Once() + + result, err := s.uc.GetAccInfo(mockCtx, user, query, domain.TimeRangeTypeOneMonth) + s.Require().ErrorIs(err, mockErr, desc) + s.Require().Empty(result, desc) + + // reset now + resetNow() +} + +func getAccInfo_GetMonthlyTransNotFound_ReturnAccInfo(s *TransactionSuite, desc string) { + startDate := "2024-10-01" + endDate := "2024-10-31" + user := domain.User{ID: 1} + query := domain.GetAccInfoQuery{StartDate: &startDate, EndDate: &endDate} + accInfo := domain.AccInfo{ + TotalIncome: 100, + TotalExpense: 200, + TotalBalance: -100, + } + + startTime, err := time.Parse(time.DateOnly, startDate) + s.Require().NoError(err) + + // set now to target month + setNow(startTime) + + s.mockMonthlyTransRepo.On("GetByUserIDAndMonthDate", mockCtx, user.ID, startTime). + Return(domain.AccInfo{}, domain.ErrDataNotFound).Once() + s.mockTransactionRepo.On("GetAccInfo", mockCtx, query, user.ID). + Return(accInfo, nil).Once() + + result, err := s.uc.GetAccInfo(mockCtx, user, query, domain.TimeRangeTypeOneMonth) + s.Require().NoError(err, desc) + s.Require().Equal(accInfo, result, desc) + + // reset now + resetNow() } func (s *TransactionSuite) TestGetBarChartData() { @@ -971,15 +1087,16 @@ func getBarChartData_GetChartDataFail_ReturnError(s *TransactionSuite, desc stri } mainCategIDs := []int64{1} + mockErr := errors.New("error") s.mockTransactionRepo.On("GetDailyBarChartData", mockCtx, chartDataRange, domain.TransactionTypeExpense, mainCategIDs, int64(1)). - Return(domain.DateToChartData{}, errors.New("error")).Once() + Return(domain.DateToChartData{}, mockErr).Once() // prepare expected result expResult := domain.ChartData{} result, err := s.uc.GetBarChartData(mockCtx, chartDataRange, domain.TimeRangeTypeOneWeekDay, domain.TransactionTypeExpense, mainCategIDs, domain.User{ID: 1}) - s.Require().Equal(errors.New("error"), err, desc) + s.Require().ErrorIs(err, mockErr, desc) s.Require().Equal(expResult, result, desc) } @@ -1030,13 +1147,14 @@ func getPieChartData_GetChartDataFail_ReturnError(s *TransactionSuite, desc stri Start: start, End: end, } + mockErr := errors.New("error") s.mockTransactionRepo.On("GetPieChartData", mockCtx, chartDataRange, domain.TransactionTypeExpense, int64(1)). - Return(domain.ChartData{}, errors.New("error")).Once() + Return(domain.ChartData{}, mockErr).Once() result, err := s.uc.GetPieChartData(mockCtx, chartDataRange, domain.TransactionTypeExpense, domain.User{ID: 1}) - s.Require().EqualError(err, "error", desc) - s.Require().Equal(domain.ChartData{}, result, desc) + s.Require().ErrorIs(err, mockErr, desc) + s.Require().Empty(result, desc) } func (s *TransactionSuite) TestGetLineChartData() { @@ -1326,15 +1444,14 @@ func getLineChartData_GetChartDataFail_ReturnError(s *TransactionSuite, desc str Start: start, End: end, } + mockErr := errors.New("error") s.mockTransactionRepo.On("GetDailyLineChartData", mockCtx, chartDataRange, int64(1)). - Return(domain.DateToChartData{}, errors.New("error")).Once() - - expResult := domain.ChartData{} + Return(domain.DateToChartData{}, mockErr).Once() result, err := s.uc.GetLineChartData(mockCtx, chartDataRange, domain.TimeRangeTypeOneWeekDay, domain.User{ID: 1}) - s.Require().Equal(errors.New("error"), err, desc) - s.Require().Equal(expResult, result, desc) + s.Require().ErrorIs(err, mockErr, desc) + s.Require().Empty(result, desc) } func (s *TransactionSuite) TestGetMonthlyData() { @@ -1477,11 +1594,12 @@ func getMonthlyData_GetMonthlyDataFail_ReturnError(s *TransactionSuite, desc str StartDate: startDate, EndDate: endDate, } + mockErr := errors.New("error") s.mockTransactionRepo.On("GetMonthlyData", mockCtx, dateRange, int64(1)). - Return(nil, errors.New("error")).Once() + Return(nil, mockErr).Once() result, err := s.uc.GetMonthlyData(mockCtx, dateRange, domain.User{ID: 1}) - s.Require().Equal(errors.New("error"), err, desc) - s.Require().Equal([]domain.TransactionType{}, result, desc) + s.Require().ErrorIs(err, mockErr, desc) + s.Require().Empty(result, desc) } diff --git a/internal/usecase/usecase.go b/internal/usecase/usecase.go index 095e73d..0eb3866 100644 --- a/internal/usecase/usecase.go +++ b/internal/usecase/usecase.go @@ -5,6 +5,7 @@ import ( "github.com/eyo-chen/expense-tracker-go/internal/usecase/initdata" "github.com/eyo-chen/expense-tracker-go/internal/usecase/interfaces" "github.com/eyo-chen/expense-tracker-go/internal/usecase/maincateg" + "github.com/eyo-chen/expense-tracker-go/internal/usecase/monthlytrans" "github.com/eyo-chen/expense-tracker-go/internal/usecase/subcateg" "github.com/eyo-chen/expense-tracker-go/internal/usecase/transaction" "github.com/eyo-chen/expense-tracker-go/internal/usecase/user" @@ -12,13 +13,14 @@ import ( ) type Usecase struct { - User *user.UC - MainCateg *maincateg.UC - SubCateg *subcateg.UC - Transaction *transaction.UC - Icon *icon.UC - UserIcon *usericon.UC - InitData *initdata.UC + User *user.UC + MainCateg *maincateg.UC + SubCateg *subcateg.UC + Transaction *transaction.UC + MonthlyTrans *monthlytrans.UC + Icon *icon.UC + UserIcon *usericon.UC + InitData *initdata.UC } func New(u interfaces.UserRepo, @@ -26,6 +28,7 @@ func New(u interfaces.UserRepo, s interfaces.SubCategRepo, i interfaces.IconRepo, t interfaces.TransactionRepo, + mt interfaces.MonthlyTransRepo, r interfaces.RedisService, ui interfaces.UserIconRepo, s3 interfaces.S3Service, @@ -34,7 +37,7 @@ func New(u interfaces.UserRepo, User: user.New(u, r), MainCateg: maincateg.New(m, i, ui, r, s3), SubCateg: subcateg.New(s, m), - Transaction: transaction.New(t, m, s, r, s3), + Transaction: transaction.New(t, m, s, mt, r, s3), Icon: icon.New(i, ui, r, s3), UserIcon: usericon.New(s3, ui), InitData: initdata.New(i, m, s, u), diff --git a/mocks/MonthlyTransRepo.go b/mocks/MonthlyTransRepo.go index 86f96c4..48e5ae7 100644 --- a/mocks/MonthlyTransRepo.go +++ b/mocks/MonthlyTransRepo.go @@ -35,6 +35,34 @@ func (_m *MonthlyTransRepo) Create(ctx context.Context, date time.Time, trans [] return r0 } +// GetByUserIDAndMonthDate provides a mock function with given fields: ctx, userID, monthDate +func (_m *MonthlyTransRepo) GetByUserIDAndMonthDate(ctx context.Context, userID int64, monthDate time.Time) (domain.AccInfo, error) { + ret := _m.Called(ctx, userID, monthDate) + + if len(ret) == 0 { + panic("no return value specified for GetByUserIDAndMonthDate") + } + + var r0 domain.AccInfo + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int64, time.Time) (domain.AccInfo, error)); ok { + return rf(ctx, userID, monthDate) + } + if rf, ok := ret.Get(0).(func(context.Context, int64, time.Time) domain.AccInfo); ok { + r0 = rf(ctx, userID, monthDate) + } else { + r0 = ret.Get(0).(domain.AccInfo) + } + + if rf, ok := ret.Get(1).(func(context.Context, int64, time.Time) error); ok { + r1 = rf(ctx, userID, monthDate) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // NewMonthlyTransRepo creates a new instance of MonthlyTransRepo. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMonthlyTransRepo(t interface { diff --git a/mocks/TransactionUC.go b/mocks/TransactionUC.go index 7efa1c1..80d9f69 100644 --- a/mocks/TransactionUC.go +++ b/mocks/TransactionUC.go @@ -51,9 +51,9 @@ func (_m *TransactionUC) Delete(ctx context.Context, id int64, user domain.User) return r0 } -// GetAccInfo provides a mock function with given fields: ctx, query, user -func (_m *TransactionUC) GetAccInfo(ctx context.Context, query domain.GetAccInfoQuery, user domain.User) (domain.AccInfo, error) { - ret := _m.Called(ctx, query, user) +// GetAccInfo provides a mock function with given fields: ctx, user, query, timeRange +func (_m *TransactionUC) GetAccInfo(ctx context.Context, user domain.User, query domain.GetAccInfoQuery, timeRange domain.TimeRangeType) (domain.AccInfo, error) { + ret := _m.Called(ctx, user, query, timeRange) if len(ret) == 0 { panic("no return value specified for GetAccInfo") @@ -61,17 +61,17 @@ func (_m *TransactionUC) GetAccInfo(ctx context.Context, query domain.GetAccInfo var r0 domain.AccInfo var r1 error - if rf, ok := ret.Get(0).(func(context.Context, domain.GetAccInfoQuery, domain.User) (domain.AccInfo, error)); ok { - return rf(ctx, query, user) + if rf, ok := ret.Get(0).(func(context.Context, domain.User, domain.GetAccInfoQuery, domain.TimeRangeType) (domain.AccInfo, error)); ok { + return rf(ctx, user, query, timeRange) } - if rf, ok := ret.Get(0).(func(context.Context, domain.GetAccInfoQuery, domain.User) domain.AccInfo); ok { - r0 = rf(ctx, query, user) + if rf, ok := ret.Get(0).(func(context.Context, domain.User, domain.GetAccInfoQuery, domain.TimeRangeType) domain.AccInfo); ok { + r0 = rf(ctx, user, query, timeRange) } else { r0 = ret.Get(0).(domain.AccInfo) } - if rf, ok := ret.Get(1).(func(context.Context, domain.GetAccInfoQuery, domain.User) error); ok { - r1 = rf(ctx, query, user) + if rf, ok := ret.Get(1).(func(context.Context, domain.User, domain.GetAccInfoQuery, domain.TimeRangeType) error); ok { + r1 = rf(ctx, user, query, timeRange) } else { r1 = ret.Error(1) } diff --git a/pkg/validator/transaction.go b/pkg/validator/transaction.go index 7b6d28c..8e43b18 100644 --- a/pkg/validator/transaction.go +++ b/pkg/validator/transaction.go @@ -47,7 +47,6 @@ func (v *Validator) GetAccInfo(q domain.GetAccInfoQuery) bool { v.Check(isValidDateFormat(q.StartDate), "startDate", "Start date must be in YYYY-MM-DD format") v.Check(isValidDateFormat(q.EndDate), "endDate", "End date must be in YYYY-MM-DD format") v.Check(checkStartDateBeforeEndDate(q.StartDate, q.EndDate), "startDate", "Start date must be before end date") - return v.Valid() }