diff --git a/domtest/document.go b/domtest/document.go index 88e6af6..61934c2 100644 --- a/domtest/document.go +++ b/domtest/document.go @@ -58,9 +58,15 @@ func Reader(t T, r io.Reader) spec.Document { return dom.NewNode(node).(spec.Document) } -func DocumentFragment(t T, r io.Reader, parent atom.Atom) []spec.Element { +func DocumentFragmentReader(t T, r io.Reader, parent atom.Atom) spec.DocumentFragment { t.Helper() - nodes, err := html.ParseFragment(r, &html.Node{ + + body, err := io.ReadAll(r) + if err != nil { + t.Error(err) + return nil + } + nodes, err := html.ParseFragment(bytes.NewReader(body), &html.Node{ Type: html.ElementNode, Data: parent.String(), DataAtom: parent, @@ -69,9 +75,17 @@ func DocumentFragment(t T, r io.Reader, parent atom.Atom) []spec.Element { t.Error(err) return nil } - var result []spec.Element - for _, node := range nodes { - result = append(result, dom.NewNode(node).(spec.Element)) + return dom.NewDocumentFragment(nodes) +} + +func DocumentFragmentResponse(t T, res *http.Response, parent atom.Atom) spec.DocumentFragment { + t.Helper() + defer closeAndCheckError(t, res.Body) + return DocumentFragmentReader(t, res.Body, parent) +} + +func closeAndCheckError(t T, c io.Closer) { + if err := c.Close(); err != nil { + t.Error(err) } - return result } diff --git a/domtest/document_test.go b/domtest/document_test.go index 7bd2280..2f893ef 100644 --- a/domtest/document_test.go +++ b/domtest/document_test.go @@ -11,9 +11,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/net/html/atom" "github.com/crhntr/dom/domtest" "github.com/crhntr/dom/internal/fakes" + "github.com/crhntr/dom/spec" ) var ( @@ -136,6 +138,23 @@ func TestResponse(t *testing.T) { }) } +func TestDocumentFragment(t *testing.T) { + t.Run("when a valid html document is passed", func(t *testing.T) { + testingT := new(fakes.T) + res := &http.Response{ + Body: io.NopCloser(strings.NewReader(fragmentHTML)), + } + fragment := domtest.DocumentFragmentResponse(testingT, res, atom.Body) + + assert.Equal(t, testingT.ErrorCallCount(), 0, "it should not report errors") + assert.Equal(t, testingT.LogCallCount(), 0) + assert.NotZero(t, testingT.HelperCallCount()) + + require.NotNil(t, fragment) + require.Equal(t, spec.NodeTypeDocumentFragment, fragment.NodeType()) + }) +} + func TestReader(t *testing.T) { testingT := new(fakes.T) r := iotest.ErrReader(errors.New("banana"))