From d99ab70a1c8b25f58f3c770c2004771a71fde2f1 Mon Sep 17 00:00:00 2001 From: Furqaan Khan <46216254+furqaankhan@users.noreply.github.com> Date: Fri, 3 Nov 2023 06:52:36 +0530 Subject: [PATCH] [SEDONA-415] Add optional parameter to ST_Transform (#1069) --- docs/api/sql/Function.md | 12 ++++++++++-- .../flink/expressions/FunctionsGeoTools.java | 7 +++++++ .../org/apache/sedona/flink/FunctionTest.java | 6 ++++++ python/sedona/sql/st_functions.py | 15 +++++++++++++-- python/tests/sql/test_dataframe_api.py | 1 - python/tests/sql/test_function.py | 11 +++++++++-- .../sql/sedona_sql/expressions/Functions.scala | 3 ++- .../sql/sedona_sql/expressions/st_functions.scala | 2 ++ .../apache/sedona/sql/dataFrameAPITestScala.scala | 14 +++++++++++--- .../org/apache/sedona/sql/functionTestScala.scala | 9 +++++++-- 10 files changed, 67 insertions(+), 13 deletions(-) diff --git a/docs/api/sql/Function.md b/docs/api/sql/Function.md index ec1117a371..eede8e0ba4 100644 --- a/docs/api/sql/Function.md +++ b/docs/api/sql/Function.md @@ -2507,7 +2507,7 @@ MULTIPOLYGON (((-2 -3, -3 -3, -3 3, -2 3, -2 -3)), ((3 -3, 3 3, 4 3, 4 -3, 3 -3) Introduction: -Transform the Spatial Reference System / Coordinate Reference System of A, from SourceCRS to TargetCRS. For SourceCRS and TargetCRS, WKT format is also available since v1.3.1. +Transform the Spatial Reference System / Coordinate Reference System of A, from SourceCRS to TargetCRS. For SourceCRS and TargetCRS, WKT format is also available since `v1.3.1`. Since `v1.5.1`, if the `SourceCRS` is not specified, CRS will be fetched from the geometry using [ST_SRID](#st_srid). **Lon/Lat Order in the input geometry** @@ -2560,7 +2560,15 @@ PROJCS["WGS 84 / Pseudo-Mercator", Format: ``` -ST_Transform (A: Geometry, SourceCRS: String, TargetCRS: String, [Optional] lenientMode: Boolean) +ST_Transform (A: Geometry, SourceCRS: String, TargetCRS: String, lenientMode: Boolean) +``` + +``` +ST_Transform (A: Geometry, SourceCRS: String, TargetCRS: String) +``` + +``` +ST_Transform (A: Geometry, TargetCRS: String) ``` Since: `v1.2.0` diff --git a/flink/src/main/java/org/apache/sedona/flink/expressions/FunctionsGeoTools.java b/flink/src/main/java/org/apache/sedona/flink/expressions/FunctionsGeoTools.java index f51a9faccc..65054cf11b 100644 --- a/flink/src/main/java/org/apache/sedona/flink/expressions/FunctionsGeoTools.java +++ b/flink/src/main/java/org/apache/sedona/flink/expressions/FunctionsGeoTools.java @@ -21,6 +21,13 @@ public class FunctionsGeoTools { public static class ST_Transform extends ScalarFunction { + @DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class) + public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo = Geometry.class) Object o, @DataTypeHint("String") String targetCRS) + throws FactoryException, TransformException { + Geometry geom = (Geometry) o; + return org.apache.sedona.common.FunctionsGeoTools.transform(geom, targetCRS); + } + @DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class) public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo = Geometry.class) Object o, @DataTypeHint("String") String sourceCRS, @DataTypeHint("String") String targetCRS) throws FactoryException, TransformException { diff --git a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java index 59485e51b5..78ac0f83bf 100644 --- a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java +++ b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java @@ -232,6 +232,12 @@ public void testTransform() { , "epsg:4326", "epsg:3857")); String result = first(transformedTable).getField(0).toString(); assertEquals("POINT (-13134586.718698347 3764623.3541299687)", result); + + pointTable = pointTable.select(call(Functions.ST_SetSRID.class.getSimpleName(), $(pointColNames[0]), 4326)).as(pointColNames[0]); + transformedTable = pointTable.select(call(FunctionsGeoTools.ST_Transform.class.getSimpleName(), $(pointColNames[0]), "epsg:3857")) + .as(pointColNames[0]).select(call(Functions.ST_ReducePrecision.class.getSimpleName(), $(pointColNames[0]), 2)); + result = first(transformedTable).getField(0).toString(); + assertEquals("POINT (-13134586.72 3764623.35)", result); } @Test diff --git a/python/sedona/sql/st_functions.py b/python/sedona/sql/st_functions.py index 64fe6c286e..da9d845ce6 100644 --- a/python/sedona/sql/st_functions.py +++ b/python/sedona/sql/st_functions.py @@ -1241,7 +1241,7 @@ def ST_SymDifference(a: ColumnOrName, b: ColumnOrName) -> Column: @validate_argument_types -def ST_Transform(geometry: ColumnOrName, source_crs: ColumnOrName, target_crs: ColumnOrName, disable_error: Optional[Union[ColumnOrName, bool]] = None) -> Column: +def ST_Transform(geometry: ColumnOrName, source_crs: ColumnOrName, target_crs: Optional[Union[ColumnOrName, str]] = None, disable_error: Optional[Union[ColumnOrName, bool]] = None) -> Column: """Convert a geometry from one coordinate system to another coordinate system. :param geometry: Geometry column to convert. @@ -1255,7 +1255,18 @@ def ST_Transform(geometry: ColumnOrName, source_crs: ColumnOrName, target_crs: C :return: Geometry converted to the target coordinate system as an :rtype: Column """ - args = (geometry, source_crs, target_crs) if disable_error is None else (geometry, source_crs, target_crs, disable_error) + + if disable_error is None: + args = (geometry, source_crs, target_crs) + + # When 2 arguments are passed to the function. + # From python's perspective ST_Transform(geometry, source_crs) is provided + # that's why have to check if the target_crs is empty. + # the source_crs acts as target_crs when calling the function + if target_crs is None: + args = (geometry, source_crs) + else: + args = (geometry, source_crs, target_crs, disable_error) return _call_st_function("ST_Transform", args) diff --git a/python/tests/sql/test_dataframe_api.py b/python/tests/sql/test_dataframe_api.py index 7bdbbd1476..df6bb3567d 100644 --- a/python/tests/sql/test_dataframe_api.py +++ b/python/tests/sql/test_dataframe_api.py @@ -305,7 +305,6 @@ (stf.ST_SymDifference, ("", None)), (stf.ST_Transform, (None, "", "")), (stf.ST_Transform, ("", None, "")), - (stf.ST_Transform, ("", "", None)), (stf.ST_Union, (None, "")), (stf.ST_Union, ("", None)), (stf.ST_X, (None,)), diff --git a/python/tests/sql/test_function.py b/python/tests/sql/test_function.py index ea958eb4ec..67f21220a0 100644 --- a/python/tests/sql/test_function.py +++ b/python/tests/sql/test_function.py @@ -189,8 +189,15 @@ def test_st_transform(self): polygon_df.createOrReplaceTempView("polygondf") polygon_df.show() function_df = self.spark.sql( - "select ST_Transform(polygondf.countyshape, 'epsg:4326','epsg:3857', false) from polygondf") - function_df.show() + "select ST_ReducePrecision(ST_Transform(polygondf.countyshape, 'epsg:4326','epsg:3857', false), 2) from polygondf") + actual = function_df.take(1)[0][0].wkt + assert actual[:300] == "POLYGON ((-10800163.45 5161718.41, -10800164.34 5162103.12, -10800164.57 5162440.81, -10800164.57 5162443.95, -10800164.57 5162468.37, -10800164.57 5162501.93, -10800165.57 5163066.47, -10800166.9 5163158.61, -10800166.9 5163161.46, -10800167.01 5163167.9, -10800167.01 5163171.5, -10800170.24 516340" + + function_df = self.spark.sql( + "select ST_ReducePrecision(ST_Transform(ST_SetSRID(polygondf.countyshape, 4326), 'epsg:3857'), 2) from polygondf" + ) + actual = function_df.take(1)[0][0].wkt + assert actual[:300] == "POLYGON ((-10800163.45 5161718.41, -10800164.34 5162103.12, -10800164.57 5162440.81, -10800164.57 5162443.95, -10800164.57 5162468.37, -10800164.57 5162501.93, -10800165.57 5163066.47, -10800166.9 5163158.61, -10800166.9 5163161.46, -10800167.01 5163167.9, -10800167.01 5163171.5, -10800170.24 516340" def test_st_intersection_intersects_but_not_contains(self): test_table = self.spark.sql( diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala index cd5366fda0..6db9cc3f43 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala @@ -219,7 +219,8 @@ case class ST_Centroid(inputExpressions: Seq[Expression]) * @param inputExpressions */ case class ST_Transform(inputExpressions: Seq[Expression]) - extends InferredExpression(inferrableFunction4(FunctionsGeoTools.transform)) { + extends InferredExpression(inferrableFunction4(FunctionsGeoTools.transform), inferrableFunction3(FunctionsGeoTools.transform), + inferrableFunction2(FunctionsGeoTools.transform)) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala index 4da9579704..c5ee85d691 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala @@ -280,6 +280,8 @@ object st_functions extends DataFrameAPI { def ST_Transform(geometry: String, sourceCRS: String, targetCRS: String): Column = wrapExpression[ST_Transform](geometry, sourceCRS, targetCRS, false) def ST_Transform(geometry: Column, sourceCRS: Column, targetCRS: Column, disableError: Column): Column = wrapExpression[ST_Transform](geometry, sourceCRS, targetCRS, disableError) def ST_Transform(geometry: String, sourceCRS: String, targetCRS: String, disableError: Boolean): Column = wrapExpression[ST_Transform](geometry, sourceCRS, targetCRS, disableError) + def ST_Transform(geometry: String, targetCRS: String): Column = wrapExpression[ST_Transform](geometry, targetCRS) + def ST_Transform(geometry: Column, targetCRS: Column): Column = wrapExpression[ST_Transform](geometry, targetCRS) def ST_Union(a: Column, b: Column): Column = wrapExpression[ST_Union](a, b) def ST_Union(a: String, b: String): Column = wrapExpression[ST_Union](a, b) diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala index f00da2ff44..7cc7f2ed6c 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala @@ -289,11 +289,19 @@ class dataFrameAPITestScala extends TestBaseScala { } it("Passed ST_Transform") { - val pointDf = sparkSession.sql("SELECT ST_Point(1.0, 1.0) AS geom") - val df = pointDf.select(ST_Transform($"geom", lit("EPSG:4326"), lit("EPSG:32649")).as("geom")).selectExpr("ST_ReducePrecision(geom, 2)") + var pointDf = sparkSession.sql("SELECT ST_Point(1.0, 1.0) AS geom") + var df = pointDf.select(ST_Transform($"geom", lit("EPSG:4326"), lit("EPSG:32649")).as("geom")).selectExpr("ST_ReducePrecision(geom, 2)") val actualResult = df.take(1)(0).get(0).asInstanceOf[Geometry].toText() val expectedResult = "POINT (-33741810.95 1823994.03)" - assert(actualResult == expectedResult) + assertEquals(expectedResult, actualResult) + + pointDf = sparkSession.sql("SELECT ST_Point(40.0, 100.0) AS geom") + .select(ST_SetSRID("geom", 4326).as("geom")) + df = pointDf.select(ST_Transform($"geom", lit("EPSG:32649")).as("geom")) + .selectExpr("ST_ReducePrecision(geom, 2)") + val actual = df.take(1)(0).get(0).asInstanceOf[Geometry].toText() + val expected = "POINT (1560393.11 10364606.84)" + assertEquals(expected, actual) } it("Passed ST_Intersection") { diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala index f7563beb98..91c79b4347 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala @@ -189,11 +189,16 @@ class functionTestScala extends TestBaseScala with Matchers with GeometrySample } it("Passed ST_Transform") { - val point = "POINT (120 60)" + var point = "POINT (120 60)" val transformedResult = sparkSession.sql(s"""select ST_Transform(ST_geomFromWKT('$point'),'EPSG:4326', 'EPSG:3857', false)""").rdd.map(row => row.getAs[Geometry](0)).collect()(0) - println(transformedResult) assertEquals(1.3358338895192828E7, transformedResult.getCoordinate.x, FP_TOLERANCE) assertEquals(8399737.889818355, transformedResult.getCoordinate.y, FP_TOLERANCE) + + point = "POINT (100 40)" + val result = sparkSession.sql( + s"""SELECT + |ST_AsText(ST_ReducePrecision(ST_Transform(ST_SetSRID(ST_GeomFromWKT('$point'), 4326), 'EPSG:3857'), 2))""".stripMargin).first().get(0) + assertEquals("POINT (11131949.08 4865942.28)", result) } it("Passed ST_transform WKT version"){