Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor origin in tree node #1189

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ import java.util.{UUID}

// Expression used to refer to fields, functions and similar. This can be used everywhere
// expressions in SQL appear.
abstract class Expression extends TreeNode[Expression] {
abstract class Expression(_origin: Option[Origin] = Option.empty) extends TreeNode[Expression](_origin) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


lazy val resolved: Boolean = childrenResolved

def dataType: DataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ package com.databricks.labs.remorph.intermediate
* the [[Command]] type that is used to execute commands on the server. [[Plan]] is a union of Spark's LogicalPlan and
* QueryPlan.
*/
abstract class Plan[PlanType <: Plan[PlanType]] extends TreeNode[PlanType] {
abstract class Plan[PlanType <: Plan[PlanType]](_origin: Option[Origin] = Option.empty)
extends TreeNode[PlanType](_origin) {
self: PlanType =>

def output: Seq[Attribute]
Expand Down Expand Up @@ -80,9 +81,7 @@ abstract class Plan[PlanType <: Plan[PlanType]] extends TreeNode[PlanType] {
var changed = false

@inline def transformExpression(e: Expression): Expression = {
val newE = CurrentOrigin.withOrigin(e.origin) {
f(e)
}
val newE = f(e)
if (newE.fastEquals(e)) {
e
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package com.databricks.labs.remorph.intermediate

import com.databricks.labs.remorph.utils.Strings.truncatedString
import com.fasterxml.jackson.annotation.JsonIgnore
import org.antlr.v4.runtime.ParserRuleContext
import org.antlr.v4.runtime.tree.ParseTree

import scala.reflect.ClassTag
import scala.util.control.NonFatal
Expand All @@ -10,36 +12,31 @@ import scala.util.control.NonFatal
private class MutableInt(var i: Int)

case class Origin(
line: Option[Int] = None,
startPosition: Option[Int] = None,
startIndex: Option[Int] = None,
stopIndex: Option[Int] = None,
sqlText: Option[String] = None,
objectType: Option[String] = None,
objectName: Option[String] = None)

object CurrentOrigin {
private val value = new ThreadLocal[Origin]() {
override def initialValue: Origin = Origin()
}

def get: Origin = value.get()

def setPosition(line: Int, start: Int): Unit = {
value.set(value.get.copy(line = Some(line), startPosition = Some(start)))
startLine: Int,
startColumn: Int,
endLine: Int,
endColumn: Int,
startTokenIndex: Int,
endTokenIndex: Int)

object Origin {

def fromParseTree(tree: ParseTree): Option[Origin] = {
tree match {
case parserRuleContext: ParserRuleContext => Some(Origin.fromParserRuleContext(parserRuleContext))
case other => Option.empty
}
}

def withOrigin[A](o: Origin)(f: => A): A = {
set(o)
val ret =
try f
finally { reset() }
ret
def fromParserRuleContext(ctx: ParserRuleContext): Origin = {
Origin(
startLine = ctx.start.getLine,
startColumn = ctx.start.getCharPositionInLine,
endLine = ctx.stop.getLine,
endColumn = ctx.stop.getCharPositionInLine + ctx.stop.getStopIndex - ctx.stop.getStartIndex,
startTokenIndex = ctx.start.getTokenIndex,
endTokenIndex = ctx.stop.getTokenIndex)
}

def set(o: Origin): Unit = value.set(o)

def reset(): Unit = value.set(Origin())
}

class TreeNodeException[TreeType <: TreeNode[_]](@transient val tree: TreeType, msg: String, cause: Throwable)
Expand All @@ -57,14 +54,15 @@ class TreeNodeException[TreeType <: TreeNode[_]](@transient val tree: TreeType,
}

// scalastyle:off
abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
abstract class TreeNode[BaseType <: TreeNode[BaseType]](_origin: Option[Origin] = Option.empty) extends Product {
// scalastyle:on
self: BaseType =>

@JsonIgnore lazy val containsChild: Set[TreeNode[_]] = children.toSet
private lazy val _hashCode: Int = productHash(this, scala.util.hashing.MurmurHash3.productSeed)
private lazy val allChildren: Set[TreeNode[_]] = (children ++ innerChildren).toSet[TreeNode[_]]
@JsonIgnore val origin: Origin = CurrentOrigin.get

def origin: Option[Origin] = _origin

/**
* Returns a Seq of the children of this node. Children should not change. Immutability required for containsChild
Expand Down Expand Up @@ -273,10 +271,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
.getOrElse(ctors.maxBy(_.getParameterTypes.length)) // fall back to older heuristic

try {
CurrentOrigin.withOrigin(origin) {
val res = defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType]
res
}
val res = defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType]
res
} catch {
case e: java.lang.IllegalArgumentException =>
throw new TreeNodeException(
Expand Down Expand Up @@ -368,9 +364,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
* the function used to transform this nodes children
*/
def transformDown(rule: PartialFunction[BaseType, BaseType]): BaseType = {
val afterRule = CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(this, identity[BaseType])
}
val afterRule = rule.applyOrElse(this, identity[BaseType])

// Check if unchanged and then possibly return old copy to avoid gc churn.
if (this fastEquals afterRule) {
Expand All @@ -391,13 +385,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
val afterRuleOnChildren = mapChildren(_.transformUp(rule))
val newNode = if (this fastEquals afterRuleOnChildren) {
CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(this, identity[BaseType])
}
rule.applyOrElse(this, identity[BaseType])
} else {
CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(afterRuleOnChildren, identity[BaseType])
}
rule.applyOrElse(afterRuleOnChildren, identity[BaseType])
}
// If the transform function replaces this node with a new one, carry over the tags.
newNode
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package com.databricks.labs.remorph.intermediate.workflows

import com.databricks.labs.remorph.intermediate.TreeNode
import com.databricks.labs.remorph.intermediate.{Origin, TreeNode}

abstract class JobNode extends TreeNode[JobNode]
abstract class JobNode(_origin: Option[Origin] = Option.empty) extends TreeNode[JobNode](_origin)

abstract class LeafJobNode extends JobNode {
override def children: Seq[JobNode] = Seq()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ class SnowflakeAstBuilder(override val vc: SnowflakeVisitorCoordinator)
protected override def unresolved(ruleText: String, message: String): ir.LogicalPlan =
ir.UnresolvedRelation(ruleText = ruleText, message = message)

// Concrete visitors

override def visitSnowflakeFile(ctx: SnowflakeFileContext): ir.LogicalPlan = {
// This very top level visitor does not ignore any valid statements for the batch, instead
// we prepend any errors to the batch plan, so they are generated first in the output.
Expand Down
Loading