diff --git a/document.go b/document.go
index 4b3f23f..2ce4002 100644
--- a/document.go
+++ b/document.go
@@ -26,11 +26,11 @@ func (d *Document) GetElementsByClassName(name string) spec.ElementCollection {
}
func (d *Document) QuerySelector(query string) spec.Element {
- return querySelector(d.node, query)
+ return querySelector(d.node, query, false)
}
func (d *Document) QuerySelectorAll(query string) spec.NodeList[spec.Element] {
- return querySelectorAll(d.node, query)
+ return querySelectorAll(d.node, query, false)
}
func (d *Document) Contains(other spec.Node) bool { return contains(d.node, other) }
diff --git a/element.go b/element.go
index abea6e2..0185eee 100644
--- a/element.go
+++ b/element.go
@@ -52,9 +52,11 @@ func (e *Element) GetElementsByClassName(name string) spec.ElementCollection {
return getElementsByClassName(e.node, name)
}
-func (e *Element) QuerySelector(query string) spec.Element { return querySelector(e.node, query) }
+func (e *Element) QuerySelector(query string) spec.Element {
+ return querySelector(e.node, query, false)
+}
func (e *Element) QuerySelectorAll(query string) spec.NodeList[spec.Element] {
- return querySelectorAll(e.node, query)
+ return querySelectorAll(e.node, query, false)
}
func (e *Element) Closest(selector string) spec.Element { return closest(e.node, selector) }
func (e *Element) Matches(selector string) bool { return matches(e.node, selector) }
diff --git a/element_test.go b/element_test.go
index 29af72f..415d21a 100644
--- a/element_test.go
+++ b/element_test.go
@@ -1,12 +1,14 @@
package dom
import (
+ "strconv"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/html"
+ "golang.org/x/net/html/atom"
"github.com/crhntr/dom/spec"
)
@@ -243,3 +245,38 @@ func TestElement_IsSameNode(t *testing.T) {
assert.True(t, e1.IsSameNode(e2))
})
}
+
+func TestElement_QuerySelector(t *testing.T) {
+ parseFirstElement := func(t *testing.T, fragment string) *Element {
+ nodes, err := html.ParseFragment(strings.NewReader(strings.TrimSpace(fragment)), &html.Node{
+ Type: html.ElementNode,
+ Data: atom.Div.String(),
+ DataAtom: atom.Div,
+ })
+ require.NoError(t, err)
+ node := NewNode(nodes[0])
+ require.NotNil(t, node)
+ el, ok := node.(*Element)
+ require.True(t, ok)
+ return el
+ }
+ t.Run("tree of nested results", func(t *testing.T) {
+ element := parseFirstElement(t,
+ /* language=html */ `
`)
+ results := element.QuerySelectorAll("div")
+ require.NotNil(t, results)
+ assert.Equal(t, results.Length(), 3)
+
+ for i := 0; i < results.Length(); i++ {
+ result := results.Item(i)
+ require.NotNil(t, result)
+ assert.Equal(t, "DIV", result.TagName())
+ assert.Equal(t, "n"+strconv.Itoa(i+1), result.ID())
+ }
+ })
+}
diff --git a/fragment.go b/fragment.go
index b1ea354..ebcbcaa 100644
--- a/fragment.go
+++ b/fragment.go
@@ -85,14 +85,14 @@ func (d *DocumentFragment) ChildElementCount() int {
return count
}
-func (d *DocumentFragment) Append(nodes ...spec.ChildNode) {
+func (d *DocumentFragment) Append(nodes ...spec.Node) {
d.nodes = slices.Grow(d.nodes, len(nodes))
for _, node := range nodes {
d.nodes = append(d.nodes, domNodeToHTMLNode(node))
}
}
-func (d *DocumentFragment) Prepend(nodes ...spec.ChildNode) {
+func (d *DocumentFragment) Prepend(nodes ...spec.Node) {
children := make([]*html.Node, 0, len(d.nodes)+len(nodes))
for _, node := range nodes {
children = append(children, domNodeToHTMLNode(node))
@@ -100,7 +100,7 @@ func (d *DocumentFragment) Prepend(nodes ...spec.ChildNode) {
d.nodes = append(children, d.nodes...)
}
-func (d *DocumentFragment) ReplaceChildren(nodes ...spec.ChildNode) {
+func (d *DocumentFragment) ReplaceChildren(nodes ...spec.Node) {
list := make([]*html.Node, 0, len(nodes))
for _, node := range nodes {
list = append(list, domNodeToHTMLNode(node))
@@ -110,7 +110,7 @@ func (d *DocumentFragment) ReplaceChildren(nodes ...spec.ChildNode) {
func (d *DocumentFragment) QuerySelector(query string) spec.Element {
for _, n := range d.nodes {
- el := querySelector(n, query)
+ el := querySelector(n, query, true)
if el != nil {
return el
}
@@ -121,7 +121,7 @@ func (d *DocumentFragment) QuerySelector(query string) spec.Element {
func (d *DocumentFragment) QuerySelectorAll(query string) spec.NodeList[spec.Element] {
var list nodeListHTMLElements
for _, n := range d.nodes {
- list = append(list, querySelectorAll(n, query)...)
+ list = append(list, querySelectorAll(n, query, true)...)
}
return slices.Clip(list)
}
diff --git a/fragment_test.go b/fragment_test.go
index 756816b..29cfcc7 100644
--- a/fragment_test.go
+++ b/fragment_test.go
@@ -1,6 +1,7 @@
package dom_test
import (
+ "strconv"
"strings"
"testing"
@@ -313,6 +314,37 @@ func TestDocumentFragment_QuerySelector(t *testing.T) {
fragment := parseDocumentFragment(t, ``)
require.Nil(t, fragment.QuerySelector("#not-found"))
})
+ t.Run("direct descendant", func(t *testing.T) {
+ fragment := parseDocumentFragment(t, `Peach
`)
+ result := fragment.QuerySelector("p")
+ require.NotNil(t, result)
+ assert.Equal(t, "Peach", result.TextContent())
+ })
+ t.Run("tree of nested results", func(t *testing.T) {
+ fragment := parseDocumentFragment(t,
+ /* language=html */ `
+`)
+ results := fragment.QuerySelectorAll("div")
+ require.NotNil(t, results)
+ assert.Equal(t, results.Length(), 8)
+
+ for i := 0; i < results.Length(); i++ {
+ result := results.Item(i)
+ require.NotNil(t, result)
+ assert.Equal(t, "DIV", result.TagName())
+ assert.Equal(t, "n"+strconv.Itoa(i), result.ID())
+ }
+ })
}
func TestDocumentFragment_QuerySelectorAll(t *testing.T) {
diff --git a/node.go b/node.go
index 0fa5643..9391d31 100644
--- a/node.go
+++ b/node.go
@@ -3,6 +3,7 @@ package dom
import (
"bytes"
"io"
+ "slices"
"strings"
"github.com/andybalholm/cascadia"
@@ -421,16 +422,24 @@ func hasClasses(elementClassesStr, classesStr string) bool {
return len(set) == 0
}
-func querySelector(node *html.Node, query string) spec.Element {
- result := cascadia.Query(node, cascadia.MustCompile(query))
- if result == nil {
- return nil
+func querySelector(node *html.Node, query string, includeParent bool) spec.Element {
+ q := cascadia.MustCompile(query)
+ if includeParent && q.Match(node) {
+ return &Element{node: node}
+ }
+ if result := cascadia.Query(node, q); result != nil {
+ return &Element{node: result}
}
- return &Element{node: result}
+ return nil
}
-func querySelectorAll(node *html.Node, query string) nodeListHTMLElements {
- return cascadia.QueryAll(node, cascadia.MustCompile(query))
+func querySelectorAll(node *html.Node, query string, includeParent bool) nodeListHTMLElements {
+ q := cascadia.MustCompile(query)
+ results := cascadia.QueryAll(node, cascadia.MustCompile(query))
+ if includeParent && node.Type == html.ElementNode && q.Match(node) {
+ results = slices.Insert(results, 0, node)
+ }
+ return results
}
var _ spec.NodeList[spec.Element] = nodeListHTMLElements(nil)