diff --git a/src/main/scala/Node.scala b/src/main/scala/Node.scala index 5a022e727..72bd469b8 100644 --- a/src/main/scala/Node.scala +++ b/src/main/scala/Node.scala @@ -149,12 +149,15 @@ final case class Node[A]( import Tree.given final def mainline: List[Node[A]] = + mainlineReverse.reverse + + final def mainlineReverse: List[Node[A]] = @tailrec def loop(tree: Node[A], acc: List[Node[A]]): List[Node[A]] = tree.child match case Some(child) => loop(child, tree :: acc) case None => tree :: acc - loop(this, Nil).reverse + loop(this, Nil) // take the first n nodes in the mainline // keep all variations @@ -167,8 +170,20 @@ final case class Node[A]( node.child match case None => node :: acc case Some(child) => loop(n - 1, child, node.withoutChild :: acc) - if n == 0 then this - else loop(n, this, Nil).foldLeft(none[Node[A]])((acc, node) => node.withChild(acc).some).getOrElse(this) + if n <= 0 then this + else Tree.buildWithNodeReverse(loop(n, this, Nil)).getOrElse(this) + + // take nodes while mainline nodes satisfy the predicate + // keep all variations + def takeMainlineWhile(f: A => Boolean): Option[Node[A]] = + @tailrec + def loop(node: Node[A], acc: List[Node[A]]): List[Node[A]] = + if !f(node.value) then acc + else + node.child match + case None => node :: acc + case Some(child) => loop(child, node.withoutChild :: acc) + Tree.buildWithNodeReverse(loop(this, Nil)) // get the nth node of in the mainline def apply(n: Int): Option[Node[A]] = @@ -475,7 +490,10 @@ object Tree: build(s.zipWithIndex, f.tupled) def buildWithNode[A](s: Seq[Node[A]]): Option[Node[A]] = - s.reverse.foldLeft(none)((acc, a) => a.withChild(acc).some) + buildWithNodeReverse(s.reverse) + + def buildWithNodeReverse[A](s: Seq[Node[A]]): Option[Node[A]] = + s.foldLeft(none)((acc, a) => a.withChild(acc).some) def buildWithNode[A, B](s: Seq[A], f: A => Node[B]): Option[Node[B]] = s.reverse match diff --git a/test-kit/src/test/scala/NodeTest.scala b/test-kit/src/test/scala/NodeTest.scala index 9abecd031..bf4443086 100644 --- a/test-kit/src/test/scala/NodeTest.scala +++ b/test-kit/src/test/scala/NodeTest.scala @@ -73,6 +73,14 @@ class NodeTest extends ScalaCheckSuite: forAll: (node: Node[Int]) => node.take(node.mainline.size) == node + test("takeMainlineWhile"): + forAll: (node: Node[Int]) => + node.takeMainlineWhile(_ % 2 == 0).fold(true)(_.mainlineValues.forall(_ % 2 == 0)) + + test("takeMainlineWhile == identity when all mainline values satisfy the predicate"): + forAll: (node: Node[Int]) => + node.takeMainlineWhile(_ => true) == node.some + test("apply(n) return None if n >= node.mainline.size"): forAll: (node: Node[Int]) => node(node.mainline.size).isEmpty