Skip to content

Commit

Permalink
Added ANSI_QUOTES mode
Browse files Browse the repository at this point in the history
  • Loading branch information
sananguliyev committed Dec 4, 2020
1 parent 24fd111 commit 4b9ef3c
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 26 deletions.
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,34 @@ func main() {
}
```

Parsing SQL mode `ANSI_QUOTES`:

Treat `"` as an identifier quote character (like the \` quote character) and not as a string quote character. You can still use \` to quote identifiers with this mode enabled. With `ANSI_QUOTES` enabled, you cannot use double quotation marks to quote literal strings because they are interpreted as identifiers.

```go
package main

import (
"github.com/SananGuliyev/sqlparser"
)

func main() {
sql := "SELECT * FROM table WHERE a = 'abc'"
sqlparser.SQLMode = sqlparser.SQLModeANSIQuotes
stmt, err := sqlparser.Parse(sql)
if err != nil {
// Do something with the err
}

// Otherwise do something with stmt
switch stmt := stmt.(type) {
case *sqlparser.Select:
_ = stmt
case *sqlparser.Insert:
}
}
```

See [parse_test.go](https://github.com/SananGuliyev/sqlparser/blob/master/parse_test.go) for more examples, or read the [godoc](https://godoc.org/github.com/SananGuliyev/sqlparser).


Expand Down
2 changes: 1 addition & 1 deletion analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ func ExtractSetValues(sql string) (keyValues map[SetKey]interface{}, scope strin
if setStmt.Scope != "" && scope != "" {
return nil, "", fmt.Errorf("unsupported in set: mixed using of variable scope")
}
_, out := NewStringTokenizer(key).Scan()
_, out := NewStringTokenizer(key, SQLMode).Scan()
key = string(out)
}

Expand Down
31 changes: 22 additions & 9 deletions ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ import (
"github.com/SananGuliyev/sqlparser/dependency/sqltypes"
)

const (
SQLModeStandard = iota
SQLModeANSIQuotes
)

var SQLMode = SQLModeStandard

// Instructions for creating new types: If a type
// needs to satisfy an interface, declare that function
// along with that interface. This will help users
Expand All @@ -46,7 +53,7 @@ import (
// is partially parsed but still contains a syntax error, the
// error is ignored and the DDL is returned anyway.
func Parse(sql string) (Statement, error) {
tokenizer := NewStringTokenizer(sql)
tokenizer := NewStringTokenizer(sql, SQLMode)
if yyParse(tokenizer) != 0 {
if tokenizer.partialDDL != nil {
log.Printf("ignoring error parsing DDL '%s': %v", sql, tokenizer.LastError)
Expand All @@ -61,7 +68,7 @@ func Parse(sql string) (Statement, error) {
// ParseStrictDDL is the same as Parse except it errors on
// partially parsed DDL statements.
func ParseStrictDDL(sql string) (Statement, error) {
tokenizer := NewStringTokenizer(sql)
tokenizer := NewStringTokenizer(sql, SQLMode)
if yyParse(tokenizer) != 0 {
return nil, tokenizer.LastError
}
Expand Down Expand Up @@ -97,7 +104,7 @@ func ParseNext(tokenizer *Tokenizer) (Statement, error) {
// SplitStatement returns the first sql statement up to either a ; or EOF
// and the remainder from the given buffer
func SplitStatement(blob string) (string, string, error) {
tokenizer := NewStringTokenizer(blob)
tokenizer := NewStringTokenizer(blob, SQLMode)
tkn := 0
for {
tkn, _ = tokenizer.Scan()
Expand All @@ -118,7 +125,7 @@ func SplitStatement(blob string) (string, string, error) {
// returns the sql pieces blob contains; or error if sql cannot be parsed
func SplitStatementToPieces(blob string) (pieces []string, err error) {
pieces = make([]string, 0, 16)
tokenizer := NewStringTokenizer(blob)
tokenizer := NewStringTokenizer(blob, SQLMode)

tkn := 0
var stmt string
Expand Down Expand Up @@ -3430,6 +3437,12 @@ func Backtick(in string) string {
}

func formatID(buf *TrackedBuffer, original, lowered string) {
var identChar rune
if SQLMode == SQLModeANSIQuotes {
identChar = '"'
} else {
identChar = '`'
}
isDbSystemVariable := false
if len(original) > 1 && original[:2] == "@@" {
isDbSystemVariable = true
Expand All @@ -3449,14 +3462,14 @@ func formatID(buf *TrackedBuffer, original, lowered string) {
return

mustEscape:
buf.WriteByte('`')
_, _ = buf.WriteRune(identChar)
for _, c := range original {
buf.WriteRune(c)
if c == '`' {
buf.WriteByte('`')
_, _ = buf.WriteRune(c)
if c == identChar {
_, _ = buf.WriteRune(identChar)
}
}
buf.WriteByte('`')
_, _ = buf.WriteRune(identChar)
}

func compliantName(in string) string {
Expand Down
5 changes: 2 additions & 3 deletions parse_next_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func TestParseNextErrors(t *testing.T) {
}

sql := tcase.input + "; select 1 from t"
tokens := NewStringTokenizer(sql)
tokens := NewStringTokenizer(sql, SQLMode)

// The first statement should be an error
_, err := ParseNext(tokens)
Expand Down Expand Up @@ -136,13 +136,12 @@ func TestParseNextEdgeCases(t *testing.T) {
}}

for _, test := range tests {
tokens := NewStringTokenizer(test.input)
tokens := NewStringTokenizer(test.input, SQLMode)

for i, want := range test.want {
tree, err := ParseNext(tokens)
if err != nil {
t.Fatalf("[%d] ParseNext(%q) err = %q, want nil", i, test.input, err)
continue
}

if got := String(tree); got != want {
Expand Down
33 changes: 28 additions & 5 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ type Tokenizer struct {
posVarIndex int
ParseTree Statement
partialDDL *DDL
sqlMode int
nesting int
multi bool
specialComment *Tokenizer
Expand All @@ -55,11 +56,12 @@ type Tokenizer struct {

// NewStringTokenizer creates a new Tokenizer for the
// sql string.
func NewStringTokenizer(sql string) *Tokenizer {
func NewStringTokenizer(sql string, sqlMode int) *Tokenizer {
buf := []byte(sql)
return &Tokenizer{
buf: buf,
bufSize: len(buf),
sqlMode: sqlMode,
}
}

Expand Down Expand Up @@ -595,7 +597,12 @@ func (tkn *Tokenizer) Scan() (int, []byte) {
return NE, nil
}
return int(ch), nil
case '\'', '"':
case '\'':
return tkn.scanString(ch, STRING)
case '"':
if tkn.sqlMode == SQLModeANSIQuotes {
return tkn.scanLiteralIdentifier()
}
return tkn.scanString(ch, STRING)
case '`':
return tkn.scanLiteralIdentifier()
Expand Down Expand Up @@ -667,25 +674,41 @@ func (tkn *Tokenizer) scanBitLiteral() (int, []byte) {
func (tkn *Tokenizer) scanLiteralIdentifier() (int, []byte) {
buffer := &bytes2.Buffer{}
backTickSeen := false
quoteSeen := false
for {
if backTickSeen {
if tkn.lastChar != '`' {
break
}
backTickSeen = false
buffer.WriteByte('`')
_ = buffer.WriteByte('`')
tkn.next()
continue
}
if quoteSeen {
if tkn.lastChar != '"' {
break
}
quoteSeen = false
_ = buffer.WriteByte('"')
tkn.next()
continue
}
// The previous char was not a backtick.
switch tkn.lastChar {
case '`':
backTickSeen = true
case '"':
if tkn.sqlMode == SQLModeANSIQuotes {
quoteSeen = true
} else {
_ = buffer.WriteByte(byte(tkn.lastChar))
}
case eofChar:
// Premature EOF.
return LEX_ERROR, buffer.Bytes()
default:
buffer.WriteByte(byte(tkn.lastChar))
_ = buffer.WriteByte(byte(tkn.lastChar))
}
tkn.next()
}
Expand Down Expand Up @@ -880,7 +903,7 @@ func (tkn *Tokenizer) scanMySQLSpecificComment() (int, []byte) {
tkn.consumeNext(buffer)
}
_, sql := ExtractMysqlComment(buffer.String())
tkn.specialComment = NewStringTokenizer(sql)
tkn.specialComment = NewStringTokenizer(sql, SQLMode)
return tkn.Scan()
}

Expand Down
32 changes: 30 additions & 2 deletions token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func TestLiteralID(t *testing.T) {
}}

for _, tcase := range testcases {
tkn := NewStringTokenizer(tcase.in)
tkn := NewStringTokenizer(tcase.in, SQLMode)
id, out := tkn.Scan()
if tcase.id != id || string(out) != tcase.out {
t.Errorf("Scan(%s): %d, %s, want %d, %s", tcase.in, id, out, tcase.id, tcase.out)
Expand Down Expand Up @@ -130,7 +130,7 @@ func TestString(t *testing.T) {
}}

for _, tcase := range testcases {
id, got := NewStringTokenizer(tcase.in).Scan()
id, got := NewStringTokenizer(tcase.in, SQLMode).Scan()
if tcase.id != id || string(got) != tcase.want {
t.Errorf("Scan(%q) = (%s, %q), want (%s, %q)", tcase.in, tokenName(id), got, tokenName(tcase.id), tcase.want)
}
Expand Down Expand Up @@ -189,3 +189,31 @@ func TestSplitStatement(t *testing.T) {
}
}
}

func TestParseANSIQuotesMode(t *testing.T) {
testcases := []struct {
in string
out string
}{{
in: `select * from "table"`,
out: `select * from "table"`,
}, {
in: `select * from "tbl"`,
out: `select * from tbl`,
}}

SQLMode = SQLModeANSIQuotes
for _, tcase := range testcases {
stmt, err := Parse(tcase.in)
if err != nil {
t.Errorf("EndOfStatementPosition(%s): ERROR: %v", tcase.in, err)
continue
}

finalSQL := String(stmt)
if tcase.out != finalSQL {
t.Errorf("EndOfStatementPosition(%s) got sql \"%s\" want \"%s\"", tcase.in, finalSQL, tcase.out)
}
}
SQLMode = SQLModeStandard
}
12 changes: 6 additions & 6 deletions tracked_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (buf *TrackedBuffer) Myprintf(format string, values ...interface{}) {
i++
}
if i > lasti {
buf.WriteString(format[lasti:i])
_, _ = buf.WriteString(format[lasti:i])
}
if i >= end {
break
Expand All @@ -78,18 +78,18 @@ func (buf *TrackedBuffer) Myprintf(format string, values ...interface{}) {
case 'c':
switch v := values[fieldnum].(type) {
case byte:
buf.WriteByte(v)
_ = buf.WriteByte(v)
case rune:
buf.WriteRune(v)
_, _ = buf.WriteRune(v)
default:
panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v))
}
case 's':
switch v := values[fieldnum].(type) {
case []byte:
buf.Write(v)
_, _ = buf.Write(v)
case string:
buf.WriteString(v)
_, _ = buf.WriteString(v)
default:
panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v))
}
Expand Down Expand Up @@ -118,7 +118,7 @@ func (buf *TrackedBuffer) WriteArg(arg string) {
offset: buf.Len(),
length: len(arg),
})
buf.WriteString(arg)
_, _ = buf.WriteString(arg)
}

// ParsedQuery returns a ParsedQuery that contains bind
Expand Down

0 comments on commit 4b9ef3c

Please sign in to comment.