Skip to content

Commit

Permalink
[SEDONA-415] Add optional parameter to ST_Transform (apache#1069)
Browse files Browse the repository at this point in the history
  • Loading branch information
furqaankhan authored and jiayuasu committed Dec 8, 2023
1 parent 041f5c1 commit 90182b0
Show file tree
Hide file tree
Showing 10 changed files with 67 additions and 13 deletions.
12 changes: 10 additions & 2 deletions docs/api/sql/Function.md
Original file line number Diff line number Diff line change
Expand Up @@ -2507,7 +2507,7 @@ MULTIPOLYGON (((-2 -3, -3 -3, -3 3, -2 3, -2 -3)), ((3 -3, 3 3, 4 3, 4 -3, 3 -3)

Introduction:

Transform the Spatial Reference System / Coordinate Reference System of A, from SourceCRS to TargetCRS. For SourceCRS and TargetCRS, WKT format is also available since v1.3.1.
Transform the Spatial Reference System / Coordinate Reference System of A, from SourceCRS to TargetCRS. For SourceCRS and TargetCRS, WKT format is also available since `v1.3.1`. Since `v1.5.1`, if the `SourceCRS` is not specified, CRS will be fetched from the geometry using [ST_SRID](#st_srid).

**Lon/Lat Order in the input geometry**

Expand Down Expand Up @@ -2560,7 +2560,15 @@ PROJCS["WGS 84 / Pseudo-Mercator",
Format:

```
ST_Transform (A: Geometry, SourceCRS: String, TargetCRS: String, [Optional] lenientMode: Boolean)
ST_Transform (A: Geometry, SourceCRS: String, TargetCRS: String, lenientMode: Boolean)
```

```
ST_Transform (A: Geometry, SourceCRS: String, TargetCRS: String)
```

```
ST_Transform (A: Geometry, TargetCRS: String)
```

Since: `v1.2.0`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@

public class FunctionsGeoTools {
public static class ST_Transform extends ScalarFunction {
@DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class)
public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo = Geometry.class) Object o, @DataTypeHint("String") String targetCRS)
throws FactoryException, TransformException {
Geometry geom = (Geometry) o;
return org.apache.sedona.common.FunctionsGeoTools.transform(geom, targetCRS);
}

@DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class)
public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo = Geometry.class) Object o, @DataTypeHint("String") String sourceCRS, @DataTypeHint("String") String targetCRS)
throws FactoryException, TransformException {
Expand Down
6 changes: 6 additions & 0 deletions flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,12 @@ public void testTransform() {
, "epsg:4326", "epsg:3857"));
String result = first(transformedTable).getField(0).toString();
assertEquals("POINT (-13134586.718698347 3764623.3541299687)", result);

pointTable = pointTable.select(call(Functions.ST_SetSRID.class.getSimpleName(), $(pointColNames[0]), 4326)).as(pointColNames[0]);
transformedTable = pointTable.select(call(FunctionsGeoTools.ST_Transform.class.getSimpleName(), $(pointColNames[0]), "epsg:3857"))
.as(pointColNames[0]).select(call(Functions.ST_ReducePrecision.class.getSimpleName(), $(pointColNames[0]), 2));
result = first(transformedTable).getField(0).toString();
assertEquals("POINT (-13134586.72 3764623.35)", result);
}

@Test
Expand Down
15 changes: 13 additions & 2 deletions python/sedona/sql/st_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,7 +1241,7 @@ def ST_SymDifference(a: ColumnOrName, b: ColumnOrName) -> Column:


@validate_argument_types
def ST_Transform(geometry: ColumnOrName, source_crs: ColumnOrName, target_crs: ColumnOrName, disable_error: Optional[Union[ColumnOrName, bool]] = None) -> Column:
def ST_Transform(geometry: ColumnOrName, source_crs: ColumnOrName, target_crs: Optional[Union[ColumnOrName, str]] = None, disable_error: Optional[Union[ColumnOrName, bool]] = None) -> Column:
"""Convert a geometry from one coordinate system to another coordinate system.
:param geometry: Geometry column to convert.
Expand All @@ -1255,7 +1255,18 @@ def ST_Transform(geometry: ColumnOrName, source_crs: ColumnOrName, target_crs: C
:return: Geometry converted to the target coordinate system as an
:rtype: Column
"""
args = (geometry, source_crs, target_crs) if disable_error is None else (geometry, source_crs, target_crs, disable_error)

if disable_error is None:
args = (geometry, source_crs, target_crs)

# When 2 arguments are passed to the function.
# From python's perspective ST_Transform(geometry, source_crs) is provided
# that's why have to check if the target_crs is empty.
# the source_crs acts as target_crs when calling the function
if target_crs is None:
args = (geometry, source_crs)
else:
args = (geometry, source_crs, target_crs, disable_error)
return _call_st_function("ST_Transform", args)


Expand Down
1 change: 0 additions & 1 deletion python/tests/sql/test_dataframe_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,6 @@
(stf.ST_SymDifference, ("", None)),
(stf.ST_Transform, (None, "", "")),
(stf.ST_Transform, ("", None, "")),
(stf.ST_Transform, ("", "", None)),
(stf.ST_Union, (None, "")),
(stf.ST_Union, ("", None)),
(stf.ST_X, (None,)),
Expand Down
11 changes: 9 additions & 2 deletions python/tests/sql/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,15 @@ def test_st_transform(self):
polygon_df.createOrReplaceTempView("polygondf")
polygon_df.show()
function_df = self.spark.sql(
"select ST_Transform(polygondf.countyshape, 'epsg:4326','epsg:3857', false) from polygondf")
function_df.show()
"select ST_ReducePrecision(ST_Transform(polygondf.countyshape, 'epsg:4326','epsg:3857', false), 2) from polygondf")
actual = function_df.take(1)[0][0].wkt
assert actual[:300] == "POLYGON ((-10800163.45 5161718.41, -10800164.34 5162103.12, -10800164.57 5162440.81, -10800164.57 5162443.95, -10800164.57 5162468.37, -10800164.57 5162501.93, -10800165.57 5163066.47, -10800166.9 5163158.61, -10800166.9 5163161.46, -10800167.01 5163167.9, -10800167.01 5163171.5, -10800170.24 516340"

