Skip to content

Commit

Permalink
evalengine: Implement REPLACE (#15274)
Browse files Browse the repository at this point in the history
Signed-off-by: Dirkjan Bussink <d.bussink@gmail.com>
  • Loading branch information
dbussink authored Feb 19, 2024
1 parent 9a78e7d commit d131230
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 22 deletions.
12 changes: 12 additions & 0 deletions go/vt/vtgate/evalengine/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions go/vt/vtgate/evalengine/compiler_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -3022,6 +3022,19 @@ func (asm *assembler) Locate2(collation colldata.Collation) {
}, "LOCATE VARCHAR(SP-2), VARCHAR(SP-1) COLLATE '%s'", collation.Name())
}

func (asm *assembler) Replace() {
asm.adjustStack(-2)

asm.emit(func(env *ExpressionEnv) int {
str := env.vm.stack[env.vm.sp-3].(*evalBytes)
from := env.vm.stack[env.vm.sp-2].(*evalBytes)
to := env.vm.stack[env.vm.sp-1].(*evalBytes)
env.vm.sp -= 2
str.bytes = replace(str.bytes, from.bytes, to.bytes)
return 1
}, "REPLACE VARCHAR(SP-3), VARCHAR(SP-2) VARCHAR(SP-1)")
}

func (asm *assembler) Strcmp(collation collations.TypedCollation) {
asm.adjustStack(-1)

Expand Down
4 changes: 4 additions & 0 deletions go/vt/vtgate/evalengine/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,10 @@ func TestCompilerSingle(t *testing.T) {
expression: `locate("", "😊😂🤢", 3)`,
result: `INT64(3)`,
},
{
expression: `REPLACE('www.mysql.com', '', 'Ww')`,
result: `VARCHAR("www.mysql.com")`,
},
}

tz, _ := time.LoadLocation("Europe/Madrid")
Expand Down
171 changes: 149 additions & 22 deletions go/vt/vtgate/evalengine/fn_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,31 @@ type (
CallExpr
collate collations.ID
}

builtinChar struct {
CallExpr
collate collations.ID
}

builtinRepeat struct {
CallExpr
collate collations.ID
}

builtinConcat struct {
CallExpr
collate collations.ID
}

builtinConcatWs struct {
CallExpr
collate collations.ID
}

builtinReplace struct {
CallExpr
collate collations.ID
}
)

var _ IR = (*builtinInsert)(nil)
Expand All @@ -129,7 +154,15 @@ var _ IR = (*builtinCollation)(nil)
var _ IR = (*builtinWeightString)(nil)
var _ IR = (*builtinLeftRight)(nil)
var _ IR = (*builtinPad)(nil)
var _ IR = (*builtinStrcmp)(nil)
var _ IR = (*builtinTrim)(nil)
var _ IR = (*builtinSubstring)(nil)
var _ IR = (*builtinLocate)(nil)
var _ IR = (*builtinChar)(nil)
var _ IR = (*builtinRepeat)(nil)
var _ IR = (*builtinConcat)(nil)
var _ IR = (*builtinConcatWs)(nil)
var _ IR = (*builtinReplace)(nil)

