Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transpile SnowFlake select statement #1193

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
// =================================================================================
lexer grammar commonlex;

channels {
COMMENT_CHANNEL
}

// TODO: Remove the use of DUMMY for unfinished Snoflake grammar productions
DUMMY:
'DUMMY'
Expand Down Expand Up @@ -1427,7 +1431,7 @@ WS: SPACE+ -> skip;

// Comments
SQL_COMMENT : '/*' (SQL_COMMENT | .)*? '*/' -> channel(HIDDEN);
LINE_COMMENT : ('--' | '//') ~[\r\n]* -> channel(HIDDEN);
LINE_COMMENT : ('--' | '//') ~[\r\n]* -> channel(COMMENT_CHANNEL);

// Identifiers
ID : ( [A-Z_] | FullWidthLetter) ( [A-Z_#$@0-9] | FullWidthLetter)*;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import com.databricks.labs.remorph.{KoResult, OkResult, Parsing, Optimizing}
import com.databricks.labs.remorph.transpilers.SqlGenerator
import com.typesafe.scalalogging.LazyLogging

class Estimator(queryHistory: QueryHistoryProvider, planParser: PlanParser[_], analyzer: EstimationAnalyzer)
class Estimator(queryHistory: QueryHistoryProvider, planParser: PlanParser, analyzer: EstimationAnalyzer)
extends LazyLogging {

def run(): EstimationReport = {
Expand Down Expand Up @@ -43,8 +43,7 @@ class Estimator(queryHistory: QueryHistoryProvider, planParser: PlanParser[_], a
if (!parsedSet.contains(fingerprint)) {
parsedSet += fingerprint
planParser
.parse(Parsing(query.source, query.user.getOrElse("unknown") + "_" + query.id))
.flatMap(planParser.visit)
.parseLogicalPlan(Parsing(query.source, query.user.getOrElse("unknown") + "_" + query.id))
.run(initialState) match {
case KoResult(PARSE, error) =>
Some(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ case class Fingerprints(fingerprints: Seq[Fingerprint]) {
def uniqueQueries: Int = fingerprints.map(_.fingerprint).distinct.size
}

class Anonymizer(parser: PlanParser[_]) extends LazyLogging {
class Anonymizer(parser: PlanParser) extends LazyLogging {
private val placeholder = Literal("?", UnresolvedType)

def apply(history: QueryHistory): Fingerprints = Fingerprints(history.queries.map(fingerprint))
Expand All @@ -55,7 +55,7 @@ class Anonymizer(parser: PlanParser[_]) extends LazyLogging {
def apply(query: String): String = fingerprint(query)

private[discovery] def fingerprint(query: ExecutedQuery): Fingerprint = {
parser.parse(Parsing(query.source)).flatMap(parser.visit).run(Parsing(query.source)) match {
parser.parseLogicalPlan(Parsing(query.source)).run(Parsing(query.source)) match {
case KoResult(WorkflowStage.PARSE, error) =>
logger.warn(s"Failed to parse query: ${query.source} ${error.msg}")
Fingerprint(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import com.databricks.labs.remorph.parsers.PlanParser
import com.databricks.labs.remorph.transpilers.{PySparkGenerator, SqlGenerator}

class FileSetGenerator(
private val parser: PlanParser[_],
private val parser: PlanParser,
private val sqlGen: SqlGenerator,
private val pyGen: PySparkGenerator)
extends Generator[JobNode, FileSet] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,15 @@ import com.databricks.labs.remorph.intermediate.Rule
import com.databricks.labs.remorph.intermediate.workflows.JobNode
import com.databricks.labs.remorph.parsers.PlanParser

class QueryHistoryToQueryNodes(val parser: PlanParser[_]) extends Rule[JobNode] {
class QueryHistoryToQueryNodes(val parser: PlanParser) extends Rule[JobNode] {
override def apply(plan: JobNode): JobNode = plan match {
case RawMigration(QueryHistory(queries)) => Migration(queries.par.map(executedQuery).seq)
}

private def executedQuery(query: ExecutedQuery): JobNode = {
val state = Parsing(query.source, query.id)
parser
.parse(state)
.flatMap(parser.visit)
.parseLogicalPlan(state)
.flatMap(parser.optimize)
.run(state) match {
case OkResult((_, plan)) => QueryPlan(plan, query)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package com.databricks.labs.remorph.generators.py

import com.databricks.labs.remorph.intermediate.{Binary, DataType, Expression, Name, Plan, StringType, UnresolvedType, Attribute => IRAttribute}
import com.databricks.labs.remorph.intermediate.{Binary, DataType, Expression, Name, Origin, Plan, StringType, UnresolvedType, Attribute => IRAttribute}

// this is a subset of https://docs.python.org/3/library/ast.html

abstract class Statement extends Plan[Statement] {
abstract class Statement extends Plan[Statement]()(Origin.empty) {
override def output: Seq[IRAttribute] = Nil
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,12 @@ class LogicalPlanGenerator(
.sequence
.map(_.mkString.stripSuffix(", "))

code"SELECT $sqlParts$fromClause"
val lineComments = if (proj.comments.length > 0) {
proj.comments.map(node => node.text).mkString("\n") + "\n"
} else {
""
}
code"${lineComments}SELECT $sqlParts$fromClause"
}

private def orderBy(sort: ir.Sort): SQL = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ protected case class Node(tableDefinition: TableDefinition, metadata: Map[String
// `from` is the table which is sourced to create `to` table
protected case class Edge(from: Node, to: Option[Node], metadata: Map[String, String])

class TableGraph(parser: PlanParser[_]) extends DependencyGraph with LazyLogging {
class TableGraph(parser: PlanParser) extends DependencyGraph with LazyLogging {
private val nodes = scala.collection.mutable.Set.empty[Node]
private val edges = scala.collection.mutable.Set.empty[Edge]

Expand Down Expand Up @@ -107,7 +107,7 @@ class TableGraph(parser: PlanParser[_]) extends DependencyGraph with LazyLogging
def buildDependency(queryHistory: QueryHistory, tableDefinition: Set[TableDefinition]): Unit = {
queryHistory.queries.foreach { query =>
try {
val plan = parser.parse(Parsing(query.source)).flatMap(parser.visit).run(Parsing(query.source))
val plan = parser.parseLogicalPlan(Parsing(query.source)).run(Parsing(query.source))
plan match {
case KoResult(_, error) =>
logger.warn(s"Failed to produce plan from query: ${query.id}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ trait Command extends LogicalPlan {
}

case class SqlCommand(sql: String, named_arguments: Map[String, Expression], pos_arguments: Seq[Expression])
extends LeafNode
extends LeafNode()()
with Command

case class CreateDataFrameViewCommand(child: Relation, name: String, is_global: Boolean, replace: Boolean)
extends LeafNode
extends LeafNode()()
with Command

abstract class TableSaveMethod
Expand All @@ -24,7 +24,7 @@ case object OverwriteSaveMode extends SaveMode
case object ErrorIfExistsSaveMode extends SaveMode
case object IgnoreSaveMode extends SaveMode

case class SaveTable(table_name: String, save_method: TableSaveMethod) extends LeafNode with Command
case class SaveTable(table_name: String, save_method: TableSaveMethod) extends LeafNode()() with Command

case class BucketBy(bucket_column_names: Seq[String], num_buckets: Int)

Expand All @@ -39,7 +39,7 @@ case class WriteOperation(
bucket_by: Option[BucketBy],
options: Map[String, String],
clustering_columns: Seq[String])
extends LeafNode
extends LeafNode()()
with Command

abstract class Mode
Expand All @@ -61,7 +61,7 @@ case class WriteOperationV2(
mode: Mode,
overwrite_condition: Option[Expression],
clustering_columns: Seq[String])
extends LeafNode
extends LeafNode()()
with Command

case class Trigger(
Expand All @@ -85,11 +85,11 @@ case class WriteStreamOperationStart(
sink_destination: SinkDestination,
foreach_writer: Option[StreamingForeachFunction],
foreach_batch: Option[StreamingForeachFunction])
extends LeafNode
extends LeafNode()()
with Command

// TODO: align snowflake and common IR implementations for `CreateVariable`
case class CreateVariable(name: Id, dataType: DataType, defaultExpr: Option[Expression] = None, replace: Boolean)
extends LeafNode
extends LeafNode()()
with Command

Original file line number Diff line number Diff line change
Expand Up @@ -112,19 +112,19 @@ case class ColumnDeclaration(
virtualColumnDeclaration: Option[Expression] = Option.empty,
constraints: Seq[Constraint] = Seq.empty)

case class CreateTableCommand(name: String, columns: Seq[ColumnDeclaration]) extends Catalog {}
case class CreateTableCommand(name: String, columns: Seq[ColumnDeclaration]) extends Catalog()() {}

// TODO Need to introduce TableSpecBase, TableSpec and UnresolvedTableSpec

case class ReplaceTableCommand(name: String, columns: Seq[ColumnDeclaration], orCreate: Boolean) extends Catalog
case class ReplaceTableCommand(name: String, columns: Seq[ColumnDeclaration], orCreate: Boolean) extends Catalog()()

case class ReplaceTableAsSelect(
table_name: String,
query: LogicalPlan,
writeOptions: Map[String, String],
orCreate: Boolean,
isAnalyzed: Boolean = false)
extends Catalog
extends Catalog()()

sealed trait TableAlteration
case class AddColumn(columnDeclaration: Seq[ColumnDeclaration]) extends TableAlteration
Expand All @@ -149,31 +149,31 @@ case class DropColumns(columnNames: Seq[String]) extends TableAlteration
case class RenameConstraint(oldName: String, newName: String) extends TableAlteration
case class RenameColumn(oldName: String, newName: String) extends TableAlteration

case class AlterTableCommand(tableName: String, alterations: Seq[TableAlteration]) extends Catalog {}
case class AlterTableCommand(tableName: String, alterations: Seq[TableAlteration]) extends Catalog()() {}

// Catalog API (experimental / unstable)
abstract class Catalog extends LeafNode {
abstract class Catalog()(origin: Origin = Origin.empty) extends LeafNode()(origin) {
override def output: Seq[Attribute] = Seq.empty
}

case class SetCurrentDatabase(db_name: String) extends Catalog {}
case class ListDatabases(pattern: Option[String]) extends Catalog {}
case class ListTables(db_name: Option[String], pattern: Option[String]) extends Catalog {}
case class ListFunctions(db_name: Option[String], pattern: Option[String]) extends Catalog {}
case class ListColumns(table_name: String, db_name: Option[String]) extends Catalog {}
case class GetDatabase(db_name: String) extends Catalog {}
case class GetTable(table_name: String, db_name: Option[String]) extends Catalog {}
case class GetFunction(function_name: String, db_name: Option[String]) extends Catalog {}
case class DatabaseExists(db_name: String) extends Catalog {}
case class TableExists(table_name: String, db_name: Option[String]) extends Catalog {}
case class FunctionExists(function_name: String, db_name: Option[String]) extends Catalog {}
case class SetCurrentDatabase(db_name: String) extends Catalog()() {}
case class ListDatabases(pattern: Option[String]) extends Catalog()() {}
case class ListTables(db_name: Option[String], pattern: Option[String]) extends Catalog()() {}
case class ListFunctions(db_name: Option[String], pattern: Option[String]) extends Catalog()() {}
case class ListColumns(table_name: String, db_name: Option[String]) extends Catalog()() {}
case class GetDatabase(db_name: String) extends Catalog()() {}
case class GetTable(table_name: String, db_name: Option[String]) extends Catalog()() {}
case class GetFunction(function_name: String, db_name: Option[String]) extends Catalog()() {}
case class DatabaseExists(db_name: String) extends Catalog()() {}
case class TableExists(table_name: String, db_name: Option[String]) extends Catalog()() {}
case class FunctionExists(function_name: String, db_name: Option[String]) extends Catalog()() {}
case class CreateExternalTable(
table_name: String,
path: Option[String],
source: Option[String],
description: Option[String],
override val schema: DataType)
extends Catalog {}
extends Catalog()() {}

// As per Spark v2Commands
case class CreateTable(
Expand All @@ -182,7 +182,7 @@ case class CreateTable(
source: Option[String],
description: Option[String],
override val schema: DataType)
extends Catalog {}
extends Catalog()() {}

// As per Spark v2Commands
case class CreateTableAsSelect(
Expand All @@ -191,19 +191,19 @@ case class CreateTableAsSelect(
path: Option[String],
source: Option[String],
description: Option[String])
extends Catalog {}

case class DropTempView(view_name: String) extends Catalog {}
case class DropGlobalTempView(view_name: String) extends Catalog {}
case class RecoverPartitions(table_name: String) extends Catalog {}
case class IsCached(table_name: String) extends Catalog {}
case class CacheTable(table_name: String, storage_level: StorageLevel) extends Catalog {}
case class UncachedTable(table_name: String) extends Catalog {}
case class ClearCache() extends Catalog {}
case class RefreshTable(table_name: String) extends Catalog {}
case class RefreshByPath(path: String) extends Catalog {}
case class SetCurrentCatalog(catalog_name: String) extends Catalog {}
case class ListCatalogs(pattern: Option[String]) extends Catalog {}
extends Catalog()() {}

case class DropTempView(view_name: String) extends Catalog()() {}
case class DropGlobalTempView(view_name: String) extends Catalog()() {}
case class RecoverPartitions(table_name: String) extends Catalog()() {}
case class IsCached(table_name: String) extends Catalog()() {}
case class CacheTable(table_name: String, storage_level: StorageLevel) extends Catalog()() {}
case class UncachedTable(table_name: String) extends Catalog()() {}
case class ClearCache() extends Catalog()() {}
case class RefreshTable(table_name: String) extends Catalog()() {}
case class RefreshByPath(path: String) extends Catalog()() {}
case class SetCurrentCatalog(catalog_name: String) extends Catalog()() {}
case class ListCatalogs(pattern: Option[String]) extends Catalog()() {}

case class TableIdentifier(table: String, database: Option[String])
case class CatalogTable(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.databricks.labs.remorph.intermediate

// Used for DML other than SELECT
abstract class Modification extends LogicalPlan
abstract class Modification extends LogicalPlan()(Origin.empty)

case class InsertIntoTable(
target: LogicalPlan,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import java.util.{UUID}

// Expression used to refer to fields, functions and similar. This can be used everywhere
// expressions in SQL appear.
abstract class Expression extends TreeNode[Expression] {
abstract class Expression(origin: Origin = Origin.empty) extends TreeNode[Expression]()(origin) {
lazy val resolved: Boolean = childrenResolved

def dataType: DataType
Expand Down
Loading
Loading