function_df = self.spark.sql(
"select ST_ReducePrecision(ST_Transform(ST_SetSRID(polygondf.countyshape, 4326), 'epsg:3857'), 2) from polygondf"
)
actual = function_df.take(1)[0][0].wkt
assert actual[:300] == "POLYGON ((-10800163.45 5161718.41, -10800164.34 5162103.12, -10800164.57 5162440.81, -10800164.57 5162443.95, -10800164.57 5162468.37, -10800164.57 5162501.93, -10800165.57 5163066.47, -10800166.9 5163158.61, -10800166.9 5163161.46, -10800167.01 5163167.9, -10800167.01 5163171.5, -10800170.24 516340"

def test_st_intersection_intersects_but_not_contains(self):
test_table = self.spark.sql(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ case class ST_Centroid(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_Transform(inputExpressions: Seq[Expression])
extends InferredExpression(inferrableFunction4(FunctionsGeoTools.transform)) {
extends InferredExpression(inferrableFunction4(FunctionsGeoTools.transform), inferrableFunction3(FunctionsGeoTools.transform),
inferrableFunction2(FunctionsGeoTools.transform)) {

protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ object st_functions extends DataFrameAPI {
def ST_Transform(geometry: String, sourceCRS: String, targetCRS: String): Column = wrapExpression[ST_Transform](geometry, sourceCRS, targetCRS, false)
def ST_Transform(geometry: Column, sourceCRS: Column, targetCRS: Column, disableError: Column): Column = wrapExpression[ST_Transform](geometry, sourceCRS, targetCRS, disableError)
def ST_Transform(geometry: String, sourceCRS: String, targetCRS: String, disableError: Boolean): Column = wrapExpression[ST_Transform](geometry, sourceCRS, targetCRS, disableError)
def ST_Transform(geometry: String, targetCRS: String): Column = wrapExpression[ST_Transform](geometry, targetCRS)
def ST_Transform(geometry: Column, targetCRS: Column): Column = wrapExpression[ST_Transform](geometry, targetCRS)

def ST_Union(a: Column, b: Column): Column = wrapExpression[ST_Union](a, b)
def ST_Union(a: String, b: String): Column = wrapExpression[ST_Union](a, b)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,11 +291,19 @@ class dataFrameAPITestScala extends TestBaseScala {
}

it("Passed ST_Transform") {
val pointDf = sparkSession.sql("SELECT ST_Point(1.0, 1.0) AS geom")
val df = pointDf.select(ST_Transform($"geom", lit("EPSG:4326"), lit("EPSG:32649")).as("geom")).selectExpr("ST_ReducePrecision(geom, 2)")
var pointDf = sparkSession.sql("SELECT ST_Point(1.0, 1.0) AS geom")
var df = pointDf.select(ST_Transform($"geom", lit("EPSG:4326"), lit("EPSG:32649")).as("geom")).selectExpr("ST_ReducePrecision(geom, 2)")
val actualResult = df.take(1)(0).get(0).asInstanceOf[Geometry].toText()
val expectedResult = "POINT (-33741810.95 1823994.03)"
assert(actualResult == expectedResult)
assertEquals(expectedResult, actualResult)

pointDf = sparkSession.sql("SELECT ST_Point(40.0, 100.0) AS geom")
.select(ST_SetSRID("geom", 4326).as("geom"))
df = pointDf.select(ST_Transform($"geom", lit("EPSG:32649")).as("geom"))
.selectExpr("ST_ReducePrecision(geom, 2)")
val actual = df.take(1)(0).get(0).asInstanceOf[Geometry].toText()
val expected = "POINT (1560393.11 10364606.84)"
assertEquals(expected, actual)
}

it("Passed ST_Intersection") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,16 @@ class functionTestScala extends TestBaseScala with Matchers with GeometrySample
}

it("Passed ST_Transform") {
val point = "POINT (120 60)"
var point = "POINT (120 60)"
val transformedResult = sparkSession.sql(s"""select ST_Transform(ST_geomFromWKT('$point'),'EPSG:4326', 'EPSG:3857', false)""").rdd.map(row => row.getAs[Geometry](0)).collect()(0)
println(transformedResult)
assertEquals(1.3358338895192828E7, transformedResult.getCoordinate.x, FP_TOLERANCE)
assertEquals(8399737.889818355, transformedResult.getCoordinate.y, FP_TOLERANCE)

point = "POINT (100 40)"
val result = sparkSession.sql(
s"""SELECT
|ST_AsText(ST_ReducePrecision(ST_Transform(ST_SetSRID(ST_GeomFromWKT('$point'), 4326), 'EPSG:3857'), 2))""".stripMargin).first().get(0)
assertEquals("POINT (11131949.08 4865942.28)", result)
}

it("Passed ST_transform WKT version"){
Expand Down

0 comments on commit 90182b0

Please sign in to comment.