Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-415] Add optional parameter to ST_Transform #1069

Merged
merged 10 commits into from
Nov 3, 2023
12 changes: 10 additions & 2 deletions docs/api/sql/Function.md
Original file line number Diff line number Diff line change
Expand Up @@ -2477,7 +2477,7 @@ MULTIPOLYGON (((-2 -3, -3 -3, -3 3, -2 3, -2 -3)), ((3 -3, 3 3, 4 3, 4 -3, 3 -3)

Introduction:

Transform the Spatial Reference System / Coordinate Reference System of A, from SourceCRS to TargetCRS. For SourceCRS and TargetCRS, WKT format is also available since v1.3.1.
Transform the Spatial Reference System / Coordinate Reference System of A, from SourceCRS to TargetCRS. For SourceCRS and TargetCRS, WKT format is also available since `v1.3.1`. Since `v1.5.1`, if the `SourceCRS` is not specified, CRS will be fetched from the geometry using [ST_SRID](#st_srid).

**Lon/Lat Order in the input geometry**

Expand Down Expand Up @@ -2530,7 +2530,15 @@ PROJCS["WGS 84 / Pseudo-Mercator",
Format:

```
ST_Transform (A: Geometry, SourceCRS: String, TargetCRS: String, [Optional] lenientMode: Boolean)
ST_Transform (A: Geometry, SourceCRS: String, TargetCRS: String, lenientMode: Boolean)
```

```
ST_Transform (A: Geometry, SourceCRS: String, TargetCRS: String)
```

```
ST_Transform (A: Geometry, TargetCRS: String)
```

Since: `v1.2.0`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@

public class FunctionsGeoTools {
public static class ST_Transform extends ScalarFunction {
@DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class)
public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo = Geometry.class) Object o, @DataTypeHint("String") String targetCRS)
throws FactoryException, TransformException {
Geometry geom = (Geometry) o;
return org.apache.sedona.common.FunctionsGeoTools.transform(geom, targetCRS);
}

@DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class)
public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo = Geometry.class) Object o, @DataTypeHint("String") String sourceCRS, @DataTypeHint("String") String targetCRS)
throws FactoryException, TransformException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,12 @@ public void testTransform() {
, "epsg:4326", "epsg:3857"));
String result = first(transformedTable).getField(0).toString();
assertEquals("POINT (-13134586.718698347 3764623.3541299687)", result);

pointTable = pointTable.select(call(Functions.ST_SetSRID.class.getSimpleName(), $(pointColNames[0]), 4326)).as(pointColNames[0]);
transformedTable = pointTable.select(call(FunctionsGeoTools.ST_Transform.class.getSimpleName(), $(pointColNames[0]), "epsg:3857"))
.as(pointColNames[0]).select(call(Functions.ST_ReducePrecision.class.getSimpleName(), $(pointColNames[0]), 2));
result = first(transformedTable).getField(0).toString();
assertEquals("POINT (-13134586.72 3764623.35)", result);
}

@Test
Expand Down
15 changes: 13 additions & 2 deletions python/sedona/sql/st_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,7 @@ def ST_SymDifference(a: ColumnOrName, b: ColumnOrName) -> Column:


@validate_argument_types
def ST_Transform(geometry: ColumnOrName, source_crs: ColumnOrName, target_crs: ColumnOrName, disable_error: Optional[Union[ColumnOrName, bool]] = None) -> Column:
def ST_Transform(geometry: ColumnOrName, source_crs: ColumnOrName, target_crs: Optional[Union[ColumnOrName, str]] = None, disable_error: Optional[Union[ColumnOrName, bool]] = None) -> Column:
"""Convert a geometry from one coordinate system to another coordinate system.

:param geometry: Geometry column to convert.
Expand All @@ -1250,7 +1250,18 @@ def ST_Transform(geometry: ColumnOrName, source_crs: ColumnOrName, target_crs: C
:return: Geometry converted to the target coordinate system as an
:rtype: Column
"""
args = (geometry, source_crs, target_crs) if disable_error is None else (geometry, source_crs, target_crs, disable_error)

if disable_error is None:
args = (geometry, source_crs, target_crs)

# When 2 arguments are passed to the function.
# From python's perspective ST_Transform(geometry, source_crs) is provided
# that's why have to check if the target_crs is empty.
# the source_crs acts as target_crs when calling the function
if target_crs is None:
args = (geometry, source_crs)
else:
args = (geometry, source_crs, target_crs, disable_error)
return _call_st_function("ST_Transform", args)


Expand Down
1 change: 0 additions & 1 deletion python/tests/sql/test_dataframe_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,6 @@
(stf.ST_SymDifference, ("", None)),
(stf.ST_Transform, (None, "", "")),
(stf.ST_Transform, ("", None, "")),
(stf.ST_Transform, ("", "", None)),
(stf.ST_Union, (None, "")),
(stf.ST_Union, ("", None)),
(stf.ST_X, (None,)),
Expand Down
11 changes: 9 additions & 2 deletions python/tests/sql/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,15 @@ def test_st_transform(self):
polygon_df.createOrReplaceTempView("polygondf")
polygon_df.show()
function_df = self.spark.sql(
"select ST_Transform(polygondf.countyshape, 'epsg:4326','epsg:3857', false) from polygondf")
function_df.show()
"select ST_ReducePrecision(ST_Transform(polygondf.countyshape, 'epsg:4326','epsg:3857', false), 2) from polygondf")
actual = function_df.take(1)[0][0].wkt
assert actual[:300] == "POLYGON ((-10800163.45 5161718.41, -10800164.34 5162103.12, -10800164.57 5162440.81, -10800164.57 5162443.95, -10800164.57 5162468.37, -10800164.57 5162501.93, -10800165.57 5163066.47, -10800166.9 5163158.61, -10800166.9 5163161.46, -10800167.01 5163167.9, -10800167.01 5163171.5, -10800170.24 516340"

