diff --git a/query.go b/query.go index 9735ef0..333fe09 100644 --- a/query.go +++ b/query.go @@ -1,6 +1,9 @@ package xpath import ( + "bytes" + "fmt" + "hash/fnv" "reflect" ) @@ -728,16 +731,17 @@ type unionQuery struct { func (u *unionQuery) Select(t iterator) NodeNavigator { if u.iterator == nil { - var list []NodeNavigator - var i int + var m = make(map[uint64]NodeNavigator) root := t.Current().Copy() for { node := u.Left.Select(t) if node == nil { break } - node = node.Copy() - list = append(list, node) + code := getHashCode(node.Copy()) + if _, ok := m[code]; !ok { + m[code] = node.Copy() + } } t.Current().MoveTo(root) for { @@ -745,18 +749,18 @@ func (u *unionQuery) Select(t iterator) NodeNavigator { if node == nil { break } - node = node.Copy() - var exists bool - for _, x := range list { - if reflect.DeepEqual(x, node) { - exists = true - break - } - } - if !exists { - list = append(list, node) + code := getHashCode(node.Copy()) + if _, ok := m[code]; !ok { + m[code] = node.Copy() } } + list := make([]NodeNavigator, len(m)) + var i int + for _, v := range m { + list[i] = v + i++ + } + i = 0 u.iterator = func() NodeNavigator { if i >= len(list) { return nil @@ -780,6 +784,35 @@ func (u *unionQuery) Clone() query { return &unionQuery{Left: u.Left.Clone(), Right: u.Right.Clone()} } +func getHashCode(n NodeNavigator) uint64 { + var sb bytes.Buffer + switch n.NodeType() { + case AttributeNode, TextNode, CommentNode: + sb.WriteString(fmt.Sprintf("%s=%s", n.LocalName(), n.Value())) + if n.MoveToParent() { + sb.WriteString(n.LocalName()) + } + case ElementNode: + sb.WriteString(n.Prefix() + n.LocalName()) + d := 1 + for n.MoveToPrevious() { + d++ + } + sb.WriteString(fmt.Sprintf("-%d", d)) + + for n.MoveToParent() { + d = 1 + for n.MoveToPrevious() { + d++ + } + sb.WriteString(fmt.Sprintf("-%d", d)) + } + } + h := fnv.New64a() + h.Write([]byte(sb.String())) + return h.Sum64() +} + func getNodePosition(q query) int { type Position interface { position() int