Skip to content

Commit

Permalink
Add Cursor.Err() to check for cursor errors and fix data races (#27)
Browse files Browse the repository at this point in the history
* add Cursor.Err()

* fix data race by using atomics for record.Next accesses
  • Loading branch information
Preetam authored Apr 9, 2017
1 parent 8f9b8f1 commit 138f7c6
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 14 deletions.
52 changes: 46 additions & 6 deletions cursor.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ type Cursor struct {
current *record
first bool
snapshot int64
err error
}

// NewCursor returns a new cursor with a snapshot view of the
Expand Down Expand Up @@ -43,10 +44,11 @@ func (c *Collection) NewCursor() (*Cursor, error) {
cur.current.lock.RLock()
for (cur.current.Deleted != 0 && cur.current.Deleted <= cur.snapshot) ||
(cur.current.Offset >= cur.snapshot) {
rec, err = cur.collection.readRecord(cur.current.Next)
rec, err = cur.collection.readRecord(atomic.LoadInt64(&cur.current.Next))
if err != nil {
cur.current.lock.RUnlock()
cur.current = nil
cur.err = err
return cur, nil
}
cur.current.lock.RUnlock()
Expand Down Expand Up @@ -83,9 +85,12 @@ func (c *Cursor) Next() bool {
}

c.current.lock.RLock()
rec, err := c.collection.readRecord(c.current.Next)
rec, err := c.collection.readRecord(atomic.LoadInt64(&c.current.Next))
if err != nil {
c.current.lock.RUnlock()
if atomic.LoadInt64(&c.current.Next) != 0 {
c.err = err
}
c.current = nil
return false
}
Expand All @@ -95,9 +100,12 @@ func (c *Cursor) Next() bool {
c.current.lock.RLock()
for (c.current.Deleted != 0 && c.current.Deleted <= c.snapshot) ||
(c.current.Offset >= c.snapshot) {
rec, err = c.collection.readRecord(c.current.Next)
rec, err = c.collection.readRecord(atomic.LoadInt64(&c.current.Next))
if err != nil {
c.current.lock.RUnlock()
if atomic.LoadInt64(&c.current.Next) != 0 {
c.err = err
}
c.current = nil
return false
}
Expand Down Expand Up @@ -143,12 +151,18 @@ func (c *Cursor) Seek(key string) {
rec, err = c.collection.readRecord(c.collection.Head)
c.collection.metaLock.RUnlock()
if err != nil {
if atomic.LoadInt64(&c.current.Next) != 0 {
c.err = err
}
c.current = nil
return
}
} else {
rec, err = c.collection.readRecord(offset)
if err != nil {
if atomic.LoadInt64(&c.current.Next) != 0 {
c.err = err
}
c.current = nil
return
}
Expand All @@ -160,7 +174,14 @@ func (c *Cursor) Seek(key string) {
if rec.Key >= key {
if (rec.Deleted > 0 && rec.Deleted <= c.snapshot) || (rec.Offset >= c.snapshot) {
oldRec := rec
rec = c.collection.nextRecord(rec)
rec, err = c.collection.nextRecord(rec)
if err != nil {
if atomic.LoadInt64(&c.current.Next) != 0 {
c.err = err
}
c.current = nil
return
}
oldRec.lock.RUnlock()
c.current = rec
continue
Expand All @@ -170,15 +191,34 @@ func (c *Cursor) Seek(key string) {
}
if (rec.Deleted > 0 && rec.Deleted <= c.snapshot) || (rec.Offset >= c.snapshot) {
oldRec := rec
rec = c.collection.nextRecord(rec)
rec, err = c.collection.nextRecord(rec)
if err != nil {
if atomic.LoadInt64(&c.current.Next) != 0 {
c.err = err
}
c.current = nil
return
}
oldRec.lock.RUnlock()
continue
}
if rec.Key < key {
c.current = rec
}
oldRec := rec
rec = c.collection.nextRecord(rec)
rec, err = c.collection.nextRecord(rec)
if err != nil {
if atomic.LoadInt64(&c.current.Next) != 0 {
c.err = err
}
c.current = nil
return
}
oldRec.lock.RUnlock()
}
}

// Err returns the error encountered during iteration, if any.
func (c *Cursor) Err() error {
return c.err
}
23 changes: 15 additions & 8 deletions lm2.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,15 +315,19 @@ func (c *Collection) readRecord(offset int64) (*record, error) {
return rec, nil
}

func (c *Collection) nextRecord(rec *record) *record {
func (c *Collection) nextRecord(rec *record) (*record, error) {
if rec == nil {
return nil
return nil, errors.New("lm2: invalid record")
}
nextRec, err := c.readRecord(rec.Next)
if atomic.LoadInt64(&rec.Next) == 0 {
// There's no next record.
return nil, nil
}
nextRec, err := c.readRecord(atomic.LoadInt64(&rec.Next))
if err != nil {
return nil
return nil, err
}
return nextRec
return nextRec, nil
}

func writeRecord(rec *record, currentOffset int64, buf *bytes.Buffer) error {
Expand Down Expand Up @@ -409,7 +413,10 @@ func (c *Collection) findLastLessThanOrEqual(key string, startingOffset int64) (
}
offset = rec.Offset
oldRec := rec
rec = c.nextRecord(oldRec)
rec, err = c.nextRecord(oldRec)
if err != nil {
return 0, err
}
oldRec.lock.RUnlock()
}

Expand Down Expand Up @@ -570,7 +577,7 @@ func (c *Collection) Update(wb *WriteBatch) (int64, error) {
}
rec := &record{
recordHeader: recordHeader{
Next: prevRec.Next,
Next: atomic.LoadInt64(&prevRec.Next),
},
Key: key,
Value: value,
Expand All @@ -583,7 +590,7 @@ func (c *Collection) Update(wb *WriteBatch) (int64, error) {
}
newlyInserted[key] = newRecordOffset
c.cache.forcePush(rec)
prevRec.Next = newRecordOffset
atomic.StoreInt64(&prevRec.Next, newRecordOffset)
walEntry.Push(newWALRecord(prevRec.Offset, prevRec.recordHeader.bytes()))
if prevRec.Key == key {
overwrittenRecords = append(overwrittenRecords, prevRec.Offset)
Expand Down
37 changes: 37 additions & 0 deletions lm2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ func verifyOrder(t *testing.T, c *Collection) int {
t.Errorf("key %v greater than previous key %v", cur.Key(), prev)
}
}
if err = cur.Err(); err != nil {
t.Fatal(err)
}
return count
}

Expand Down Expand Up @@ -63,6 +66,9 @@ func TestCopy(t *testing.T) {
for cur.Next() {
work <- [2]string{cur.Key(), cur.Value()}
}
if err = cur.Err(); err != nil {
t.Fatal(err)
}
close(work)
}()

Expand Down Expand Up @@ -152,6 +158,9 @@ func TestWriteBatch(t *testing.T) {
}
i++
}
if err = cur.Err(); err != nil {
t.Fatal(err)
}

expected = [][2]string{
{"key1", "5"},
Expand Down Expand Up @@ -193,6 +202,9 @@ func TestWriteBatch(t *testing.T) {
}
i++
}
if err = cur.Err(); err != nil {
t.Fatal(err)
}

// Check if cursor can be reset
cur.Seek("")
Expand All @@ -211,6 +223,9 @@ func TestWriteBatch(t *testing.T) {
}
i++
}
if err = cur.Err(); err != nil {
t.Fatal(err)
}
}

func TestWriteBatch1(t *testing.T) {
Expand Down Expand Up @@ -352,6 +367,9 @@ func TestWriteBatch2(t *testing.T) {
}
i++
}
if err = cur.Err(); err != nil {
t.Fatal(err)
}
t.Logf("%+v", c.Stats())
}

Expand Down Expand Up @@ -422,6 +440,9 @@ func TestWriteCloseOpen(t *testing.T) {
}
i++
}
if err = cur.Err(); err != nil {
t.Fatal(err)
}
t.Logf("%+v", c.Stats())

err = c.Destroy()
Expand Down Expand Up @@ -497,6 +518,10 @@ func TestSeekToFirstKey(t *testing.T) {
if cur.Key() != "a" {
t.Fatalf("expected cursor key to be 'a', got %v", cur.Key())
}

if err = cur.Err(); err != nil {
t.Fatal(err)
}
}

func TestOverwriteFirstKey(t *testing.T) {
Expand Down Expand Up @@ -548,6 +573,10 @@ func TestOverwriteFirstKey(t *testing.T) {
if cur.Key() != "b" {
t.Fatalf("expected cursor key to be 'b', got %v", cur.Key())
}

if err = cur.Err(); err != nil {
t.Fatal(err)
}
}

func TestOverwriteFirstKeyOnly(t *testing.T) {
Expand Down Expand Up @@ -594,6 +623,10 @@ func TestOverwriteFirstKeyOnly(t *testing.T) {
t.Error("expected Next() to return false")
t.Log(cur.Key(), "=>", cur.Value())
}

if err = cur.Err(); err != nil {
t.Fatal(err)
}
}

func TestDeleteInFirstUpdate(t *testing.T) {
Expand Down Expand Up @@ -655,4 +688,8 @@ func TestSeekOverwrittenKey(t *testing.T) {
if cur.Key() != "committed" {
t.Fatalf("expected cur.Key() to be %s, got %s", "committed", cur.Key())
}

if err = cur.Err(); err != nil {
t.Fatal(err)
}
}

0 comments on commit 138f7c6

Please sign in to comment.