diff --git a/trie/ctrie/ctrie.go b/trie/ctrie/ctrie.go index f2f5764..119dffc 100644 --- a/trie/ctrie/ctrie.go +++ b/trie/ctrie/ctrie.go @@ -63,8 +63,10 @@ type Ctrie struct { } // generation demarcates Ctrie snapshots. We use a heap-allocated reference -// instead of an integer to avoid integer overflows. -type generation struct{} +// instead of an integer to avoid integer overflows. Struct must have a field +// on it since two distinct zero-size variables may have the same address in +// memory. +type generation struct{ _ int } // iNode is an indirection node. I-nodes remain present in the Ctrie even as // nodes above and below change. Thread-safety is achieved in part by @@ -320,7 +322,7 @@ func (c *Ctrie) Snapshot() *Ctrie { root := c.readRoot() main := gcasRead(root, c) if c.rdcssRoot(root, main, root.copyToGen(&generation{}, c)) { - return newCtrie(root.copyToGen(&generation{}, c), c.hashFactory, c.readOnly) + return newCtrie(c.readRoot().copyToGen(&generation{}, c), c.hashFactory, c.readOnly) } } } @@ -335,7 +337,7 @@ func (c *Ctrie) ReadOnlySnapshot() *Ctrie { root := c.readRoot() main := gcasRead(root, c) if c.rdcssRoot(root, main, root.copyToGen(&generation{}, c)) { - return newCtrie(root, c.hashFactory, true) + return newCtrie(c.readRoot(), c.hashFactory, true) } } } @@ -363,7 +365,7 @@ func (c *Ctrie) Iterator(cancel <-chan struct{}) <-chan *Entry { ch := make(chan *Entry) snapshot := c.ReadOnlySnapshot() go func() { - traverse(snapshot.root, ch, cancel) + snapshot.traverse(snapshot.readRoot(), ch, cancel) close(ch) }() return ch @@ -385,13 +387,14 @@ func (c *Ctrie) Size() uint { var errCanceled = errors.New("canceled") -func traverse(i *iNode, ch chan<- *Entry, cancel <-chan struct{}) error { +func (c *Ctrie) traverse(i *iNode, ch chan<- *Entry, cancel <-chan struct{}) error { + main := gcasRead(i, c) switch { - case i.main.cNode != nil: - for _, br := range i.main.cNode.array { + case main.cNode != nil: + for _, br := range main.cNode.array { switch b := br.(type) { case *iNode: - if err := traverse(b, ch, cancel); err != nil { + if err := c.traverse(b, ch, cancel); err != nil { return err } case *sNode: @@ -402,8 +405,8 @@ func traverse(i *iNode, ch chan<- *Entry, cancel <-chan struct{}) error { } } } - case i.main.lNode != nil: - for _, e := range i.main.lNode.Map(func(sn interface{}) interface{} { + case main.lNode != nil: + for _, e := range main.lNode.Map(func(sn interface{}) interface{} { return sn.(*sNode).Entry }) { select { @@ -485,7 +488,7 @@ func (c *Ctrie) iinsert(i *iNode, entry *Entry, lev uint, parent *iNode, startGe // If the branch is an I-node, then iinsert is called recursively. in := branch.(*iNode) if startGen == in.gen { - return c.iinsert(in, entry, lev+w, i, i.gen) + return c.iinsert(in, entry, lev+w, i, startGen) } if gcas(i, main, &mainNode{cNode: cn.renewed(startGen, c)}, c) { return c.iinsert(i, entry, lev, parent, startGen) @@ -810,8 +813,8 @@ func gcasComplete(i *iNode, m *mainNode, ctrie *Ctrie) *mainNode { // Signals GCAS failure. Swap old value back into I-node. fn := prev.failed if atomic.CompareAndSwapPointer((*unsafe.Pointer)(unsafe.Pointer(&i.main)), - unsafe.Pointer(m), unsafe.Pointer(fn.prev)) { - return fn.prev + unsafe.Pointer(m), unsafe.Pointer(fn)) { + return fn } m = (*mainNode)(atomic.LoadPointer( (*unsafe.Pointer)(unsafe.Pointer(&i.main)))) @@ -845,7 +848,7 @@ type rdcssDescriptor struct { old *iNode expected *mainNode nv *iNode - committed bool + committed int32 } // readRoot performs a linearizable read of the Ctrie root. This operation is @@ -878,7 +881,7 @@ func (c *Ctrie) rdcssRoot(old *iNode, expected *mainNode, nv *iNode) bool { } if c.casRoot(old, desc) { c.rdcssComplete(false) - return desc.rdcss.committed + return atomic.LoadInt32(&desc.rdcss.committed) == 1 } return false } @@ -909,7 +912,7 @@ func (c *Ctrie) rdcssComplete(abort bool) *iNode { if oldeMain == exp { // Commit the RDCSS. if c.casRoot(r, nv) { - desc.committed = true + atomic.StoreInt32(&desc.committed, 1) return nv } continue diff --git a/trie/ctrie/ctrie_test.go b/trie/ctrie/ctrie_test.go index 52d4462..a3fa561 100644 --- a/trie/ctrie/ctrie_test.go +++ b/trie/ctrie/ctrie_test.go @@ -169,6 +169,47 @@ func TestConcurrency(t *testing.T) { wg.Wait() } +func TestConcurrency2(t *testing.T) { + assert := assert.New(t) + ctrie := New(nil) + var wg sync.WaitGroup + wg.Add(4) + + go func() { + for i := 0; i < 10000; i++ { + ctrie.Insert([]byte(strconv.Itoa(i)), i) + } + wg.Done() + }() + + go func() { + for i := 0; i < 10000; i++ { + val, ok := ctrie.Lookup([]byte(strconv.Itoa(i))) + if ok { + assert.Equal(i, val) + } + } + wg.Done() + }() + + go func() { + for i := 0; i < 10000; i++ { + ctrie.Snapshot() + } + wg.Done() + }() + + go func() { + for i := 0; i < 10000; i++ { + ctrie.ReadOnlySnapshot() + } + wg.Done() + }() + + wg.Wait() + assert.Equal(uint(10000), ctrie.Size()) +} + func TestSnapshot(t *testing.T) { assert := assert.New(t) ctrie := New(nil)