function_df = self.spark.sql(
"select ST_ReducePrecision(ST_Transform(ST_SetSRID(polygondf.countyshape, 4326), 'epsg:3857'), 2) from polygondf"
)
actual = function_df.take(1)[0][0].wkt
assert actual[:300] == "POLYGON ((-10800163.45 5161718.41, -10800164.34 5162103.12, -10800164.57 5162440.81, -10800164.57 5162443.95, -10800164.57 5162468.37, -10800164.57 5162501.93, -10800165.57 5163066.47, -10800166.9 5163158.61, -10800166.9 5163161.46, -10800167.01 5163167.9, -10800167.01 5163171.5, -10800170.24 516340"

def test_st_intersection_intersects_but_not_contains(self):
test_table = self.spark.sql(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ case class ST_Centroid(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_Transform(inputExpressions: Seq[Expression])
extends InferredExpression(inferrableFunction4(FunctionsGeoTools.transform)) {
extends InferredExpression(inferrableFunction4(FunctionsGeoTools.transform), inferrableFunction3(FunctionsGeoTools.transform),
inferrableFunction2(FunctionsGeoTools.transform)) {

protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,8 @@ object st_functions extends DataFrameAPI {
def ST_Transform(geometry: String, sourceCRS: String, targetCRS: String): Column = wrapExpression[ST_Transform](geometry, sourceCRS, targetCRS, false)
def ST_Transform(geometry: Column, sourceCRS: Column, targetCRS: Column, disableError: Column): Column = wrapExpression[ST_Transform](geometry, sourceCRS, targetCRS, disableError)
def ST_Transform(geometry: String, sourceCRS: String, targetCRS: String, disableError: Boolean): Column = wrapExpression[ST_Transform](geometry, sourceCRS, targetCRS, disableError)
def ST_Transform(geometry: String, targetCRS: String): Column = wrapExpression[ST_Transform](geometry, targetCRS)
def ST_Transform(geometry: Column, targetCRS: Column): Column = wrapExpression[ST_Transform](geometry, targetCRS)

def ST_Union(a: Column, b: Column): Column = wrapExpression[ST_Union](a, b)
def ST_Union(a: String, b: String): Column = wrapExpression[ST_Union](a, b)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,19 @@ class dataFrameAPITestScala extends TestBaseScala {
}

it("Passed ST_Transform") {
val pointDf = sparkSession.sql("SELECT ST_Point(1.0, 1.0) AS geom")
val df = pointDf.select(ST_Transform($"geom", lit("EPSG:4326"), lit("EPSG:32649")).as("geom")).selectExpr("ST_ReducePrecision(geom, 2)")
var pointDf = sparkSession.sql("SELECT ST_Point(1.0, 1.0) AS geom")
var df = pointDf.select(ST_Transform($"geom", lit("EPSG:4326"), lit("EPSG:32649")).as("geom")).selectExpr("ST_ReducePrecision(geom, 2)")
val actualResult = df.take(1)(0).get(0).asInstanceOf[Geometry].toText()
val expectedResult = "POINT (-33741810.95 1823994.03)"
assert(actualResult == expectedResult)
assertEquals(expectedResult, actualResult)

pointDf = sparkSession.sql("SELECT ST_Point(40.0, 100.0) AS geom")
.select(ST_SetSRID("geom", 4326).as("geom"))
df = pointDf.select(ST_Transform($"geom", lit("EPSG:32649")).as("geom"))
.selectExpr("ST_ReducePrecision(geom, 2)")
val actual = df.take(1)(0).get(0).asInstanceOf[Geometry].toText()
val expected = "POINT (1560393.11 10364606.84)"
assertEquals(expected, actual)
}

it("Passed ST_Intersection") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,16 @@ class functionTestScala extends TestBaseScala with Matchers with GeometrySample
}

it("Passed ST_Transform") {
val point = "POINT (120 60)"
var point = "POINT (120 60)"
val transformedResult = sparkSession.sql(s"""select ST_Transform(ST_geomFromWKT('$point'),'EPSG:4326', 'EPSG:3857', false)""").rdd.map(row => row.getAs[Geometry](0)).collect()(0)
println(transformedResult)
assertEquals(1.3358338895192828E7, transformedResult.getCoordinate.x, FP_TOLERANCE)
assertEquals(8399737.889818355, transformedResult.getCoordinate.y, FP_TOLERANCE)

point = "POINT (100 40)"
val result = sparkSession.sql(
s"""SELECT
|ST_AsText(ST_ReducePrecision(ST_Transform(ST_SetSRID(ST_GeomFromWKT('$point'), 4326), 'EPSG:3857'), 2))""".stripMargin).first().get(0)
assertEquals("POINT (11131949.08 4865942.28)", result)
}

it("Passed ST_transform WKT version"){
Expand Down
Loading