Skip to content

Commit

Permalink
Merge pull request #123 from tylertreat/fixes
Browse files Browse the repository at this point in the history
Fix Ctrie snapshotting
  • Loading branch information
tylertreat-wf committed Jan 4, 2016
2 parents 074c32b + 38e136e commit 371ee25
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 17 deletions.
37 changes: 20 additions & 17 deletions trie/ctrie/ctrie.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ type Ctrie struct {
}

// generation demarcates Ctrie snapshots. We use a heap-allocated reference
// instead of an integer to avoid integer overflows.
type generation struct{}
// instead of an integer to avoid integer overflows. Struct must have a field
// on it since two distinct zero-size variables may have the same address in
// memory.
type generation struct{ _ int }

// iNode is an indirection node. I-nodes remain present in the Ctrie even as
// nodes above and below change. Thread-safety is achieved in part by
Expand Down Expand Up @@ -320,7 +322,7 @@ func (c *Ctrie) Snapshot() *Ctrie {
root := c.readRoot()
main := gcasRead(root, c)
if c.rdcssRoot(root, main, root.copyToGen(&generation{}, c)) {
return newCtrie(root.copyToGen(&generation{}, c), c.hashFactory, c.readOnly)
return newCtrie(c.readRoot().copyToGen(&generation{}, c), c.hashFactory, c.readOnly)
}
}
}
Expand All @@ -335,7 +337,7 @@ func (c *Ctrie) ReadOnlySnapshot() *Ctrie {
root := c.readRoot()
main := gcasRead(root, c)
if c.rdcssRoot(root, main, root.copyToGen(&generation{}, c)) {
return newCtrie(root, c.hashFactory, true)
return newCtrie(c.readRoot(), c.hashFactory, true)
}
}
}
Expand Down Expand Up @@ -363,7 +365,7 @@ func (c *Ctrie) Iterator(cancel <-chan struct{}) <-chan *Entry {
ch := make(chan *Entry)
snapshot := c.ReadOnlySnapshot()
go func() {
traverse(snapshot.root, ch, cancel)
snapshot.traverse(snapshot.readRoot(), ch, cancel)
close(ch)
}()
return ch
Expand All @@ -385,13 +387,14 @@ func (c *Ctrie) Size() uint {

var errCanceled = errors.New("canceled")

func traverse(i *iNode, ch chan<- *Entry, cancel <-chan struct{}) error {
func (c *Ctrie) traverse(i *iNode, ch chan<- *Entry, cancel <-chan struct{}) error {
main := gcasRead(i, c)
switch {
case i.main.cNode != nil:
for _, br := range i.main.cNode.array {
case main.cNode != nil:
for _, br := range main.cNode.array {
switch b := br.(type) {
case *iNode:
if err := traverse(b, ch, cancel); err != nil {
if err := c.traverse(b, ch, cancel); err != nil {
return err
}
case *sNode:
Expand All @@ -402,8 +405,8 @@ func traverse(i *iNode, ch chan<- *Entry, cancel <-chan struct{}) error {
}
}
}
case i.main.lNode != nil:
for _, e := range i.main.lNode.Map(func(sn interface{}) interface{} {
case main.lNode != nil:
for _, e := range main.lNode.Map(func(sn interface{}) interface{} {
return sn.(*sNode).Entry
}) {
select {
Expand Down Expand Up @@ -485,7 +488,7 @@ func (c *Ctrie) iinsert(i *iNode, entry *Entry, lev uint, parent *iNode, startGe
// If the branch is an I-node, then iinsert is called recursively.
in := branch.(*iNode)
if startGen == in.gen {
return c.iinsert(in, entry, lev+w, i, i.gen)
return c.iinsert(in, entry, lev+w, i, startGen)
}
if gcas(i, main, &mainNode{cNode: cn.renewed(startGen, c)}, c) {
return c.iinsert(i, entry, lev, parent, startGen)
Expand Down Expand Up @@ -810,8 +813,8 @@ func gcasComplete(i *iNode, m *mainNode, ctrie *Ctrie) *mainNode {
// Signals GCAS failure. Swap old value back into I-node.
fn := prev.failed
if atomic.CompareAndSwapPointer((*unsafe.Pointer)(unsafe.Pointer(&i.main)),
unsafe.Pointer(m), unsafe.Pointer(fn.prev)) {
return fn.prev
unsafe.Pointer(m), unsafe.Pointer(fn)) {
return fn
}
m = (*mainNode)(atomic.LoadPointer(
(*unsafe.Pointer)(unsafe.Pointer(&i.main))))
Expand Down Expand Up @@ -845,7 +848,7 @@ type rdcssDescriptor struct {
old *iNode
expected *mainNode
nv *iNode
committed bool
committed int32
}

// readRoot performs a linearizable read of the Ctrie root. This operation is
Expand Down Expand Up @@ -878,7 +881,7 @@ func (c *Ctrie) rdcssRoot(old *iNode, expected *mainNode, nv *iNode) bool {
}
if c.casRoot(old, desc) {
c.rdcssComplete(false)
return desc.rdcss.committed
return atomic.LoadInt32(&desc.rdcss.committed) == 1
}
return false
}
Expand Down Expand Up @@ -909,7 +912,7 @@ func (c *Ctrie) rdcssComplete(abort bool) *iNode {
if oldeMain == exp {
// Commit the RDCSS.
if c.casRoot(r, nv) {
desc.committed = true
atomic.StoreInt32(&desc.committed, 1)
return nv
}
continue
Expand Down
41 changes: 41 additions & 0 deletions trie/ctrie/ctrie_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,47 @@ func TestConcurrency(t *testing.T) {
wg.Wait()
}

func TestConcurrency2(t *testing.T) {
assert := assert.New(t)
ctrie := New(nil)
var wg sync.WaitGroup
wg.Add(4)

go func() {
for i := 0; i < 10000; i++ {
ctrie.Insert([]byte(strconv.Itoa(i)), i)
}
wg.Done()
}()

go func() {
for i := 0; i < 10000; i++ {
val, ok := ctrie.Lookup([]byte(strconv.Itoa(i)))
if ok {
assert.Equal(i, val)
}
}
wg.Done()
}()

go func() {
for i := 0; i < 10000; i++ {
ctrie.Snapshot()
}
wg.Done()
}()

go func() {
for i := 0; i < 10000; i++ {
ctrie.ReadOnlySnapshot()
}
wg.Done()
}()

wg.Wait()
assert.Equal(uint(10000), ctrie.Size())
}

func TestSnapshot(t *testing.T) {
assert := assert.New(t)
ctrie := New(nil)
Expand Down

0 comments on commit 371ee25

Please sign in to comment.