diff --git a/document.go b/document.go index 4b3f23f..2ce4002 100644 --- a/document.go +++ b/document.go @@ -26,11 +26,11 @@ func (d *Document) GetElementsByClassName(name string) spec.ElementCollection { } func (d *Document) QuerySelector(query string) spec.Element { - return querySelector(d.node, query) + return querySelector(d.node, query, false) } func (d *Document) QuerySelectorAll(query string) spec.NodeList[spec.Element] { - return querySelectorAll(d.node, query) + return querySelectorAll(d.node, query, false) } func (d *Document) Contains(other spec.Node) bool { return contains(d.node, other) } diff --git a/element.go b/element.go index abea6e2..0185eee 100644 --- a/element.go +++ b/element.go @@ -52,9 +52,11 @@ func (e *Element) GetElementsByClassName(name string) spec.ElementCollection { return getElementsByClassName(e.node, name) } -func (e *Element) QuerySelector(query string) spec.Element { return querySelector(e.node, query) } +func (e *Element) QuerySelector(query string) spec.Element { + return querySelector(e.node, query, false) +} func (e *Element) QuerySelectorAll(query string) spec.NodeList[spec.Element] { - return querySelectorAll(e.node, query) + return querySelectorAll(e.node, query, false) } func (e *Element) Closest(selector string) spec.Element { return closest(e.node, selector) } func (e *Element) Matches(selector string) bool { return matches(e.node, selector) } diff --git a/element_test.go b/element_test.go index 29af72f..415d21a 100644 --- a/element_test.go +++ b/element_test.go @@ -1,12 +1,14 @@ package dom import ( + "strconv" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/html" + "golang.org/x/net/html/atom" "github.com/crhntr/dom/spec" ) @@ -243,3 +245,38 @@ func TestElement_IsSameNode(t *testing.T) { assert.True(t, e1.IsSameNode(e2)) }) } + +func TestElement_QuerySelector(t *testing.T) { + parseFirstElement := func(t *testing.T, fragment string) *Element { + nodes, err := html.ParseFragment(strings.NewReader(strings.TrimSpace(fragment)), &html.Node{ + Type: html.ElementNode, + Data: atom.Div.String(), + DataAtom: atom.Div, + }) + require.NoError(t, err) + node := NewNode(nodes[0]) + require.NotNil(t, node) + el, ok := node.(*Element) + require.True(t, ok) + return el + } + t.Run("tree of nested results", func(t *testing.T) { + element := parseFirstElement(t, + /* language=html */ `
+
+
+
+
+
`) + results := element.QuerySelectorAll("div") + require.NotNil(t, results) + assert.Equal(t, results.Length(), 3) + + for i := 0; i < results.Length(); i++ { + result := results.Item(i) + require.NotNil(t, result) + assert.Equal(t, "DIV", result.TagName()) + assert.Equal(t, "n"+strconv.Itoa(i+1), result.ID()) + } + }) +} diff --git a/fragment.go b/fragment.go index b1ea354..ebcbcaa 100644 --- a/fragment.go +++ b/fragment.go @@ -85,14 +85,14 @@ func (d *DocumentFragment) ChildElementCount() int { return count } -func (d *DocumentFragment) Append(nodes ...spec.ChildNode) { +func (d *DocumentFragment) Append(nodes ...spec.Node) { d.nodes = slices.Grow(d.nodes, len(nodes)) for _, node := range nodes { d.nodes = append(d.nodes, domNodeToHTMLNode(node)) } } -func (d *DocumentFragment) Prepend(nodes ...spec.ChildNode) { +func (d *DocumentFragment) Prepend(nodes ...spec.Node) { children := make([]*html.Node, 0, len(d.nodes)+len(nodes)) for _, node := range nodes { children = append(children, domNodeToHTMLNode(node)) @@ -100,7 +100,7 @@ func (d *DocumentFragment) Prepend(nodes ...spec.ChildNode) { d.nodes = append(children, d.nodes...) } -func (d *DocumentFragment) ReplaceChildren(nodes ...spec.ChildNode) { +func (d *DocumentFragment) ReplaceChildren(nodes ...spec.Node) { list := make([]*html.Node, 0, len(nodes)) for _, node := range nodes { list = append(list, domNodeToHTMLNode(node)) @@ -110,7 +110,7 @@ func (d *DocumentFragment) ReplaceChildren(nodes ...spec.ChildNode) { func (d *DocumentFragment) QuerySelector(query string) spec.Element { for _, n := range d.nodes { - el := querySelector(n, query) + el := querySelector(n, query, true) if el != nil { return el } @@ -121,7 +121,7 @@ func (d *DocumentFragment) QuerySelector(query string) spec.Element { func (d *DocumentFragment) QuerySelectorAll(query string) spec.NodeList[spec.Element] { var list nodeListHTMLElements for _, n := range d.nodes { - list = append(list, querySelectorAll(n, query)...) + list = append(list, querySelectorAll(n, query, true)...) } return slices.Clip(list) } diff --git a/fragment_test.go b/fragment_test.go index 756816b..29cfcc7 100644 --- a/fragment_test.go +++ b/fragment_test.go @@ -1,6 +1,7 @@ package dom_test import ( + "strconv" "strings" "testing" @@ -313,6 +314,37 @@ func TestDocumentFragment_QuerySelector(t *testing.T) { fragment := parseDocumentFragment(t, `
`) require.Nil(t, fragment.QuerySelector("#not-found")) }) + t.Run("direct descendant", func(t *testing.T) { + fragment := parseDocumentFragment(t, `

Peach

`) + result := fragment.QuerySelector("p") + require.NotNil(t, result) + assert.Equal(t, "Peach", result.TextContent()) + }) + t.Run("tree of nested results", func(t *testing.T) { + fragment := parseDocumentFragment(t, + /* language=html */ `
+
+
+
+
+
+
+
+
+
+
+
`) + results := fragment.QuerySelectorAll("div") + require.NotNil(t, results) + assert.Equal(t, results.Length(), 8) + + for i := 0; i < results.Length(); i++ { + result := results.Item(i) + require.NotNil(t, result) + assert.Equal(t, "DIV", result.TagName()) + assert.Equal(t, "n"+strconv.Itoa(i), result.ID()) + } + }) } func TestDocumentFragment_QuerySelectorAll(t *testing.T) { diff --git a/node.go b/node.go index 0fa5643..9391d31 100644 --- a/node.go +++ b/node.go @@ -3,6 +3,7 @@ package dom import ( "bytes" "io" + "slices" "strings" "github.com/andybalholm/cascadia" @@ -421,16 +422,24 @@ func hasClasses(elementClassesStr, classesStr string) bool { return len(set) == 0 } -func querySelector(node *html.Node, query string) spec.Element { - result := cascadia.Query(node, cascadia.MustCompile(query)) - if result == nil { - return nil +func querySelector(node *html.Node, query string, includeParent bool) spec.Element { + q := cascadia.MustCompile(query) + if includeParent && q.Match(node) { + return &Element{node: node} + } + if result := cascadia.Query(node, q); result != nil { + return &Element{node: result} } - return &Element{node: result} + return nil } -func querySelectorAll(node *html.Node, query string) nodeListHTMLElements { - return cascadia.QueryAll(node, cascadia.MustCompile(query)) +func querySelectorAll(node *html.Node, query string, includeParent bool) nodeListHTMLElements { + q := cascadia.MustCompile(query) + results := cascadia.QueryAll(node, cascadia.MustCompile(query)) + if includeParent && node.Type == html.ElementNode && q.Match(node) { + results = slices.Insert(results, 0, node) + } + return results } var _ spec.NodeList[spec.Element] = nodeListHTMLElements(nil)