diff --git a/decode.go b/decode.go index a04981f..04e0eff 100644 --- a/decode.go +++ b/decode.go @@ -2,8 +2,10 @@ package logfmt import ( "bufio" + "bytes" "fmt" "io" + "unicode/utf8" ) // A Decoder reads and decodes logfmt records from an input stream. @@ -68,13 +70,19 @@ func (dec *Decoder) ScanKeyval() bool { return false key: - start := dec.pos + const invalidKeyError = "invalid key" + + start, multibyte := dec.pos, false for p, c := range line[dec.pos:] { switch { case c == '=': dec.pos += p if dec.pos > start { dec.key = line[start:dec.pos] + if multibyte && bytes.IndexRune(dec.key, utf8.RuneError) != -1 { + dec.syntaxError(invalidKeyError) + return false + } } if dec.key == nil { dec.unexpectedByte(c) @@ -89,13 +97,23 @@ key: dec.pos += p if dec.pos > start { dec.key = line[start:dec.pos] + if multibyte && bytes.IndexRune(dec.key, utf8.RuneError) != -1 { + dec.syntaxError(invalidKeyError) + return false + } } return true + case c >= utf8.RuneSelf: + multibyte = true } } dec.pos = len(line) if dec.pos > start { dec.key = line[start:dec.pos] + if multibyte && bytes.IndexRune(dec.key, utf8.RuneError) != -1 { + dec.syntaxError(invalidKeyError) + return false + } } return true @@ -186,9 +204,6 @@ func (dec *Decoder) Value() []byte { return dec.value } -// func (dec *Decoder) DecodeValue() ([]byte, error) { -// } - // Err returns the first non-EOF error that was encountered by the Scanner. func (dec *Decoder) Err() error { return dec.err diff --git a/decode_test.go b/decode_test.go index d261880..363152d 100644 --- a/decode_test.go +++ b/decode_test.go @@ -118,6 +118,9 @@ func TestDecoder_errors(t *testing.T) { {"a=\"1\\", &SyntaxError{Msg: "unterminated quoted value", Line: 1, Pos: 6}}, {"a=\"\\t1", &SyntaxError{Msg: "unterminated quoted value", Line: 1, Pos: 7}}, {"a=\"\\u1\"", &SyntaxError{Msg: "invalid quoted value", Line: 1, Pos: 8}}, + {"a\ufffd=bar", &SyntaxError{Msg: "invalid key", Line: 1, Pos: 5}}, + {"\x80=bar", &SyntaxError{Msg: "invalid key", Line: 1, Pos: 2}}, + {"\x80", &SyntaxError{Msg: "invalid key", Line: 1, Pos: 2}}, } for _, test := range tests { diff --git a/encode.go b/encode.go index 4d0fa23..55f1603 100644 --- a/encode.go +++ b/encode.go @@ -8,6 +8,7 @@ import ( "io" "reflect" "strings" + "unicode/utf8" ) // MarshalKeyvals returns the logfmt encoding of keyvals, a variadic sequence @@ -165,11 +166,19 @@ func writeKey(w io.Writer, key interface{}) error { } func invalidKeyRune(r rune) bool { - return r <= ' ' || r == '=' || r == '"' + return r <= ' ' || r == '=' || r == '"' || r == utf8.RuneError +} + +func invalidKeyString(key string) bool { + return len(key) == 0 || strings.IndexFunc(key, invalidKeyRune) != -1 +} + +func invalidKey(key []byte) bool { + return len(key) == 0 || bytes.IndexFunc(key, invalidKeyRune) != -1 } func writeStringKey(w io.Writer, key string) error { - if len(key) == 0 || strings.IndexFunc(key, invalidKeyRune) != -1 { + if invalidKeyString(key) { return ErrInvalidKey } _, err := io.WriteString(w, key) @@ -177,7 +186,7 @@ func writeStringKey(w io.Writer, key string) error { } func writeBytesKey(w io.Writer, key []byte) error { - if len(key) == 0 || bytes.IndexFunc(key, invalidKeyRune) != -1 { + if invalidKey(key) { return ErrInvalidKey } _, err := w.Write(key) @@ -223,7 +232,7 @@ func writeValue(w io.Writer, value interface{}) error { } func needsQuotedValueRune(r rune) bool { - return r <= ' ' || r == '=' || r == '"' + return r <= ' ' || r == '=' || r == '"' || r == utf8.RuneError } func writeStringValue(w io.Writer, value string, ok bool) error { @@ -240,7 +249,7 @@ func writeStringValue(w io.Writer, value string, ok bool) error { func writeBytesValue(w io.Writer, value []byte) error { var err error - if bytes.IndexFunc(value, needsQuotedValueRune) >= 0 { + if bytes.IndexFunc(value, needsQuotedValueRune) != -1 { _, err = writeQuotedBytes(w, value) } else { _, err = w.Write(value) diff --git a/encode_test.go b/encode_test.go index c7aa1d5..ebebaae 100644 --- a/encode_test.go +++ b/encode_test.go @@ -53,6 +53,11 @@ func TestEncodeKeyval(t *testing.T) { {key: decimalStringer{5, 9}, value: "v", want: "5.9=v"}, {key: (*decimalStringer)(nil), value: "v", err: logfmt.ErrNilKey}, {key: marshalerStringer{5, 9}, value: "v", want: "5.9=v"}, + {key: "k", value: "\xbd", want: `k="\ufffd"`}, + {key: "k", value: "\ufffd\x00", want: `k="\ufffd\u0000"`}, + {key: "k", value: "\ufffd", want: `k="\ufffd"`}, + {key: "k", value: []byte("\ufffd\x00"), want: `k="\ufffd\u0000"`}, + {key: "k", value: []byte("\ufffd"), want: `k="\ufffd"`}, } for _, d := range data { @@ -82,6 +87,8 @@ func TestMarshalKeyvals(t *testing.T) { {in: kv(), want: nil}, {in: kv(nil, "v"), err: logfmt.ErrNilKey}, {in: kv(nilPtr, "v"), err: logfmt.ErrNilKey}, + {in: kv("\ufffd"), err: logfmt.ErrInvalidKey}, + {in: kv("\xbd"), err: logfmt.ErrInvalidKey}, {in: kv("k"), want: []byte("k=null")}, {in: kv("k", nil), want: []byte("k=null")}, {in: kv("k", ""), want: []byte("k=")}, @@ -125,8 +132,8 @@ func TestMarshalKeyvals(t *testing.T) { if err != d.err { t.Errorf("%#v: got error: %v, want error: %v", d.in, err, d.err) } - if got, want := got, d.want; !reflect.DeepEqual(got, want) { - t.Errorf("%#v: got '%s', want '%s'", d.in, got, want) + if !reflect.DeepEqual(got, d.want) { + t.Errorf("%#v: got '%s', want '%s'", d.in, got, d.want) } } } @@ -181,12 +188,12 @@ func (t marshalerStringer) String() string { return fmt.Sprint(t.a + t.b) } -var marshalError = errors.New("marshal error") +var errMarshal = errors.New("marshal error") type errorMarshaler struct{} func (errorMarshaler) MarshalText() ([]byte, error) { - return nil, marshalError + return nil, errMarshal } func BenchmarkEncodeKeyval(b *testing.B) { diff --git a/fuzz.go b/fuzz.go index ab916a9..6553b35 100644 --- a/fuzz.go +++ b/fuzz.go @@ -12,21 +12,22 @@ import ( kr "github.com/kr/logfmt" ) +// Fuzz checks reserialized data matches func Fuzz(data []byte) int { parsed, err := parse(data) if err != nil { return 0 } var w1 bytes.Buffer - if err := write(parsed, &w1); err != nil { + if err = write(parsed, &w1); err != nil { panic(err) } - parsed, err = parse(data) + parsed, err = parse(w1.Bytes()) if err != nil { panic(err) } var w2 bytes.Buffer - if err := write(parsed, &w2); err != nil { + if err = write(parsed, &w2); err != nil { panic(err) } if !bytes.Equal(w1.Bytes(), w2.Bytes()) { @@ -35,6 +36,7 @@ func Fuzz(data []byte) int { return 1 } +// FuzzVsKR checks go-logfmt/logfmt against kr/logfmt func FuzzVsKR(data []byte) int { parsed, err := parse(data) parsedKR, errKR := parseKR(data) @@ -71,7 +73,6 @@ func parse(data []byte) ([][]kv, error) { kvs = append(kvs, kv{dec.Key(), dec.Value()}) } got = append(got, kvs) - kvs = nil } return got, dec.Err() } diff --git a/jsonstring.go b/jsonstring.go index 53b6532..030ac85 100644 --- a/jsonstring.go +++ b/jsonstring.go @@ -71,7 +71,7 @@ func writeQuotedString(w io.Writer, s string) (int, error) { continue } c, size := utf8.DecodeRuneInString(s[i:]) - if c == utf8.RuneError && size == 1 { + if c == utf8.RuneError { if start < i { buf.WriteString(s[start:i]) } @@ -129,7 +129,7 @@ func writeQuotedBytes(w io.Writer, s []byte) (int, error) { continue } c, size := utf8.DecodeRune(s[i:]) - if c == utf8.RuneError && size == 1 { + if c == utf8.RuneError { if start < i { buf.Write(s[start:i]) } @@ -182,7 +182,7 @@ func unquoteBytes(s []byte) (t []byte, ok bool) { continue } rr, size := utf8.DecodeRune(s[r:]) - if rr == utf8.RuneError && size == 1 { + if rr == utf8.RuneError { break } r += size