Skip to content

Commit

Permalink
Merge pull request #1898 from winfredLIN/issue1888
Browse files Browse the repository at this point in the history
modified: correct spelling of word column
  • Loading branch information
ColdWaterLW authored Oct 9, 2023
2 parents a7e2523 + 4f8d7ab commit 0bbd551
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 16 deletions.
16 changes: 8 additions & 8 deletions sqle/driver/mysql/rule/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -2687,27 +2687,27 @@ func getOnConditionLeftAndRightType(onCondition *ast.OnCondition, createTableStm
leftType = node.Tp.Tp
default:
// 默认获取子树的所有列 对应等号一侧 一般连接键只会有一个 不支持多个列的组合
lVisitor := util.ColumeNameVisitor{}
lVisitor := util.ColumnNameVisitor{}
binaryOperation.L.Accept(&lVisitor)
if len(lVisitor.ColumeNameList) > 1 {
if len(lVisitor.ColumnNameList) > 1 {
log.Logger().Warn("规则:建议JOIN字段类型保持一致,连接键不支持多个列的组合")
}
if len(lVisitor.ColumeNameList) == 1 {
leftType = getColumnType(lVisitor.ColumeNameList[0], createTableStmtMap)
if len(lVisitor.ColumnNameList) == 1 {
leftType = getColumnType(lVisitor.ColumnNameList[0], createTableStmtMap)
}
}

switch node := binaryOperation.R.(type) {
case *ast.FuncCastExpr:
rightType = node.Tp.Tp
default:
rVisitor := util.ColumeNameVisitor{}
rVisitor := util.ColumnNameVisitor{}
binaryOperation.R.Accept(&rVisitor)
if len(rVisitor.ColumeNameList) > 1 {
if len(rVisitor.ColumnNameList) > 1 {
log.Logger().Warn("规则:建议JOIN字段类型保持一致,连接键不支持多个列的组合")
}
if len(rVisitor.ColumeNameList) > 0 {
rightType = getColumnType(rVisitor.ColumeNameList[0], createTableStmtMap)
if len(rVisitor.ColumnNameList) > 0 {
rightType = getColumnType(rVisitor.ColumnNameList[0], createTableStmtMap)
}
}

Expand Down
10 changes: 5 additions & 5 deletions sqle/driver/mysql/util/visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,18 +228,18 @@ func (v *SelectVisitor) Leave(in ast.Node) (out ast.Node, ok bool) {
return in, true
}

type ColumeNameVisitor struct {
ColumeNameList []*ast.ColumnNameExpr
type ColumnNameVisitor struct {
ColumnNameList []*ast.ColumnNameExpr
}

func (v *ColumeNameVisitor) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
func (v *ColumnNameVisitor) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
switch stmt := in.(type) {
case *ast.ColumnNameExpr:
v.ColumeNameList = append(v.ColumeNameList, stmt)
v.ColumnNameList = append(v.ColumnNameList, stmt)
}
return in, false
}

func (v *ColumeNameVisitor) Leave(in ast.Node) (out ast.Node, ok bool) {
func (v *ColumnNameVisitor) Leave(in ast.Node) (out ast.Node, ok bool) {
return in, true
}
6 changes: 3 additions & 3 deletions sqle/driver/mysql/util/visitor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func TestSelectFieldExtractor(t *testing.T) {
}
}

func TestColumeNameVisitor(t *testing.T) {
func TestColumnNameVisitor(t *testing.T) {
tests := []struct {
input string
columnCount uint
Expand All @@ -130,10 +130,10 @@ func TestColumeNameVisitor(t *testing.T) {
stmt, err := parser.New().ParseOneStmt(tt.input, "", "")
assert.NoError(t, err)

visitor := &ColumeNameVisitor{}
visitor := &ColumnNameVisitor{}
stmt.Accept(visitor)

assert.Equal(t, tt.columnCount, uint(len(visitor.ColumeNameList)))
assert.Equal(t, tt.columnCount, uint(len(visitor.ColumnNameList)))
})
}
}

0 comments on commit 0bbd551

Please sign in to comment.