Skip to content

Commit

Permalink
sstable: fix restarts integer overflow
Browse files Browse the repository at this point in the history
Fix bug with integer overflows while indexing into blocks with large KVs
in SeekGE() and SeekLT() in rowblk_iter. Updated members in blockEntry
that represent  offsets in blocks to be type offsetInBlock (alias for int64).

Added check in rowblk_writer to ensure that block sizes do not exceed
rowblk.MaximumSize before writing more data to the block.

Wrote unit tests to verify correct behavior for SeekGE() and SeekLT() with large
blocks >2GB.
  • Loading branch information
EdwardX29 committed Jan 21, 2025
1 parent ac677f8 commit a59316b
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 51 deletions.
122 changes: 71 additions & 51 deletions sstable/rowblk/rowblk_iter.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ type Iter struct {

// offset is the byte index that marks where the current key/value is
// encoded in the block.
offset int32
offset offsetInBlock
// nextOffset is the byte index where the next key/value is encoded in the
// block.
nextOffset int32
nextOffset offsetInBlock
// A "restart point" in a block is a point where the full key is encoded,
// instead of just having a suffix of the key encoded. See readEntry() for
// how prefix compression of keys works. Keys in between two restart points
Expand All @@ -130,7 +130,10 @@ type Iter struct {
// 4 bytes of the block as a uint32 (i.ptr[len(block)-4:]). i.restarts can
// therefore be seen as the point where data in the block ends, and a list
// of offsets of all restart points begins.
restarts int32
//
// int64 is used to prevent overflow and preserve signedness for binary
// search invariants.
restarts offsetInBlock
// Number of restart points in this block. Encoded at the end of the block
// as a uint32.
numRestarts int32
Expand Down Expand Up @@ -193,12 +196,25 @@ type Iter struct {
firstUserKeyWithPrefixBuf []byte
}

// offsetInBlock represents an offset in a block
//
// While restart points are serialized as uint32's, it is possible for offsets to
// be greater than math.MaxUint32 since they may point to an offset after the KVs.
//
// Previously, offsets were represented as int32, which causes problems with
// integer overflows while indexing into blocks (i.data) with large KVs in SeekGE()
// and SeekLT(). Using an int64 solves the problem of overflows as wraparounds will
// be prevented. Additionally, the signedness of int64 allows repsentation of
// iterators that have conducted backward interation and allows for binary search
// invariants in SeekGE() and SeekLT() to be preserved.
type offsetInBlock int64

type blockEntry struct {
offset int32
keyStart int32
keyEnd int32
valStart int32
valSize int32
offset offsetInBlock
keyStart offsetInBlock
keyEnd offsetInBlock
valStart offsetInBlock
valSize uint32
}

// *Iter implements the block.DataBlockIterator interface.
Expand Down Expand Up @@ -237,7 +253,7 @@ func (i *Iter) Init(
i.synthSuffixBuf = i.synthSuffixBuf[:0]
i.split = split
i.cmp = cmp
i.restarts = int32(len(blk)) - 4*(1+numRestarts)
i.restarts = offsetInBlock(len(blk)) - 4*(1+offsetInBlock(numRestarts))
i.numRestarts = numRestarts
i.ptr = unsafe.Pointer(&blk[0])
i.data = blk
Expand Down Expand Up @@ -396,7 +412,7 @@ func (i *Iter) readEntry() {
}
ptr = unsafe.Pointer(uintptr(ptr) + uintptr(unshared))
i.val = getBytes(ptr, int(value))
i.nextOffset = int32(uintptr(ptr)-uintptr(i.ptr)) + int32(value)
i.nextOffset = offsetInBlock(uintptr(ptr)-uintptr(i.ptr)) + offsetInBlock(value)
}

func (i *Iter) readFirstKey() error {
Expand Down Expand Up @@ -506,16 +522,16 @@ func (i *Iter) clearCache() {
}

func (i *Iter) cacheEntry() {
var valStart int32
valSize := int32(len(i.val))
var valStart offsetInBlock
valSize := uint32(len(i.val))
if valSize > 0 {
valStart = int32(uintptr(unsafe.Pointer(&i.val[0])) - uintptr(i.ptr))
valStart = offsetInBlock(uintptr(unsafe.Pointer(&i.val[0])) - uintptr(i.ptr))
}

i.cached = append(i.cached, blockEntry{
offset: i.offset,
keyStart: int32(len(i.cachedBuf)),
keyEnd: int32(len(i.cachedBuf) + len(i.key)),
keyStart: offsetInBlock(len(i.cachedBuf)),
keyEnd: offsetInBlock(len(i.cachedBuf) + len(i.key)),
valStart: valStart,
valSize: valSize,
})
Expand Down Expand Up @@ -569,8 +585,9 @@ func (i *Iter) SeekGE(key []byte, flags base.SeekGEFlags) *base.InternalKV {
upper := i.numRestarts
for index < upper {
h := int32(uint(index+upper) >> 1) // avoid overflow when computing h

// index ≤ h < upper
offset := decodeRestart(i.data[i.restarts+4*h:])
offset := decodeRestart(i.data[i.restarts+4*offsetInBlock(h):])
// For a restart point, there are 0 bytes shared with the previous key.
// The varint encoding of 0 occupies 1 byte.
ptr := unsafe.Pointer(uintptr(i.ptr) + uintptr(offset+1))
Expand Down Expand Up @@ -642,7 +659,7 @@ func (i *Iter) SeekGE(key []byte, flags base.SeekGEFlags) *base.InternalKV {
// could be equal to the search key. If index == 0, then all keys in this
// block are larger than the key sought, and offset remains at zero.
if index > 0 {
i.offset = decodeRestart(i.data[i.restarts+4*(index-1):])
i.offset = decodeRestart(i.data[i.restarts+4*offsetInBlock(index-1):])
}
i.readEntry()
hiddenPoint := i.decodeInternalKey(i.key)
Expand Down Expand Up @@ -751,8 +768,9 @@ func (i *Iter) SeekLT(key []byte, flags base.SeekLTFlags) *base.InternalKV {
upper := i.numRestarts
for index < upper {
h := int32(uint(index+upper) >> 1) // avoid overflow when computing h

// index ≤ h < upper
offset := decodeRestart(i.data[i.restarts+4*h:])
offset := decodeRestart(i.data[i.restarts+4*offsetInBlock(h):])
// For a restart point, there are 0 bytes shared with the previous key.
// The varint encoding of 0 occupies 1 byte.
ptr := unsafe.Pointer(uintptr(i.ptr) + uintptr(offset+1))
Expand Down Expand Up @@ -853,9 +871,9 @@ func (i *Iter) SeekLT(key []byte, flags base.SeekLTFlags) *base.InternalKV {
// i.Prev(). We need to know when we have hit the offset for index, since then
// we can stop searching. targetOffset encodes that offset for index.
targetOffset := i.restarts
i.offset = decodeRestart(i.data[i.restarts+4*(index-1):])
i.offset = decodeRestart(i.data[i.restarts+4*offsetInBlock(index-1):])
if index < i.numRestarts {
targetOffset = decodeRestart(i.data[i.restarts+4*(index):])
targetOffset = decodeRestart(i.data[i.restarts+4*offsetInBlock(index):])

if i.transforms.HasSyntheticSuffix() {
// The binary search was conducted on keys without suffix replacement,
Expand Down Expand Up @@ -909,7 +927,7 @@ func (i *Iter) SeekLT(key []byte, flags base.SeekLTFlags) *base.InternalKV {
if index+1 < i.numRestarts {
// if index+1 is within the i.data bounds, use it to find the target
// offset.
targetOffset = decodeRestart(i.data[i.restarts+4*(index+1):])
targetOffset = decodeRestart(i.data[i.restarts+4*offsetInBlock(index+1):])
} else {
targetOffset = i.restarts
}
Expand Down Expand Up @@ -1012,9 +1030,9 @@ func (i *Iter) First() *base.InternalKV {
const restartMaskLittleEndianHighByteWithoutSetHasSamePrefix byte = 0b0111_1111
const restartMaskLittleEndianHighByteOnlySetHasSamePrefix byte = 0b1000_0000

func decodeRestart(b []byte) int32 {
func decodeRestart(b []byte) offsetInBlock {
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
return int32(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 |
return offsetInBlock(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 |
uint32(b[3]&restartMaskLittleEndianHighByteWithoutSetHasSamePrefix)<<24)
}

Expand All @@ -1025,7 +1043,7 @@ func (i *Iter) Last() *base.InternalKV {
}

// Seek forward from the last restart point.
i.offset = decodeRestart(i.data[i.restarts+4*(i.numRestarts-1):])
i.offset = decodeRestart(i.data[i.restarts+4*offsetInBlock(i.numRestarts-1):])
if !i.Valid() {
return nil
}
Expand Down Expand Up @@ -1231,7 +1249,7 @@ func (i *Iter) nextPrefixV3(succKey []byte) *base.InternalKV {
shared += i.transforms.SyntheticPrefixAndSuffix.PrefixLen()
// The starting position of the value.
valuePtr := unsafe.Pointer(uintptr(ptr) + uintptr(unshared))
i.nextOffset = int32(uintptr(valuePtr)-uintptr(i.ptr)) + int32(value)
i.nextOffset = offsetInBlock(uintptr(valuePtr)-uintptr(i.ptr)) + offsetInBlock(value)
if invariants.Enabled && unshared < 8 {
// This should not happen since only the key prefix is shared, so even
// if the prefix length is the same as the user key length, the unshared
Expand Down Expand Up @@ -1277,8 +1295,9 @@ func (i *Iter) nextPrefixV3(succKey []byte) *base.InternalKV {
upper := i.numRestarts
for index < upper {
h := int32(uint(index+upper) >> 1) // avoid overflow when computing h

// index ≤ h < upper
offset := decodeRestart(i.data[i.restarts+4*h:])
offset := decodeRestart(i.data[i.restarts+4*offsetInBlock(h):])
if offset < targetOffset {
index = h + 1 // preserves f(index-1) == false
} else {
Expand All @@ -1305,7 +1324,7 @@ func (i *Iter) nextPrefixV3(succKey []byte) *base.InternalKV {
// most significant bit of the 3rd byte is what we use for
// encoding the set-has-same-prefix information, the indexing
// below has +3.
i.data[i.restarts+4*index+3]&restartMaskLittleEndianHighByteOnlySetHasSamePrefix != 0 {
i.data[i.restarts+4*offsetInBlock(index)+3]&restartMaskLittleEndianHighByteOnlySetHasSamePrefix != 0 {
// We still have the same prefix, so move to the next restart.
index++
}
Expand All @@ -1314,7 +1333,7 @@ func (i *Iter) nextPrefixV3(succKey []byte) *base.InternalKV {
// Managed to skip past at least one restart. Resume iteration
// from index-1. Since nextFastCount has been reset to 0, we
// should be able to iterate to the next prefix.
i.offset = decodeRestart(i.data[i.restarts+4*(index-1):])
i.offset = decodeRestart(i.data[i.restarts+4*offsetInBlock(index-1):])
i.readEntry()
}
// Else, unable to skip past any restart. Resume iteration. Since
Expand Down Expand Up @@ -1479,8 +1498,9 @@ start:
upper := i.numRestarts
for index < upper {
h := int32(uint(index+upper) >> 1) // avoid overflow when computing h

// index ≤ h < upper
offset := decodeRestart(i.data[i.restarts+4*h:])
offset := decodeRestart(i.data[i.restarts+4*offsetInBlock(h):])
if offset < targetOffset {
// Looking for the first restart that has offset >= targetOffset, so
// ignore h and earlier.
Expand All @@ -1501,7 +1521,7 @@ start:
// as the index).
i.offset = 0
if index > 0 {
i.offset = decodeRestart(i.data[i.restarts+4*(index-1):])
i.offset = decodeRestart(i.data[i.restarts+4*offsetInBlock(index-1):])
}
// TODO(sumeer): why is the else case not an error given targetOffset is a
// valid offset.
Expand Down Expand Up @@ -1600,8 +1620,8 @@ func (i *Iter) DebugTree(tp treeprinter.Node) {
tp.Childf("%T(%p)", i, i)
}

func (i *Iter) getRestart(idx int) int32 {
return int32(binary.LittleEndian.Uint32(i.data[i.restarts+4*int32(idx):]))
func (i *Iter) getRestart(idx int) offsetInBlock {
return offsetInBlock(binary.LittleEndian.Uint32(i.data[i.restarts+4*offsetInBlock(idx):]))
}

func (i *Iter) isRestartPoint() bool {
Expand All @@ -1621,7 +1641,7 @@ type KVEncoding struct {
IsRestart bool
// Offset is the position within the block at which the key-value pair is
// encoded.
Offset int32
Offset offsetInBlock
// Length is the total length of the KV pair as it is encoded in the block
// format.
Length int32
Expand Down Expand Up @@ -1657,7 +1677,7 @@ func (i *Iter) Describe(tp treeprinter.Node, fmtKV DescribeKV) {
// Format the restart points.
for j := 0; j < int(i.numRestarts); j++ {
offset := i.getRestart(j)
n.Childf("%05d [restart %d]", uint64(i.restarts+4*int32(j)), offset)
n.Childf("%05d [restart %d]", uint64(i.restarts+4*offsetInBlock(j)), offset)
}
}

Expand All @@ -1669,9 +1689,9 @@ func (i *Iter) Describe(tp treeprinter.Node, fmtKV DescribeKV) {
// stored together with the key.
type RawIter struct {
cmp base.Compare
offset int32
nextOffset int32
restarts int32
offset offsetInBlock
nextOffset offsetInBlock
restarts offsetInBlock
numRestarts int32
ptr unsafe.Pointer
data []byte
Expand All @@ -1694,7 +1714,7 @@ func (i *RawIter) Init(cmp base.Compare, blk []byte) error {
return base.CorruptionErrorf("pebble/table: invalid table (block has no restart points)")
}
i.cmp = cmp
i.restarts = int32(len(blk)) - 4*(1+numRestarts)
i.restarts = offsetInBlock(len(blk)) - 4*(1+offsetInBlock(numRestarts))
i.numRestarts = numRestarts
i.ptr = unsafe.Pointer(&blk[0])
i.data = blk
Expand All @@ -1717,7 +1737,7 @@ func (i *RawIter) readEntry() {
i.key = i.key[:len(i.key):len(i.key)]
ptr = unsafe.Pointer(uintptr(ptr) + uintptr(unshared))
i.val = getBytes(ptr, int(value))
i.nextOffset = int32(uintptr(ptr)-uintptr(i.ptr)) + int32(value)
i.nextOffset = offsetInBlock(uintptr(ptr)-uintptr(i.ptr)) + offsetInBlock(value)
}

func (i *RawIter) loadEntry() {
Expand All @@ -1731,16 +1751,16 @@ func (i *RawIter) clearCache() {
}

func (i *RawIter) cacheEntry() {
var valStart int32
valSize := int32(len(i.val))
var valStart offsetInBlock
valSize := uint32(len(i.val))
if valSize > 0 {
valStart = int32(uintptr(unsafe.Pointer(&i.val[0])) - uintptr(i.ptr))
valStart = offsetInBlock(uintptr(unsafe.Pointer(&i.val[0])) - uintptr(i.ptr))
}

i.cached = append(i.cached, blockEntry{
offset: i.offset,
keyStart: int32(len(i.cachedBuf)),
keyEnd: int32(len(i.cachedBuf) + len(i.key)),
keyStart: offsetInBlock(len(i.cachedBuf)),
keyEnd: offsetInBlock(len(i.cachedBuf) + len(i.key)),
valStart: valStart,
valSize: valSize,
})
Expand Down Expand Up @@ -1770,7 +1790,7 @@ func (i *RawIter) SeekGE(key []byte) bool {
// 0, then all keys in this block are larger than the key sought, and offset
// remains at zero.
if index > 0 {
i.offset = int32(binary.LittleEndian.Uint32(i.data[int(i.restarts)+4*(index-1):]))
i.offset = offsetInBlock(binary.LittleEndian.Uint32(i.data[int(i.restarts)+4*(index-1):]))
}
i.loadEntry()

Expand All @@ -1794,7 +1814,7 @@ func (i *RawIter) First() bool {
// Last implements internalIterator.Last, as documented in the pebble package.
func (i *RawIter) Last() bool {
// Seek forward from the last restart point.
i.offset = int32(binary.LittleEndian.Uint32(i.data[i.restarts+4*(i.numRestarts-1):]))
i.offset = offsetInBlock(binary.LittleEndian.Uint32(i.data[i.restarts+4*offsetInBlock(i.numRestarts-1):]))

i.readEntry()
i.clearCache()
Expand Down Expand Up @@ -1842,12 +1862,12 @@ func (i *RawIter) Prev() bool {

targetOffset := i.offset
index := sort.Search(int(i.numRestarts), func(j int) bool {
offset := int32(binary.LittleEndian.Uint32(i.data[int(i.restarts)+4*j:]))
offset := offsetInBlock(binary.LittleEndian.Uint32(i.data[int(i.restarts)+4*j:]))
return offset >= targetOffset
})
i.offset = 0
if index > 0 {
i.offset = int32(binary.LittleEndian.Uint32(i.data[int(i.restarts)+4*(index-1):]))
i.offset = offsetInBlock(binary.LittleEndian.Uint32(i.data[int(i.restarts)+4*(index-1):]))
}

i.readEntry()
Expand Down Expand Up @@ -1899,8 +1919,8 @@ func (i *RawIter) DebugTree(tp treeprinter.Node) {
tp.Childf("%T(%p)", i, i)
}

func (i *RawIter) getRestart(idx int) int32 {
return int32(binary.LittleEndian.Uint32(i.data[i.restarts+4*int32(idx):]))
func (i *RawIter) getRestart(idx int) offsetInBlock {
return offsetInBlock(binary.LittleEndian.Uint32(i.data[i.restarts+4*offsetInBlock(idx):]))
}

func (i *RawIter) isRestartPoint() bool {
Expand Down Expand Up @@ -1935,7 +1955,7 @@ func (i *RawIter) Describe(tp treeprinter.Node, fmtKV DescribeKV) {
// Format the restart points.
for j := 0; j < int(i.numRestarts); j++ {
offset := i.getRestart(j)
n.Childf("%05d [restart %d]", uint64(i.restarts+4*int32(j)), offset)
n.Childf("%05d [restart %d]", uint64(i.restarts+4*offsetInBlock(j)), offset)
}
}

Expand Down
Loading

0 comments on commit a59316b

Please sign in to comment.