diff --git a/bson/value_reader.go b/bson/value_reader.go index 195b30add8..678c47b106 100644 --- a/bson/value_reader.go +++ b/bson/value_reader.go @@ -7,7 +7,7 @@ package bson import ( - "bytes" + "bufio" "encoding/binary" "errors" "fmt" @@ -31,11 +31,8 @@ type vrState struct { // valueReader is for reading BSON values. type valueReader struct { + r *bufio.Reader offset int64 - d []byte - - readerErr error - r io.Reader stack []vrState frame int64 @@ -65,59 +62,11 @@ func newDocumentReader(r io.Reader) *valueReader { mode: mTopLevel, } return &valueReader{ - r: r, + r: bufio.NewReader(r), stack: stack, } } -func (vr *valueReader) prepload(length int32) (int32, error) { - const chunkSize = 512 - - if vr.offset+int64(length) <= int64(len(vr.d)) { - return length, nil - } - - if vr.readerErr != nil { - return 0, vr.readerErr - } - if vr.r == nil { - vr.readerErr = io.EOF - return 0, vr.readerErr - } - - size := len(vr.d) - var need int64 = chunkSize - if l := int64(length) + vr.offset - int64(size); l > need { - need = l - } - buf := make([]byte, need) - n, err := vr.r.Read(buf) - if err != nil { - vr.readerErr = err - } - vr.d = append(vr.d, buf[0:n]...) - if l := int64(n+size) - vr.offset; l < int64(length) { - length = int32(l) - } - return length, err -} - -func (vr *valueReader) indexByteAfter(offset int64, c byte) (int64, error) { - const chunkSize = 512 - for idx := -1; idx < 0; { - idx = bytes.IndexByte(vr.d[offset:], c) - if idx < 0 { - n, err := vr.prepload(chunkSize) - if n == 0 && err != nil { - return 0, err - } - } else { - offset += int64(idx) - } - } - return offset, nil -} - func (vr *valueReader) advanceFrame() { if vr.frame+1 >= int64(len(vr.stack)) { // We need to grow the stack length := len(vr.stack) @@ -142,11 +91,11 @@ func (vr *valueReader) pushDocument() error { vr.stack[vr.frame].mode = mDocument - size, err := vr.readLength() + length, err := vr.readLength() if err != nil { return err } - vr.stack[vr.frame].end = int64(size) + vr.offset - 4 + vr.stack[vr.frame].end = int64(length) + vr.offset - 4 return nil } @@ -156,11 +105,11 @@ func (vr *valueReader) pushArray() error { vr.stack[vr.frame].mode = mArray - size, err := vr.readLength() + length, err := vr.readLength() if err != nil { return err } - vr.stack[vr.frame].end = int64(size) + vr.offset - 4 + vr.stack[vr.frame].end = int64(length) + vr.offset - 4 return nil } @@ -184,30 +133,41 @@ func (vr *valueReader) pushCodeWithScope() (int64, error) { vr.stack[vr.frame].mode = mCodeWithScope - size, err := vr.readLength() + length, err := vr.readLength() if err != nil { return 0, err } - vr.stack[vr.frame].end = int64(size) + vr.offset - 4 + vr.stack[vr.frame].end = int64(length) + vr.offset - 4 - return int64(size), nil + return int64(length), nil } -func (vr *valueReader) pop() { +func (vr *valueReader) pop() error { + var cnt int switch vr.stack[vr.frame].mode { case mElement, mValue: - vr.frame-- + cnt = 1 case mDocument, mArray, mCodeWithScope: - vr.frame -= 2 // we pop twice to jump over the vrElement: vrDocument -> vrElement -> vrDocument/TopLevel/etc... + cnt = 2 // we pop twice to jump over the vrElement: vrDocument -> vrElement -> vrDocument/TopLevel/etc... } - if vr.frame < 0 { - vr.frame = 0 + for i := 0; i < cnt && vr.frame > 0; i++ { + if vr.offset < vr.stack[vr.frame].end { + _, err := vr.r.Discard(int(vr.stack[vr.frame].end - vr.offset)) + if err != nil { + return err + } + } + vr.frame-- } - if vr.frame == 0 && (vr.stack[vr.frame].end <= vr.offset) { - vr.d = vr.d[vr.stack[vr.frame].end:] - vr.offset -= vr.stack[vr.frame].end - vr.stack[vr.frame].end = 0 + if vr.frame == 0 { + if vr.stack[0].end > vr.offset { + vr.stack[0].end -= vr.offset + } else { + vr.stack[0].end = 0 + } + vr.offset = 0 } + return nil } func (vr *valueReader) invalidTransitionErr(destination mode, name string, modes []mode) error { @@ -249,7 +209,7 @@ func (vr *valueReader) Type() Type { return vr.stack[vr.frame].vType } -func (vr *valueReader) nextElementLength() (int32, error) { +func (vr *valueReader) appendNextElement(dst []byte) ([]byte, error) { var length int32 var err error switch vr.stack[vr.frame].vType { @@ -277,21 +237,30 @@ func (vr *valueReader) nextElementLength() (int32, error) { case TypeObjectID: length = 12 case TypeRegex: - offset := vr.offset for n := 0; n < 2; n++ { // Read two C strings. - var err error - offset, err = vr.indexByteAfter(offset, 0x00) + str, err := vr.r.ReadBytes(0x00) if err != nil { - return 0, err + return nil, err } - offset++ // add 0x00 + dst = append(dst, str...) + vr.offset += int64(len(str)) } - length = int32(offset - vr.offset) + return dst, nil default: - return 0, fmt.Errorf("attempted to read bytes of unknown BSON type %v", vr.stack[vr.frame].vType) + return nil, fmt.Errorf("attempted to read bytes of unknown BSON type %v", vr.stack[vr.frame].vType) + } + if err != nil { + return nil, err } - return length, err + buf := make([]byte, length) + _, err = io.ReadFull(vr.r, buf) + if err != nil { + return nil, err + } + dst = append(dst, buf...) + vr.offset += int64(len(buf)) + return dst, err } func (vr *valueReader) readValueBytes(dst []byte) (Type, []byte, error) { @@ -307,15 +276,17 @@ func (vr *valueReader) readValueBytes(dst []byte) (Type, []byte, error) { } return Type(0), dst, nil case mElement, mValue: - length, err := vr.nextElementLength() + dst, err := vr.appendNextElement(dst) if err != nil { return Type(0), dst, err } - dst, err = vr.appendBytes(dst, length) t := vr.stack[vr.frame].vType - vr.pop() - return t, dst, err + err = vr.pop() + if err != nil { + return Type(0), nil, err + } + return t, dst, nil default: return Type(0), nil, vr.invalidTransitionErr(0, "ReadValueBytes", []mode{mElement, mValue}) } @@ -328,14 +299,12 @@ func (vr *valueReader) Skip() error { return vr.invalidTransitionErr(0, "Skip", []mode{mElement, mValue}) } - length, err := vr.nextElementLength() + _, err := vr.appendNextElement(nil) if err != nil { return err } - err = vr.skipBytes(length) - vr.pop() - return err + return vr.pop() } func (vr *valueReader) ReadArray() (ArrayReader, error) { @@ -374,17 +343,16 @@ func (vr *valueReader) ReadBinary() (b []byte, btype byte, err error) { } } - b, err = vr.readBytes(length) + b = make([]byte, length) + err = vr.read(b) if err != nil { return nil, 0, err } - // Make a copy of the returned byte slice because it's just a subslice from the valueReader's - // buffer and is not safe to return in the unmarshaled value. - cp := make([]byte, len(b)) - copy(cp, b) - vr.pop() - return cp, btype, nil + if err := vr.pop(); err != nil { + return nil, 0, err + } + return b, btype, nil } func (vr *valueReader) ReadBoolean() (bool, error) { @@ -401,7 +369,9 @@ func (vr *valueReader) ReadBoolean() (bool, error) { return false, fmt.Errorf("invalid byte for boolean, %b", b) } - vr.pop() + if err := vr.pop(); err != nil { + return false, err + } return b == 1, nil } @@ -415,16 +385,8 @@ func (vr *valueReader) ReadDocument() (DocumentReader, error) { if length <= 4 { return nil, fmt.Errorf("invalid string length: %d", length) } - length -= 4 - n, err := vr.prepload(length) - if n < length || err != nil { - if err == nil { - err = io.EOF - } - return nil, err - } - vr.stack[vr.frame].end = int64(length) + vr.offset + vr.stack[vr.frame].end = int64(length) + vr.offset - 4 return vr, nil case mElement, mValue: if vr.stack[vr.frame].vType != TypeEmbeddedDocument { @@ -458,7 +420,8 @@ func (vr *valueReader) ReadCodeWithScope() (code string, dr DocumentReader, err if strLength <= 0 { return "", nil, fmt.Errorf("invalid string length: %d", strLength) } - strBytes, err := vr.readBytes(strLength) + strBytes := make([]byte, strLength) + err = vr.read(strBytes) if err != nil { return "", nil, err } @@ -491,14 +454,14 @@ func (vr *valueReader) ReadDBPointer() (ns string, oid ObjectID, err error) { return "", oid, err } - oidbytes, err := vr.readBytes(12) + err = vr.read(oid[:]) if err != nil { - return "", oid, err + return "", ObjectID{}, err } - copy(oid[:], oidbytes) - - vr.pop() + if err := vr.pop(); err != nil { + return "", ObjectID{}, err + } return ns, oid, nil } @@ -512,7 +475,9 @@ func (vr *valueReader) ReadDateTime() (int64, error) { return 0, err } - vr.pop() + if err := vr.pop(); err != nil { + return 0, err + } return i, nil } @@ -521,7 +486,8 @@ func (vr *valueReader) ReadDecimal128() (Decimal128, error) { return Decimal128{}, err } - b, err := vr.readBytes(16) + var b [16]byte + err := vr.read(b[:]) if err != nil { return Decimal128{}, err } @@ -529,7 +495,9 @@ func (vr *valueReader) ReadDecimal128() (Decimal128, error) { l := binary.LittleEndian.Uint64(b[0:8]) h := binary.LittleEndian.Uint64(b[8:16]) - vr.pop() + if err := vr.pop(); err != nil { + return Decimal128{}, err + } return NewDecimal128(h, l), nil } @@ -543,7 +511,9 @@ func (vr *valueReader) ReadDouble() (float64, error) { return 0, err } - vr.pop() + if err := vr.pop(); err != nil { + return 0, err + } return math.Float64frombits(u), nil } @@ -552,7 +522,9 @@ func (vr *valueReader) ReadInt32() (int32, error) { return 0, err } - vr.pop() + if err := vr.pop(); err != nil { + return 0, err + } return vr.readi32() } @@ -561,16 +533,20 @@ func (vr *valueReader) ReadInt64() (int64, error) { return 0, err } - vr.pop() + if err := vr.pop(); err != nil { + return 0, err + } return vr.readi64() } -func (vr *valueReader) ReadJavascript() (code string, err error) { +func (vr *valueReader) ReadJavascript() (string, error) { if err := vr.ensureElementValue(TypeJavaScript, 0, "ReadJavascript"); err != nil { return "", err } - vr.pop() + if err := vr.pop(); err != nil { + return "", err + } return vr.readString() } @@ -579,8 +555,7 @@ func (vr *valueReader) ReadMaxKey() error { return err } - vr.pop() - return nil + return vr.pop() } func (vr *valueReader) ReadMinKey() error { @@ -588,8 +563,7 @@ func (vr *valueReader) ReadMinKey() error { return err } - vr.pop() - return nil + return vr.pop() } func (vr *valueReader) ReadNull() error { @@ -597,8 +571,7 @@ func (vr *valueReader) ReadNull() error { return err } - vr.pop() - return nil + return vr.pop() } func (vr *valueReader) ReadObjectID() (ObjectID, error) { @@ -606,15 +579,15 @@ func (vr *valueReader) ReadObjectID() (ObjectID, error) { return ObjectID{}, err } - oidbytes, err := vr.readBytes(12) + var oid ObjectID + err := vr.read(oid[:]) if err != nil { return ObjectID{}, err } - var oid ObjectID - copy(oid[:], oidbytes) - - vr.pop() + if err := vr.pop(); err != nil { + return ObjectID{}, err + } return oid, nil } @@ -633,7 +606,9 @@ func (vr *valueReader) ReadRegex() (string, string, error) { return "", "", err } - vr.pop() + if err := vr.pop(); err != nil { + return "", "", err + } return pattern, options, nil } @@ -642,16 +617,20 @@ func (vr *valueReader) ReadString() (string, error) { return "", err } - vr.pop() + if err := vr.pop(); err != nil { + return "", err + } return vr.readString() } -func (vr *valueReader) ReadSymbol() (symbol string, err error) { +func (vr *valueReader) ReadSymbol() (string, error) { if err := vr.ensureElementValue(TypeSymbol, 0, "ReadSymbol"); err != nil { return "", err } - vr.pop() + if err := vr.pop(); err != nil { + return "", err + } return vr.readString() } @@ -670,7 +649,9 @@ func (vr *valueReader) ReadTimestamp() (t uint32, i uint32, err error) { return 0, 0, err } - vr.pop() + if err := vr.pop(); err != nil { + return 0, 0, err + } return t, i, nil } @@ -679,8 +660,7 @@ func (vr *valueReader) ReadUndefined() error { return err } - vr.pop() - return nil + return vr.pop() } func (vr *valueReader) ReadElement() (string, ValueReader, error) { @@ -700,7 +680,7 @@ func (vr *valueReader) ReadElement() (string, ValueReader, error) { return "", nil, vr.invalidDocumentLengthError() } - vr.pop() + _ = vr.pop() // Ignore the error because the call here never reads from the underlying reader. return "", nil, ErrEOD } @@ -730,11 +710,11 @@ func (vr *valueReader) ReadValue() (ValueReader, error) { return nil, vr.invalidDocumentLengthError() } - vr.pop() + _ = vr.pop() // Ignore the error because the call here never reads from the underlying reader. return nil, ErrEOA } - if err := vr.skipCString(); err != nil { + if _, err := vr.readCString(); err != nil { return nil, err } @@ -742,86 +722,41 @@ func (vr *valueReader) ReadValue() (ValueReader, error) { return vr, nil } -// readBytes reads length bytes from the valueReader starting at the current offset. Note that the -// returned byte slice is a subslice from the valueReader buffer and must be converted or copied -// before returning in an unmarshaled value. -func (vr *valueReader) readBytes(length int32) ([]byte, error) { - if length < 0 { - return nil, fmt.Errorf("invalid length: %d", length) - } - - n, err := vr.prepload(length) - if n < length || err != nil { - if err == nil { - err = io.EOF - } - return nil, err +func (vr *valueReader) read(p []byte) error { + n, err := io.ReadFull(vr.r, p) + if err != nil { + return err } - - start := vr.offset - vr.offset += int64(length) - - return vr.d[start : start+int64(length)], nil + vr.offset += int64(n) + return nil } func (vr *valueReader) appendBytes(dst []byte, length int32) ([]byte, error) { - n, err := vr.prepload(length) - if n < length || err != nil { - if err == nil { - err = io.EOF - } + buf := make([]byte, length) + err := vr.read(buf) + if err != nil { return nil, err } - start := vr.offset - vr.offset += int64(length) - return append(dst, vr.d[start:start+int64(length)]...), nil -} - -func (vr *valueReader) skipBytes(length int32) error { - if length < 0 { - return fmt.Errorf("invalid length: %d", length) - } - - n, err := vr.prepload(length) - if n < length || err != nil { - if err == nil { - err = io.EOF - } - return err - } - vr.offset += int64(length) - return nil + return append(dst, buf...), nil } func (vr *valueReader) readByte() (byte, error) { - n, err := vr.prepload(1) - if n < 1 || err != nil { - if err == nil { - err = io.EOF - } + b, err := vr.r.ReadByte() + if err != nil { return 0x0, err } vr.offset++ - return vr.d[vr.offset-1], nil -} - -func (vr *valueReader) skipCString() error { - offset, err := vr.indexByteAfter(vr.offset, 0x00) - if err != nil { - return err - } - vr.offset = offset + 1 - return nil + return b, nil } func (vr *valueReader) readCString() (string, error) { - offset, err := vr.indexByteAfter(vr.offset, 0x00) + str, err := vr.r.ReadString(0x00) if err != nil { return "", err } - start := vr.offset - vr.offset = offset + 1 - return string(vr.d[start : vr.offset-1]), nil + l := len(str) + vr.offset += int64(l) + return str[:l-1], nil } func (vr *valueReader) readString() (string, error) { @@ -833,90 +768,75 @@ func (vr *valueReader) readString() (string, error) { return "", fmt.Errorf("invalid string length: %d", length) } - n, err := vr.prepload(length) - if n < length || err != nil { - if err == nil { - err = io.EOF - } + buf := make([]byte, length) + err = vr.read(buf) + if err != nil { return "", err } - if vr.d[vr.offset+int64(length)-1] != 0x00 { - return "", fmt.Errorf("string does not end with null byte, but with %v", vr.d[vr.offset+int64(length)-1]) + if buf[length-1] != 0x00 { + return "", fmt.Errorf("string does not end with null byte, but with %v", buf[length-1]) } - start := vr.offset - vr.offset += int64(length) - return string(vr.d[start : start+int64(length)-1]), nil + return string(buf[:length-1]), nil } func (vr *valueReader) peekLength() (int32, error) { - n, err := vr.prepload(4) - if n < 4 || err != nil { - if err == nil { - err = io.EOF - } + buf, err := vr.r.Peek(4) + if err != nil { return 0, err } - idx := vr.offset - return int32(binary.LittleEndian.Uint32(vr.d[idx:])), nil + return int32(binary.LittleEndian.Uint32(buf)), nil } -func (vr *valueReader) readLength() (int32, error) { return vr.readi32() } +func (vr *valueReader) readLength() (int32, error) { + l, err := vr.readi32() + if err != nil { + return 0, err + } + if l < 0 { + return 0, fmt.Errorf("invalid negative length: %d", l) + } + return l, nil +} func (vr *valueReader) readi32() (int32, error) { - n, err := vr.prepload(4) - if n < 4 || err != nil { - if err == nil { - err = io.EOF - } + var buf [4]byte + err := vr.read(buf[:]) + if err != nil { return 0, err } - idx := vr.offset - vr.offset += 4 - return int32(binary.LittleEndian.Uint32(vr.d[idx:])), nil + return int32(binary.LittleEndian.Uint32(buf[:])), nil } func (vr *valueReader) readu32() (uint32, error) { - n, err := vr.prepload(4) - if n < 4 || err != nil { - if err == nil { - err = io.EOF - } + var buf [4]byte + err := vr.read(buf[:]) + if err != nil { return 0, err } - idx := vr.offset - vr.offset += 4 - return binary.LittleEndian.Uint32(vr.d[idx:]), nil + return binary.LittleEndian.Uint32(buf[:]), nil } func (vr *valueReader) readi64() (int64, error) { - n, err := vr.prepload(8) - if n < 8 || err != nil { - if err == nil { - err = io.EOF - } + var buf [8]byte + err := vr.read(buf[:]) + if err != nil { return 0, err } - idx := vr.offset - vr.offset += 8 - return int64(binary.LittleEndian.Uint64(vr.d[idx:])), nil + return int64(binary.LittleEndian.Uint64(buf[:])), nil } func (vr *valueReader) readu64() (uint64, error) { - n, err := vr.prepload(8) - if n < 8 || err != nil { - if err == nil { - err = io.EOF - } + var buf [8]byte + err := vr.read(buf[:]) + if err != nil { return 0, err } - idx := vr.offset - vr.offset += 8 - return binary.LittleEndian.Uint64(vr.d[idx:]), nil + return binary.LittleEndian.Uint64(buf[:]), nil } diff --git a/bson/value_reader_test.go b/bson/value_reader_test.go index f05b675a32..852a291dd4 100644 --- a/bson/value_reader_test.go +++ b/bson/value_reader_test.go @@ -7,6 +7,7 @@ package bson import ( + "bufio" "bytes" _ "embed" "errors" @@ -80,7 +81,7 @@ func TestValueReader(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { vr := &valueReader{ - r: bytes.NewReader(tc.data), + r: bufio.NewReader(bytes.NewReader(tc.data)), stack: []vrState{ {mode: mTopLevel}, { @@ -145,7 +146,7 @@ func TestValueReader(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { vr := &valueReader{ - r: bytes.NewReader(tc.data), + r: bufio.NewReader(bytes.NewReader(tc.data)), stack: []vrState{ {mode: mTopLevel}, { @@ -175,18 +176,16 @@ func TestValueReader(t *testing.T) { } // invalid length - vr.r = bytes.NewReader([]byte{0x00, 0x00}) + vr.r = bufio.NewReader(bytes.NewReader([]byte{0x00, 0x00})) _, err := vr.ReadDocument() - if !errors.Is(err, io.EOF) { - t.Errorf("Expected io.EOF with document length too small. got %v; want %v", err, io.EOF) + if !errors.Is(err, io.ErrUnexpectedEOF) { + t.Errorf("Expected io.ErrUnexpectedEOF with document length too small. got %v; want %v", err, io.EOF) } if vr.offset != 0 { t.Errorf("Expected 0 offset. got %d", vr.offset) } - vr.r = bytes.NewReader(doc) - vr.d = vr.d[:0] - vr.readerErr = nil + vr.r = bufio.NewReader(bytes.NewReader(doc)) _, err = vr.ReadDocument() noerr(t, err) if vr.stack[vr.frame].end != 5 { @@ -216,9 +215,8 @@ func TestValueReader(t *testing.T) { } vr.stack[1].mode, vr.stack[1].vType = mElement, TypeEmbeddedDocument - vr.d = []byte{0x0A, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00} vr.offset = 4 - vr.r = bytes.NewReader([]byte{}) + vr.r = bufio.NewReader(bytes.NewReader([]byte{0x05, 0x00, 0x00, 0x00, 0x00, 0x00})) _, err = vr.ReadDocument() noerr(t, err) if len(vr.stack) != 3 { @@ -236,12 +234,12 @@ func TestValueReader(t *testing.T) { vr.frame-- _, err = vr.ReadDocument() - if !errors.Is(err, io.EOF) { + if !errors.Is(err, io.ErrUnexpectedEOF) { t.Errorf("Should return error when attempting to read length with not enough bytes. got %v; want %v", err, io.EOF) } }) }) - t.Run("ReadBinary", func(t *testing.T) { + t.Run("ReadCodeWithScope", func(t *testing.T) { codeWithScope := []byte{ 0x11, 0x00, 0x00, 0x00, // total length 0x4, 0x00, 0x00, 0x00, // string length @@ -314,7 +312,7 @@ func TestValueReader(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { vr := &valueReader{ - r: bytes.NewReader(tc.data), + r: bufio.NewReader(bytes.NewReader(tc.data)), stack: []vrState{ {mode: mTopLevel}, { @@ -333,12 +331,9 @@ func TestValueReader(t *testing.T) { } t.Run("success", func(t *testing.T) { - doc := []byte{0x00, 0x00, 0x00, 0x00} - doc = append(doc, codeWithScope...) - doc = append(doc, 0x00) vr := &valueReader{ offset: 4, - d: doc, + r: bufio.NewReader(bytes.NewReader(codeWithScope)), stack: []vrState{ {mode: mTopLevel}, {mode: mElement, vType: TypeCodeWithScope}, @@ -422,7 +417,7 @@ func TestValueReader(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { vr := &valueReader{ - r: bytes.NewReader(tc.data), + r: bufio.NewReader(bytes.NewReader(tc.data)), stack: []vrState{ {mode: mTopLevel}, { @@ -480,7 +475,7 @@ func TestValueReader(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { vr := &valueReader{ - r: bytes.NewReader(tc.data), + r: bufio.NewReader(bytes.NewReader(tc.data)), stack: []vrState{ {mode: mTopLevel}, { @@ -538,7 +533,7 @@ func TestValueReader(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { vr := &valueReader{ - r: bytes.NewReader(tc.data), + r: bufio.NewReader(bytes.NewReader(tc.data)), stack: []vrState{ {mode: mTopLevel}, { @@ -598,7 +593,7 @@ func TestValueReader(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { vr := &valueReader{ - r: bytes.NewReader(tc.data), + r: bufio.NewReader(bytes.NewReader(tc.data)), stack: []vrState{ {mode: mTopLevel}, { @@ -653,7 +648,7 @@ func TestValueReader(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { vr := &valueReader{ - r: bytes.NewReader(tc.data), + r: bufio.NewReader(bytes.NewReader(tc.data)), stack: []vrState{ {mode: mTopLevel}, { @@ -708,7 +703,7 @@ func TestValueReader(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { vr := &valueReader{ - r: bytes.NewReader(tc.data), + r: bufio.NewReader(bytes.NewReader(tc.data)), stack: []vrState{ {mode: mTopLevel}, { @@ -783,7 +778,7 @@ func TestValueReader(t *testing.T) { append([]byte{0x40, 0x27, 0x00, 0x00}, testcstring...), (*valueReader).ReadString, "", - io.EOF, + io.ErrUnexpectedEOF, TypeString, }, { @@ -863,7 +858,7 @@ func TestValueReader(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { vr := &valueReader{ - r: bytes.NewReader(tc.data), + r: bufio.NewReader(bytes.NewReader(tc.data)), stack: []vrState{ {mode: mTopLevel}, { @@ -995,7 +990,7 @@ func TestValueReader(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { vr := &valueReader{ - r: bytes.NewReader(tc.data), + r: bufio.NewReader(bytes.NewReader(tc.data)), stack: []vrState{ {mode: mTopLevel}, { @@ -1062,7 +1057,7 @@ func TestValueReader(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { vr := &valueReader{ - r: bytes.NewReader(tc.data), + r: bufio.NewReader(bytes.NewReader(tc.data)), stack: []vrState{ {mode: mTopLevel}, { @@ -1132,7 +1127,7 @@ func TestValueReader(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { vr := &valueReader{ - r: bytes.NewReader(tc.data), + r: bufio.NewReader(bytes.NewReader(tc.data)), stack: []vrState{ {mode: mTopLevel}, { @@ -1349,12 +1344,12 @@ func TestValueReader(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + const startingEnd = 64 t.Run("Skip", func(t *testing.T) { vr := &valueReader{ - r: bytes.NewReader(tc.data[tc.offset:]), - d: tc.data[:tc.offset], + r: bufio.NewReader(bytes.NewReader(tc.data[tc.startingOffset:tc.offset])), stack: []vrState{ - {mode: mTopLevel}, + {mode: mTopLevel, end: startingEnd}, {mode: mElement, vType: tc.t}, }, frame: 1, @@ -1365,16 +1360,18 @@ func TestValueReader(t *testing.T) { if !errequal(t, err, tc.err) { t.Errorf("Did not receive expected error; got %v; want %v", err, tc.err) } - if tc.err == nil && vr.offset != tc.offset { - t.Errorf("Offset not set at correct position; got %d; want %d", vr.offset, tc.offset) + if tc.err == nil { + offset := startingEnd - vr.stack[0].end + if offset != tc.offset { + t.Errorf("Offset not set at correct position; got %d; want %d", offset, tc.offset) + } } }) t.Run("ReadBytes", func(t *testing.T) { vr := &valueReader{ - r: bytes.NewReader([]byte{}), - d: tc.data, + r: bufio.NewReader(bytes.NewReader(tc.data[tc.startingOffset:tc.offset])), stack: []vrState{ - {mode: mTopLevel}, + {mode: mTopLevel, end: startingEnd}, {mode: mElement, vType: tc.t}, }, frame: 1, @@ -1385,8 +1382,11 @@ func TestValueReader(t *testing.T) { if !errequal(t, err, tc.err) { t.Errorf("Did not receive expected error; got %v; want %v", err, tc.err) } - if tc.err == nil && vr.offset != tc.offset { - t.Errorf("Offset not set at correct position; got %d; want %d", vr.offset, tc.offset) + if tc.err == nil { + offset := startingEnd - vr.stack[0].end + if offset != tc.offset { + t.Errorf("Offset not set at correct position; got %d; want %d", vr.offset, tc.offset) + } } if tc.err == nil && !bytes.Equal(got, tc.data[tc.startingOffset:]) { t.Errorf("Did not receive expected bytes. got %v; want %v", got, tc.data[tc.startingOffset:]) @@ -1417,7 +1417,7 @@ func TestValueReader(t *testing.T) { "append bytes", []byte{0x01, 0x02, 0x03, 0x04}, Type(0), - io.EOF, + io.ErrUnexpectedEOF, }, } @@ -1426,7 +1426,7 @@ func TestValueReader(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() vr := &valueReader{ - r: bytes.NewReader(tc.want), + r: bufio.NewReader(bytes.NewReader(tc.want)), stack: []vrState{ {mode: mTopLevel}, },