diff --git a/go/vt/sqlparser/tracked_buffer.go b/go/vt/sqlparser/tracked_buffer.go index aec206f3b3d..48efe9547af 100644 --- a/go/vt/sqlparser/tracked_buffer.go +++ b/go/vt/sqlparser/tracked_buffer.go @@ -18,6 +18,7 @@ package sqlparser import ( "fmt" + "strconv" "strings" ) @@ -211,7 +212,32 @@ func (buf *TrackedBuffer) astPrintf(currentNode SQLNode, format string, values . } } case 'd': - buf.WriteString(fmt.Sprintf("%d", values[fieldnum])) + switch v := values[fieldnum].(type) { + case int: + buf.WriteInt(int64(v)) + case int8: + buf.WriteInt(int64(v)) + case int16: + buf.WriteInt(int64(v)) + case int32: + buf.WriteInt(int64(v)) + case int64: + buf.WriteInt(v) + case uint: + buf.WriteUint(uint64(v)) + case uint8: + buf.WriteUint(uint64(v)) + case uint16: + buf.WriteUint(uint64(v)) + case uint32: + buf.WriteUint(uint64(v)) + case uint64: + buf.WriteUint(v) + case uintptr: + buf.WriteUint(uint64(v)) + default: + panic(fmt.Sprintf("unexepcted TrackedBuffer type %T", v)) + } case 'a': buf.WriteArg("", values[fieldnum].(string)) default: @@ -288,14 +314,26 @@ func areBothISExpr(op Expr, val Expr) bool { // WriteArg writes a value argument into the buffer along with // tracking information for future substitutions. func (buf *TrackedBuffer) WriteArg(prefix, arg string) { + length := len(prefix) + len(arg) buf.bindLocations = append(buf.bindLocations, BindLocation{ Offset: buf.Len(), - Length: len(prefix) + len(arg), + Length: length, }) + buf.Grow(length) buf.WriteString(prefix) buf.WriteString(arg) } +// WriteInt writes a signed integer into the buffer. +func (buf *TrackedBuffer) WriteInt(v int64) { + buf.WriteString(strconv.FormatInt(v, 10)) +} + +// WriteUint writes an unsigned integer into the buffer. +func (buf *TrackedBuffer) WriteUint(v uint64) { + buf.WriteString(strconv.FormatUint(v, 10)) +} + // ParsedQuery returns a ParsedQuery that contains bind // locations for easy substitution. func (buf *TrackedBuffer) ParsedQuery() *ParsedQuery { @@ -335,7 +373,6 @@ func UnescapedString(node SQLNode) string { buf.SetEscapeNoIdentifier() node.Format(buf) return buf.String() - } // CanonicalString returns a canonical string representation of an SQLNode where all identifiers diff --git a/go/vt/sqlparser/tracked_buffer_test.go b/go/vt/sqlparser/tracked_buffer_test.go index 4dff65634e8..96f2174481e 100644 --- a/go/vt/sqlparser/tracked_buffer_test.go +++ b/go/vt/sqlparser/tracked_buffer_test.go @@ -295,3 +295,35 @@ func TestCanonicalOutput(t *testing.T) { }) } } + +func TestTrackedBufferMyprintf(t *testing.T) { + testcases := []struct { + input string + output string + args []any + }{ + { + input: "nothing", + output: "nothing", + args: []any{}, + }, + { + input: "my name is %s", + output: "my name is Homer", + args: []any{"Homer"}, + }, + { + input: "%d %d %d %d %d %d %d %d %d %d %d", + output: "1 2 3 4 5 6 7 8 9 10 11", + args: []any{int(1), int8(2), int16(3), int32(4), int64(5), uint(6), uint8(7), uint16(8), uint32(9), uint64(10), uintptr(11)}, + }, + } + for _, tc := range testcases { + t.Run(tc.input, func(t *testing.T) { + buf := NewTrackedBuffer(nil) + buf.Myprintf(tc.input, tc.args...) + got := buf.String() + assert.Equal(t, tc.output, got) + }) + } +}