Skip to content

Commit

Permalink
add option to prevent overwriting keys; #41 (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
Preetam committed Jan 29, 2018
1 parent 9ff6c4c commit a033aaf
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 11 deletions.
23 changes: 21 additions & 2 deletions lm2.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,30 @@ var (
// is invalid. The collection should be closed and reopened.
ErrInternal = errors.New("lm2: internal error")

ErrRolledBack = errors.New("lm2: rolled back")

fileVersion = [8]byte{'l', 'm', '2', '_', '0', '0', '1', '\n'}
)

// RollbackError is the error type returned after rollbacks.
type RollbackError struct {
DuplicateKey bool
ConflictedKey string
Err error
}

func (e RollbackError) Error() string {
if e.DuplicateKey {
return fmt.Sprintf("lm2: rolled back due to duplicate key (conflicted key: `%s`)",
e.ConflictedKey)
}
return fmt.Sprintf("lm2: rolled back (%s)", e.Err.Error())
}

// IsRollbackError returns true if err is a RollbackError.
func IsRollbackError(err error) bool {
_, ok := err.(RollbackError)
return ok
}

// Collection represents an ordered linked list map.
type Collection struct {
fileHeader
Expand Down
49 changes: 46 additions & 3 deletions lm2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func TestCopy(t *testing.T) {
wb.Set(key, val)
RETRY:
if _, err := c.Update(wb); err != nil {
if err == ErrRolledBack {
if IsRollbackError(err) {
t.Log("rollback")
goto RETRY
}
Expand Down Expand Up @@ -112,7 +112,7 @@ func TestCopy(t *testing.T) {
RETRY2:
_, err := c2.Update(wb)
if err != nil {
if err == ErrRolledBack {
if IsRollbackError(err) {
t.Log("rollback")
goto RETRY2
}
Expand All @@ -127,7 +127,7 @@ func TestCopy(t *testing.T) {
RETRY3:
_, err := c2.Update(wb)
if err != nil {
if err == ErrRolledBack {
if IsRollbackError(err) {
t.Log("rollback")
goto RETRY3
}
Expand Down Expand Up @@ -1133,3 +1133,46 @@ func TestOK(t *testing.T) {
t.Error("expected OK() to return false")
}
}

func TestConflictRollback(t *testing.T) {
c, err := NewCollection("/tmp/test_conflictrollback.lm2", 100)
if err != nil {
t.Fatal(err)
}
defer c.Destroy()

wb := NewWriteBatch()
wb.Set("key1", "1")
t.Log("Set", "key1", "1")
_, err = c.Update(wb)
if err != nil {
t.Fatal(err)
}

wb = NewWriteBatch()
wb.AllowOverwrite(false)
wb.Set("key1", "2")
t.Log("Set", "key1", "2")
_, err = c.Update(wb)
if err == nil {
t.Fatal("expected a rollback")
}

if !IsRollbackError(err) {
t.Fatal("expected a rollback error")
}

rollbackErr := err.(RollbackError)
if !rollbackErr.DuplicateKey {
t.Error("expected DuplicateKey to be true")
}

if rollbackErr.ConflictedKey != "key1" {
t.Errorf("expected ConflictedKey to be `%s`, got `%s`",
"key1", rollbackErr.ConflictedKey)
}

if !c.OK() {
t.Error("expected OK() to return true")
}
}
2 changes: 1 addition & 1 deletion tx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func TestTransactionalSquares(t *testing.T) {
return nil
})
if err != nil {
if err != ErrRolledBack && err.Error() != "lm2: partial read (random failure)" {
if !IsRollbackError(err) && err.Error() != "lm2: partial read (random failure)" {
t.Fatal(err)
} else {
atomic.AddUint32(&expectedWriteFailures, 1)
Expand Down
15 changes: 14 additions & 1 deletion update.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ func (c *Collection) findLastLessThanOrEqual(key string, startingOffset int64, l

// Update atomically and durably applies a WriteBatch (a set of updates) to the collection.
// It returns the new version (on success) and an error.
// The error may be a RollbackError; use IsRollbackError to check.
func (c *Collection) Update(wb *WriteBatch) (int64, error) {
c.writeLock.Lock()
defer c.writeLock.Unlock()
Expand Down Expand Up @@ -219,6 +220,13 @@ KEYS_LOOP:
walEntry.Push(newWALRecord(prevRec.Offset, prevRec.recordHeader.bytes()))

if prevRec.Key == key && prevRec.Deleted == 0 {
if !wb.allowOverwrite {
rollbackErr = RollbackError{
DuplicateKey: true,
ConflictedKey: key,
}
goto ROLLBACK
}
overwrittenRecords = append(overwrittenRecords, prevRec.Offset)
}

Expand Down Expand Up @@ -340,7 +348,12 @@ ROLLBACK:
c.cache.cache = map[int64]*record{}
c.cache.lock.Unlock()

return 0, ErrRolledBack
if IsRollbackError(rollbackErr) {
return 0, rollbackErr
}
return 0, RollbackError{
Err: rollbackErr,
}
}

// Update + fsync data file header.
Expand Down
17 changes: 13 additions & 4 deletions write_batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@ package lm2

// WriteBatch represents a set of modifications.
type WriteBatch struct {
sets map[string]string
deletes map[string]struct{}
sets map[string]string
deletes map[string]struct{}
allowOverwrite bool
}

// NewWriteBatch returns a new WriteBatch.
func NewWriteBatch() *WriteBatch {
return &WriteBatch{
sets: map[string]string{},
deletes: map[string]struct{}{},
sets: map[string]string{},
deletes: map[string]struct{}{},
allowOverwrite: true,
}
}

Expand All @@ -26,6 +28,13 @@ func (wb *WriteBatch) Delete(key string) {
wb.deletes[key] = struct{}{}
}

// AllowOverwrite determines whether keys will be overwritten.
// If allow is false and an existing key is being
// set, updates will be rolled back.
func (wb *WriteBatch) AllowOverwrite(allow bool) {
wb.allowOverwrite = allow
}

func (wb *WriteBatch) cleanup() {
for key := range wb.deletes {
delete(wb.sets, key)
Expand Down

0 comments on commit a033aaf

Please sign in to comment.