diff --git a/integration-tests/src/main/scala/za/co/absa/spline/test/LineageWalker.scala b/integration-tests/src/main/scala/za/co/absa/spline/test/LineageWalker.scala index 9f599feb..973d1bc0 100644 --- a/integration-tests/src/main/scala/za/co/absa/spline/test/LineageWalker.scala +++ b/integration-tests/src/main/scala/za/co/absa/spline/test/LineageWalker.scala @@ -18,6 +18,14 @@ package za.co.absa.spline.test import za.co.absa.spline.producer.model._ +/** + * A class to walk through the execution plan. + * + * @param allOpMap A map of all operations in the execution plan. + * @param funMap A map of all functional expressions in the execution plan. + * @param litMap A map of all literals in the execution plan. + * @param attrMap A map of all attributes in the execution plan. + */ class LineageWalker( allOpMap: Map[String, Operation], funMap: Map[String, FunctionalExpression], @@ -25,28 +33,53 @@ class LineageWalker( attrMap: Map[String, Attribute] ) { - def attributeById(attributeId: String): Attribute = attrMap(attributeId) - - def operationById(operationId: String): Operation = allOpMap(operationId) - - def dependsOn(att: Attribute, onAtt: Attribute): Boolean = { - dependsOnRec(AttrRef(att.id), onAtt.id) + /** + * Retrieves an attribute by its ID. + * + * @param attributeId The ID of the attribute. + * @return The attribute with the given ID. + */ + def getAttributeById(attributeId: String): Attribute = attrMap(attributeId) + + /** + * Retrieves an operation by its ID. + * + * @param operationId The ID of the operation. + * @return The operation with the given ID. + */ + def getOperationById(operationId: String): Operation = allOpMap(operationId) + + /** + * Checks if an attribute depends on another attribute. + * + * @param sourceAttribute The attribute to check. + * @param targetAttribute The attribute that the first attribute may depend on. + * @return True if the first attribute depends on the second attribute, false otherwise. + */ + def dependsOn(sourceAttribute: Attribute, targetAttribute: Attribute): Boolean = { + dependsOnRecursively(AttrRef(sourceAttribute.id), targetAttribute.id) } - private def dependsOnRec(refs: Seq[AttrOrExprRef], id: String): Boolean = - refs.exists(dependsOnRec(_, id)) + private def dependsOnRecursively(refs: Seq[AttrOrExprRef], attrId: String): Boolean = + refs.exists(dependsOnRecursively(_, attrId)) - private def dependsOnRec(ref: AttrOrExprRef, id: String): Boolean = ref match { - case AttrRef(attrIfd) => - attrIfd == id || dependsOnRec(attrMap(attrIfd).childRefs, id) + private def dependsOnRecursively(ref: AttrOrExprRef, targetAttrId: String): Boolean = ref match { + case AttrRef(attrId) => + attrId == targetAttrId || dependsOnRecursively(attrMap(attrId).childRefs, targetAttrId) case ExprRef(exprId) => - exprId == id || !litMap.contains("exprId") && dependsOnRec(funMap(exprId).childRefs, id) + funMap.get(exprId).exists(expr => dependsOnRecursively(expr.childRefs, targetAttrId)) } } object LineageWalker { + /** + * Creates a LineageWalker instance from an execution plan. + * + * @param plan The execution plan. + * @return A LineageWalker instance. + */ def apply(plan: ExecutionPlan): LineageWalker = { val allOpMap = plan.operations.all .map(op => op.id -> op) diff --git a/integration-tests/src/main/scala/za/co/absa/spline/test/ProducerModelImplicits.scala b/integration-tests/src/main/scala/za/co/absa/spline/test/ProducerModelImplicits.scala index f061104e..9411771e 100644 --- a/integration-tests/src/main/scala/za/co/absa/spline/test/ProducerModelImplicits.scala +++ b/integration-tests/src/main/scala/za/co/absa/spline/test/ProducerModelImplicits.scala @@ -26,11 +26,11 @@ object ProducerModelImplicits { implicit class OperationOps(val operation: Operation) extends AnyVal { def outputAttributes(implicit walker: LineageWalker): IOAttributes = { - operation.output.map(walker.attributeById) + operation.output.map(walker.getAttributeById) } def childOperations(implicit walker: LineageWalker): Seq[Operation] = { - operation.childIds.map(walker.operationById) + operation.childIds.map(walker.getOperationById) } def childOperation(implicit walker: LineageWalker): Operation = {