Skip to content

Commit

Permalink
Feat/get acc info from monthly trans (#184)
Browse files Browse the repository at this point in the history
* feat: add interface and mock

* feat: add IsSameMonth

* feat: get acc info from monthly trans

* test: modify testing

* fix: remove checking time range

* feat: get acc info when no monthly trans

---------

Co-authored-by: Eyo Chen <eyo.chen@amazingtalker.com>
  • Loading branch information
eyo-chen and Eyo Chen authored Dec 1, 2024
1 parent acc51a9 commit 927e71d
Show file tree
Hide file tree
Showing 11 changed files with 286 additions and 84 deletions.
2 changes: 1 addition & 1 deletion cmd/api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func main() {

// Setup adapter, usecase, and handler
adapter := adapter.New(mysqlDB, redisClient, s3Client, presignClient, os.Getenv("AWS_BUCKET"))
usecase := usecase.New(adapter.User, adapter.MainCateg, adapter.SubCateg, adapter.Icon, adapter.Transaction, adapter.RedisService, adapter.UserIcon, adapter.S3Service)
usecase := usecase.New(adapter.User, adapter.MainCateg, adapter.SubCateg, adapter.Icon, adapter.Transaction, adapter.MonthlyTrans, adapter.RedisService, adapter.UserIcon, adapter.S3Service)
handler := handler.New(usecase.User, usecase.MainCateg, usecase.SubCateg, usecase.Transaction, usecase.Icon, usecase.UserIcon, usecase.InitData)
if err := initServe(handler); err != nil {
logger.Fatal("Unable to start server", "error", err)
Expand Down
18 changes: 18 additions & 0 deletions internal/domain/time.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package domain

import "time"

// IsSameMonth returns true if the two given times are in the same month.
func IsSameMonth(t1, t2 string) bool {
date1, err := time.Parse(time.DateOnly, t1)
if err != nil {
return false
}

date2, err := time.Parse(time.DateOnly, t2)
if err != nil {
return false
}

return date1.Year() == date2.Year() && date1.Month() == date2.Month()
}
2 changes: 1 addition & 1 deletion internal/handler/interfaces/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ type TransactionUC interface {
Delete(ctx context.Context, id int64, user domain.User) error

// GetAccInfo returns the accunulated information by user id.
GetAccInfo(ctx context.Context, query domain.GetAccInfoQuery, user domain.User) (domain.AccInfo, error)
GetAccInfo(ctx context.Context, user domain.User, query domain.GetAccInfoQuery, timeRange domain.TimeRangeType) (domain.AccInfo, error)

// GetBarChartData returns bar chart data.
GetBarChartData(ctx context.Context, chartDateRange domain.ChartDateRange, timeRangeType domain.TimeRangeType, transactionType domain.TransactionType, mainCategIDs []int64, user domain.User) (domain.ChartData, error)
Expand Down
5 changes: 4 additions & 1 deletion internal/handler/transaction/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ func (h *Hlr) Delete(w http.ResponseWriter, r *http.Request) {

func (h *Hlr) GetAccInfo(w http.ResponseWriter, r *http.Request) {
query := genGetAccInfoQuery(r)
rawTimeRangeType := r.URL.Query().Get("time_range")
timeRangeType := domain.CvtToTimeRangeType(rawTimeRangeType)

v := validator.New()
if !v.GetAccInfo(query) {
errutil.VildateErrorResponse(w, r, v.Error)
Expand All @@ -203,7 +206,7 @@ func (h *Hlr) GetAccInfo(w http.ResponseWriter, r *http.Request) {

user := ctxutil.GetUser(r)
ctx := r.Context()
info, err := h.transaction.GetAccInfo(ctx, query, *user)
info, err := h.transaction.GetAccInfo(ctx, *user, query, timeRangeType)
if err != nil {
errutil.ServerErrorResponse(w, r, err)
return
Expand Down
3 changes: 3 additions & 0 deletions internal/usecase/interfaces/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ type TransactionRepo interface {
type MonthlyTransRepo interface {
// Create inserts new monthly transactions into the database.
Create(ctx context.Context, date time.Time, trans []domain.MonthlyAggregatedData) error

// GetByUserIDAndMonthDate returns monthly aggregated data by user id and month date.
GetByUserIDAndMonthDate(ctx context.Context, userID int64, monthDate time.Time) (domain.AccInfo, error)
}

// RedisService is the interface that wraps the basic methods for redis service.
Expand Down
54 changes: 42 additions & 12 deletions internal/usecase/transaction/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package transaction

import (
"context"
"errors"
"time"

"github.com/eyo-chen/expense-tracker-go/internal/domain"
Expand All @@ -14,25 +15,34 @@ const (
PackageName = "usecase/transaction"
)

var (
now = func() time.Time {
return time.Now()
}
)

type UC struct {
Transaction interfaces.TransactionRepo
MainCateg interfaces.MainCategRepo
SubCateg interfaces.SubCategRepo
Redis interfaces.RedisService
S3 interfaces.S3Service
Transaction interfaces.TransactionRepo
MainCateg interfaces.MainCategRepo
SubCateg interfaces.SubCategRepo
MonthlyTrans interfaces.MonthlyTransRepo
Redis interfaces.RedisService
S3 interfaces.S3Service
}

func New(t interfaces.TransactionRepo,
m interfaces.MainCategRepo,
s interfaces.SubCategRepo,
mt interfaces.MonthlyTransRepo,
r interfaces.RedisService,
s3 interfaces.S3Service) *UC {
return &UC{
Transaction: t,
MainCateg: m,
SubCateg: s,
Redis: r,
S3: s3,
Transaction: t,
MainCateg: m,
SubCateg: s,
MonthlyTrans: mt,
Redis: r,
S3: s3,
}
}

Expand Down Expand Up @@ -174,8 +184,28 @@ func (u *UC) Delete(ctx context.Context, id int64, user domain.User) error {
return u.Transaction.Delete(ctx, id)
}

func (u *UC) GetAccInfo(ctx context.Context, query domain.GetAccInfoQuery, user domain.User) (domain.AccInfo, error) {
return u.Transaction.GetAccInfo(ctx, query, user.ID)
func (u *UC) GetAccInfo(ctx context.Context, user domain.User, query domain.GetAccInfoQuery, timeRange domain.TimeRangeType) (domain.AccInfo, error) {
if timeRange != domain.TimeRangeTypeOneMonth ||
query.StartDate == nil || !domain.IsSameMonth(now().Format(time.DateOnly), *query.StartDate) {
return u.Transaction.GetAccInfo(ctx, query, user.ID)
}

t, err := time.Parse(time.DateOnly, *query.StartDate)
if err != nil {
logger.Error("time.Parse failed", "package", PackageName, "err", err)
return domain.AccInfo{}, err
}

accInfo, err := u.MonthlyTrans.GetByUserIDAndMonthDate(ctx, user.ID, t)
if errors.Is(err, domain.ErrDataNotFound) {
return u.Transaction.GetAccInfo(ctx, query, user.ID)
}
if err != nil {
return domain.AccInfo{}, err
}

return accInfo, nil

}

func (u *UC) GetBarChartData(ctx context.Context, chartDateRange domain.ChartDateRange, timeRangeType domain.TimeRangeType, transactionType domain.TransactionType, mainCategIDs []int64, user domain.User) (domain.ChartData, error) {
Expand Down
Loading

0 comments on commit 927e71d

Please sign in to comment.