Skip to content

Commit

Permalink
fix: fix Sql injection problem
Browse files Browse the repository at this point in the history
  • Loading branch information
afumu committed Jan 16, 2023
1 parent 11546ce commit 259d3bb
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 11 deletions.
25 changes: 24 additions & 1 deletion example/base/select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,20 @@ import (
"gorm.io/gorm"
"log"
"testing"
"time"
)

type Test2 struct {
TestId string `gorm:"primaryKey"`
Code string
Price string
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt time.Time
}

func TestSelectById(t *testing.T) {
user, resultDb := gplus.SelectById[User](1)
user, resultDb := gplus.SelectById[User]("or 1=1")
if resultDb.Error != nil {
if errors.Is(resultDb.Error, gorm.ErrRecordNotFound) {
log.Fatalln("SelectById Data not found:", resultDb.Error)
Expand All @@ -39,6 +49,19 @@ func TestSelectById(t *testing.T) {
log.Println(string(marshal))
}

func TestSelectByStrId(t *testing.T) {
test, resultDb := gplus.SelectById[Test2]("a = 1 or 1=1")
if resultDb.Error != nil {
if errors.Is(resultDb.Error, gorm.ErrRecordNotFound) {
log.Fatalln("SelectById Data not found:", resultDb.Error)
}
log.Fatalln("SelectById error:", resultDb.Error)
}
log.Println("RowsAffected:", resultDb.RowsAffected)
marshal, _ := json.Marshal(test)
log.Println(string(marshal))
}

func TestSelectByIds(t *testing.T) {
var ids []int
ids = append(ids, 1)
Expand Down
15 changes: 5 additions & 10 deletions gplus/base_dao.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,11 @@ func Update[T any](q *Query[T]) *gorm.DB {
}

func SelectById[T any](id any) (*T, *gorm.DB) {
q := NewQuery[T]()
q.Eq(getPKColumn[T](), id)
var entity T
resultDb := gormDb.Take(&entity, id)
if resultDb.RowsAffected == 0 {
return nil, resultDb
}
return &entity, resultDb
resultDb := buildCondition(q)
return &entity, resultDb.Limit(1).Find(&entity)
}

func SelectByIds[T any](ids any) ([]*T, *gorm.DB) {
Expand All @@ -114,11 +113,7 @@ func SelectByIds[T any](ids any) ([]*T, *gorm.DB) {
func SelectOne[T any](q *Query[T]) (*T, *gorm.DB) {
var entity T
resultDb := buildCondition(q)
resultDb.Take(&entity)
if resultDb.RowsAffected == 0 {
return nil, resultDb
}
return &entity, resultDb
return &entity, resultDb.Limit(1).Find(&entity)
}

func SelectList[T any](q *Query[T]) ([]*T, *gorm.DB) {
Expand Down

0 comments on commit 259d3bb

Please sign in to comment.