func insert(str, newstr *evalBytes, pos, l int) []byte {
pos--
Expand Down Expand Up @@ -555,11 +588,6 @@ func (call *builtinOrd) compile(c *compiler) (ctype, error) {
// - `> max_allowed_packet`, no error and returns `NULL`.
const maxRepeatLength = 1073741824

type builtinRepeat struct {
CallExpr
collate collations.ID
}

func (call *builtinRepeat) eval(env *ExpressionEnv) (eval, error) {
arg1, arg2, err := call.arg2(env)
if err != nil {
Expand Down Expand Up @@ -1374,11 +1402,6 @@ func (call *builtinLocate) compile(c *compiler) (ctype, error) {
return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagNullable}, nil
}

type builtinConcat struct {
CallExpr
collate collations.ID
}

func concatSQLType(arg sqltypes.Type, tt sqltypes.Type) sqltypes.Type {
if arg == sqltypes.TypeJSON {
return sqltypes.Blob
Expand Down Expand Up @@ -1507,11 +1530,6 @@ func (call *builtinConcat) compile(c *compiler) (ctype, error) {
return ctype{Type: tt, Flag: f, Col: tc}, nil
}

type builtinConcatWs struct {
CallExpr
collate collations.ID
}

func (call *builtinConcatWs) eval(env *ExpressionEnv) (eval, error) {
var ca collationAggregation
tt := sqltypes.VarChar
Expand Down Expand Up @@ -1643,13 +1661,6 @@ func (call *builtinConcatWs) compile(c *compiler) (ctype, error) {
return ctype{Type: tt, Flag: args[0].Flag, Col: tc}, nil
}

type builtinChar struct {
CallExpr
collate collations.ID
}

var _ IR = (*builtinChar)(nil)

func (call *builtinChar) eval(env *ExpressionEnv) (eval, error) {
vals := make([]eval, 0, len(call.Arguments))
for _, arg := range call.Arguments {
Expand Down Expand Up @@ -1726,3 +1737,119 @@ func encodeChar(buf []byte, i uint32) []byte {
}
return buf
}

func (call *builtinReplace) eval(env *ExpressionEnv) (eval, error) {
str, err := call.Arguments[0].eval(env)
if err != nil || str == nil {
return nil, err
}

fromStr, err := call.Arguments[1].eval(env)
if err != nil || fromStr == nil {
return nil, err
}

toStr, err := call.Arguments[2].eval(env)
if err != nil || toStr == nil {
return nil, err
}

if _, ok := str.(*evalBytes); !ok {
str, err = evalToVarchar(str, call.collate, true)
if err != nil {
return nil, err
}
}

col := str.(*evalBytes).col
fromStr, err = evalToVarchar(fromStr, col.Collation, true)
if err != nil {
return nil, err
}

toStr, err = evalToVarchar(toStr, col.Collation, true)
if err != nil {
return nil, err
}

strBytes := str.(*evalBytes).bytes
fromBytes := fromStr.(*evalBytes).bytes
toBytes := toStr.(*evalBytes).bytes

out := replace(strBytes, fromBytes, toBytes)
return newEvalRaw(str.SQLType(), out, col), nil
}

func (call *builtinReplace) compile(c *compiler) (ctype, error) {
str, err := call.Arguments[0].compile(c)
if err != nil {
return ctype{}, err
}

fromStr, err := call.Arguments[1].compile(c)
if err != nil {
return ctype{}, err
}

toStr, err := call.Arguments[2].compile(c)
if err != nil {
return ctype{}, err
}

skip := c.compileNullCheck3(str, fromStr, toStr)
if !str.isTextual() {
c.asm.Convert_xce(3, sqltypes.VarChar, c.collation)
str.Col = collations.TypedCollation{
Collation: c.collation,
Coercibility: collations.CoerceCoercible,
Repertoire: collations.RepertoireASCII,
}
}

fromCharset := colldata.Lookup(fromStr.Col.Collation).Charset()
toCharset := colldata.Lookup(toStr.Col.Collation).Charset()
strCharset := colldata.Lookup(str.Col.Collation).Charset()
if !fromStr.isTextual() || (fromCharset != strCharset && !strCharset.IsSuperset(fromCharset)) {
c.asm.Convert_xce(2, sqltypes.VarChar, str.Col.Collation)
fromStr.Col = collations.TypedCollation{
Collation: str.Col.Collation,
Coercibility: collations.CoerceCoercible,
Repertoire: collations.RepertoireASCII,
}
}

if !toStr.isTextual() || (toCharset != strCharset && !strCharset.IsSuperset(toCharset)) {
c.asm.Convert_xce(1, sqltypes.VarChar, str.Col.Collation)
toStr.Col = collations.TypedCollation{
Collation: str.Col.Collation,
Coercibility: collations.CoerceCoercible,
Repertoire: collations.RepertoireASCII,
}
}

c.asm.Replace()
c.asm.jumpDestination(skip)
return ctype{Type: sqltypes.VarChar, Col: str.Col, Flag: flagNullable}, nil
}

func replace(str, from, to []byte) []byte {
if len(from) == 0 {
return str
}
n := bytes.Count(str, from)
if n == 0 {
return str
}

out := make([]byte, len(str)+n*(len(to)-len(from)))
end := 0
start := 0
for i := 0; i < n; i++ {
pos := start + bytes.Index(str[start:], from)
end += copy(out[end:], str[start:pos])
end += copy(out[end:], to)
start = pos + len(from)
}
end += copy(out[end:], str[start:])
return out[0:end]
}
28 changes: 28 additions & 0 deletions go/vt/vtgate/evalengine/testcases/cases.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ var Cases = []TestCase{
{Run: FnTrim},
{Run: FnSubstr},
{Run: FnLocate},
{Run: FnReplace},
{Run: FnConcat},
{Run: FnConcatWs},
{Run: FnChar},
Expand Down Expand Up @@ -1577,6 +1578,33 @@ func FnLocate(yield Query) {
}
}

func FnReplace(yield Query) {
cases := []string{
`REPLACE('www.mysql.com', 'w', 'Ww')`,
// MySQL doesn't do collation matching for replace, only
// byte equivalence, but make sure to check.
`REPLACE('straße', 'ss', 'b')`,
`REPLACE('straße', 'ß', 'b')`,
// From / to strings are converted into the collation of
// the input string.
`REPLACE('fooÿbar', _latin1 0xFF, _latin1 0xFE)`,
// First occurence is replaced
`replace('fff', 'ff', 'gg')`,
}

for _, q := range cases {
yield(q, nil)
}

for _, substr := range inputStrings {
for _, str := range inputStrings {
for _, i := range inputStrings {
yield(fmt.Sprintf("REPLACE(%s, %s, %s)", substr, str, i), nil)
}
}
}
}

func FnConcat(yield Query) {
for _, str := range inputStrings {
yield(fmt.Sprintf("CONCAT(%s)", str), nil)
Expand Down
5 changes: 5 additions & 0 deletions go/vt/vtgate/evalengine/translate_builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,11 @@ func (ast *astCompiler) translateFuncExpr(fn *sqlparser.FuncExpr) (IR, error) {
}
call = CallExpr{Arguments: []IR{call.Arguments[1], call.Arguments[0]}, Method: method}
return &builtinLocate{CallExpr: call, collate: ast.cfg.Collation}, nil
case "replace":
if len(args) != 3 {
return nil, argError(method)
}
return &builtinReplace{CallExpr: call, collate: ast.cfg.Collation}, nil
default:
return nil, translateExprNotSupported(fn)
}
Expand Down

0 comments on commit d131230

Please sign in to comment.