Skip to content

Commit

Permalink
update "readBytes" to accept byte slice.
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Aug 12, 2024
1 parent 12c237f commit fe3e31f
Showing 1 changed file with 32 additions and 41 deletions.
73 changes: 32 additions & 41 deletions bson/value_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,19 +345,16 @@ func (vr *valueReader) ReadBinary() (b []byte, btype byte, err error) {
}
}

b, err = vr.readBytes(length)
b = make([]byte, length)
err = vr.read(b)
if err != nil {
return nil, 0, err
}
// Make a copy of the returned byte slice because it's just a subslice from the valueReader's
// buffer and is not safe to return in the unmarshaled value.
cp := make([]byte, len(b))
copy(cp, b)

if err := vr.pop(); err != nil {
return nil, 0, err
}
return cp, btype, nil
return b, btype, nil
}

func (vr *valueReader) ReadBoolean() (bool, error) {
Expand Down Expand Up @@ -425,7 +422,8 @@ func (vr *valueReader) ReadCodeWithScope() (code string, dr DocumentReader, err
if strLength <= 0 {
return "", nil, fmt.Errorf("invalid string length: %d", strLength)
}
strBytes, err := vr.readBytes(strLength)
strBytes := make([]byte, strLength)
err = vr.read(strBytes)
if err != nil {
return "", nil, err
}
Expand Down Expand Up @@ -458,13 +456,11 @@ func (vr *valueReader) ReadDBPointer() (ns string, oid ObjectID, err error) {
return "", oid, err
}

oidbytes, err := vr.readBytes(12)
err = vr.read(oid[:])
if err != nil {
return "", ObjectID{}, err
}

copy(oid[:], oidbytes)

if err := vr.pop(); err != nil {
return "", ObjectID{}, err
}
Expand Down Expand Up @@ -492,7 +488,8 @@ func (vr *valueReader) ReadDecimal128() (Decimal128, error) {
return Decimal128{}, err
}

b, err := vr.readBytes(16)
var b [16]byte
err := vr.read(b[:])
if err != nil {
return Decimal128{}, err
}
Expand Down Expand Up @@ -584,14 +581,12 @@ func (vr *valueReader) ReadObjectID() (ObjectID, error) {
return ObjectID{}, err
}

oidbytes, err := vr.readBytes(12)
var oid ObjectID
err := vr.read(oid[:])
if err != nil {
return ObjectID{}, err
}

var oid ObjectID
copy(oid[:], oidbytes)

if err := vr.pop(); err != nil {
return ObjectID{}, err
}
Expand Down Expand Up @@ -729,29 +724,20 @@ func (vr *valueReader) ReadValue() (ValueReader, error) {
return vr, nil
}

// readBytes reads length bytes from the valueReader starting at the current offset. Note that the
// returned byte slice is a subslice from the valueReader buffer and must be converted or copied
// before returning in an unmarshaled value.
func (vr *valueReader) readBytes(length int32) ([]byte, error) {
if length < 0 {
return nil, fmt.Errorf("invalid length: %d", length)
}

buf := make([]byte, length)
_, err := io.ReadFull(vr.r, buf)
func (vr *valueReader) read(p []byte) error {
n, err := io.ReadFull(vr.r, p)
if errors.Is(err, io.ErrUnexpectedEOF) {
return nil, io.EOF
return io.EOF
} else if err != nil {
return nil, err
return err
}

vr.offset += int64(length)

return buf, nil
vr.offset += int64(n)
return nil
}

func (vr *valueReader) appendBytes(dst []byte, length int32) ([]byte, error) {
buf, err := vr.readBytes(length)
buf := make([]byte, length)
err := vr.read(buf)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -786,7 +772,8 @@ func (vr *valueReader) readString() (string, error) {
return "", fmt.Errorf("invalid string length: %d", length)
}

buf, err := vr.readBytes(length)
buf := make([]byte, length)
err = vr.read(buf)
if err != nil {
return "", err
}
Expand All @@ -810,37 +797,41 @@ func (vr *valueReader) peekLength() (int32, error) {
func (vr *valueReader) readLength() (int32, error) { return vr.readi32() }

func (vr *valueReader) readi32() (int32, error) {
buf, err := vr.readBytes(4)
var buf [4]byte
err := vr.read(buf[:])
if err != nil {
return 0, err
}

return int32(binary.LittleEndian.Uint32(buf)), nil
return int32(binary.LittleEndian.Uint32(buf[:])), nil
}

func (vr *valueReader) readu32() (uint32, error) {
buf, err := vr.readBytes(4)
var buf [4]byte
err := vr.read(buf[:])
if err != nil {
return 0, err
}

return binary.LittleEndian.Uint32(buf), nil
return binary.LittleEndian.Uint32(buf[:]), nil
}

func (vr *valueReader) readi64() (int64, error) {
buf, err := vr.readBytes(8)
var buf [8]byte
err := vr.read(buf[:])
if err != nil {
return 0, err
}

return int64(binary.LittleEndian.Uint64(buf)), nil
return int64(binary.LittleEndian.Uint64(buf[:])), nil
}

func (vr *valueReader) readu64() (uint64, error) {
buf, err := vr.readBytes(8)
var buf [8]byte
err := vr.read(buf[:])
if err != nil {
return 0, err
}

return binary.LittleEndian.Uint64(buf), nil
return binary.LittleEndian.Uint64(buf[:]), nil
}

0 comments on commit fe3e31f

Please sign in to comment.