diff --git a/builq.go b/builq.go index a457981..bc16773 100644 --- a/builq.go +++ b/builq.go @@ -140,17 +140,17 @@ var ( // errUnsupportedVerb when %X is found and X isn't supported. errUnsupportedVerb = errors.New("unsupported verb") - // errLonelyVerb when % is found without a verb. - errLonelyVerb = errors.New("lonely verb") + // errLonelyModifier when % is found without a verb. + errLonelyModifier = errors.New("lonely modifier without a verb") // errIncorrectVerb is passed like `%+`. errIncorrectVerb = errors.New("incorrect verb") // errMixedPlaceholders when $ AND ? are mixed in 1 query. - errMixedPlaceholders = errors.New("mixed placeholders must not be used in a single query") + errMixedPlaceholders = errors.New("mixed placeholders in a single query") // errNonSliceArgument when a non-slice argument passed to placeholder with `+` or `#`. - errNonSliceArgument = errors.New("non-slice arguments must not be used with slice modifiers") + errNonSliceArgument = errors.New("non-slice arguments with slice modifiers") // errNonNumericArg expected number for %d but got something else. errNonNumericArg = errors.New("expected numeric argument") diff --git a/builq_test.go b/builq_test.go index 11442ca..6dd2cb4 100644 --- a/builq_test.go +++ b/builq_test.go @@ -38,17 +38,18 @@ func TestBuilder(t *testing.T) { }) } - test("bad verb", errUnsupportedVerb, "SELECT * FROM %+ slice", 1) - test("bad verb", errUnsupportedVerb, "SELECT * FROM %+ slice", 1) + test("incorrect verb", errIncorrectVerb, "SELECT * FROM %+ slice", 1) + test("incorrect verb", errIncorrectVerb, "SELECT * FROM %+ slice", 1) test("incorrect verb 1", errIncorrectVerb, "SELECT * FROM %+", 1) test("incorrect verb 2", errIncorrectVerb, "SELECT * FROM %#", 1) - test("lonely verb", errLonelyVerb, "SELECT * FROM % super") - test("lonely verb", errLonelyVerb, "SELECT foo FROM bar%") + test("lonely verb", errLonelyModifier, "SELECT * FROM % super") + test("lonely verb", errLonelyModifier, "SELECT foo FROM bar%") test("too few arguments", errTooFewArguments, "SELECT foo FROM bar%d") test("too few arguments", errTooFewArguments, "SELECT * FROM %s") test("too few arguments", errTooFewArguments, "SELECT * FROM %$") test("too few arguments", errTooFewArguments, "SELECT * FROM %+?") test("too many arguments", errTooManyArguments, "SELECT * FROM %s", "users", "users") + test("unsupported verb", errUnsupportedVerb, "SELECT * FROM %+x") test("unsupported verb", errUnsupportedVerb, "SELECT * FROM %v", "users") test("mixed placeholders", errMixedPlaceholders, "WHERE foo = %$ AND bar = %?", 1, 2) test("non-slice argument", errNonSliceArgument, "WHERE foo = %+$", 1) @@ -79,7 +80,7 @@ func FuzzBuilder(f *testing.F) { if errors.Is(err, errTooFewArguments) || errors.Is(err, errTooManyArguments) || errors.Is(err, errUnsupportedVerb) || - errors.Is(err, errLonelyVerb) || + errors.Is(err, errLonelyModifier) || errors.Is(err, errIncorrectVerb) || errors.Is(err, errMixedPlaceholders) || errors.Is(err, errNonSliceArgument) || diff --git a/write.go b/write.go index 86f341e..7fbbaca 100644 --- a/write.go +++ b/write.go @@ -12,26 +12,27 @@ func (b *Builder) write(sb *strings.Builder, resArgs *[]any, s string, args ...a for argID := 0; ; argID++ { idx := strings.IndexByte(s, '%') if idx == -1 { - if argID != len(args) { - b.setErr(errTooManyArguments) + var err error + if len(args) != argID { + err = fmt.Errorf("%w: have %d args, expected %d", errTooManyArguments, len(args), argID) } sb.WriteString(s) sb.WriteByte(b.sep) - return nil + return err } sb.WriteString(s[:idx]) s = s[idx+1:] // skip '%' if len(s) == 0 { - return errLonelyVerb + return errLonelyModifier } switch verb := s[0]; verb { case '$', '?', 's', 'd': if argID >= len(args) { - return errTooFewArguments + return fmt.Errorf("%w: have %d args, want %d", errTooFewArguments, len(args), argID+1) } arg := args[argID] @@ -41,15 +42,14 @@ func (b *Builder) write(sb *strings.Builder, resArgs *[]any, s string, args ...a case '+', '#': isBatch := verb == '#' s = s[1:] - if len(s) < 1 { - b.setErr(errIncorrectVerb) - continue + if len(s) < 1 || s[0] == ' ' { + return fmt.Errorf("%w: '%c' requires additional '$' or '?'", errIncorrectVerb, verb) } switch verb := s[0]; verb { case '$', '?': if argID >= len(args) { - return errTooFewArguments + return fmt.Errorf("%w: have %d args, want %d", errTooFewArguments, len(args), argID+1) } arg := args[argID] @@ -60,8 +60,9 @@ func (b *Builder) write(sb *strings.Builder, resArgs *[]any, s string, args ...a } else { b.writeSlice(sb, resArgs, verb, arg) } + default: - b.setErr(errUnsupportedVerb) + return fmt.Errorf("%w: '%c' is not supported", errUnsupportedVerb, verb) } case '%': @@ -70,10 +71,10 @@ func (b *Builder) write(sb *strings.Builder, resArgs *[]any, s string, args ...a sb.WriteByte('%') case ' ': - b.setErr(errLonelyVerb) + return errLonelyModifier default: - b.setErr(errUnsupportedVerb) + return fmt.Errorf("%w: '%c' is not supported", errUnsupportedVerb, verb) } } } @@ -136,12 +137,10 @@ func (b *Builder) writeArg(sb *strings.Builder, resArgs *[]any, verb byte, arg a return } - // store the first placeholder used in the query - // to check for mixed placeholders later - if b.placeholder == 0 { + switch { + case b.placeholder == 0: b.placeholder = verb - } - if b.placeholder != verb { + case b.placeholder != verb: b.setErr(errMixedPlaceholders) } } @@ -188,12 +187,6 @@ func (b *Builder) asSlice(v any) []any { return res } -func (b *Builder) setErr(err error) { - if b.err == nil { - b.err = err - } -} - func (b *Builder) assertNumber(v any) { switch v.(type) { case int, int8, int16, int32, int64, @@ -203,3 +196,9 @@ func (b *Builder) assertNumber(v any) { b.setErr(errNonNumericArg) } } + +func (b *Builder) setErr(err error) { + if b.err == nil { + b.err = err + } +}