Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-415] Add optional parameter to ST_Transform #1069

Merged
merged 10 commits into from
Nov 3, 2023
12 changes: 10 additions & 2 deletions docs/api/sql/Function.md
Original file line number Diff line number Diff line change
Expand Up @@ -2477,7 +2477,7 @@ MULTIPOLYGON (((-2 -3, -3 -3, -3 3, -2 3, -2 -3)), ((3 -3, 3 3, 4 3, 4 -3, 3 -3)

Introduction:

Transform the Spatial Reference System / Coordinate Reference System of A, from SourceCRS to TargetCRS. For SourceCRS and TargetCRS, WKT format is also available since v1.3.1.
Transform the Spatial Reference System / Coordinate Reference System of A, from SourceCRS to TargetCRS. For SourceCRS and TargetCRS, WKT format is also available since `v1.3.1`. Since `v1.5.1`, if the `SourceCRS` is not specified, CRS will be fetched from the geometry using [ST_SRID](#st_srid).

**Lon/Lat Order in the input geometry**

Expand Down Expand Up @@ -2530,7 +2530,15 @@ PROJCS["WGS 84 / Pseudo-Mercator",
Format:

```
ST_Transform (A: Geometry, SourceCRS: String, TargetCRS: String, [Optional] lenientMode: Boolean)
ST_Transform (A: Geometry, SourceCRS: String, TargetCRS: String, lenientMode: Boolean)
```

```
ST_Transform (A: Geometry, SourceCRS: String, TargetCRS: String)
```

```
ST_Transform (A: Geometry, TargetCRS: String)
```

Since: `v1.2.0`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@

public class FunctionsGeoTools {
public static class ST_Transform extends ScalarFunction {
@DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class)
public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo = Geometry.class) Object o, @DataTypeHint("String") String targetCRS)
throws FactoryException, TransformException {
Geometry geom = (Geometry) o;
return org.apache.sedona.common.FunctionsGeoTools.transform(geom, targetCRS);
}

@DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class)
public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo = Geometry.class) Object o, @DataTypeHint("String") String sourceCRS, @DataTypeHint("String") String targetCRS)
throws FactoryException, TransformException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,12 @@ public void testTransform() {
, "epsg:4326", "epsg:3857"));
String result = first(transformedTable).getField(0).toString();
assertEquals("POINT (-13134586.718698347 3764623.3541299687)", result);

pointTable = pointTable.select(call(Functions.ST_SetSRID.class.getSimpleName(), $(pointColNames[0]), 4326)).as(pointColNames[0]);
transformedTable = pointTable.select(call(FunctionsGeoTools.ST_Transform.class.getSimpleName(), $(pointColNames[0]), "epsg:3857"))
.as(pointColNames[0]).select(call(Functions.ST_ReducePrecision.class.getSimpleName(), $(pointColNames[0]), 2));
result = first(transformedTable).getField(0).toString();
assertEquals("POINT (-13134586.72 3764623.35)", result);
}

@Test
Expand Down
7 changes: 6 additions & 1 deletion python/sedona/sql/st_functions.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New Python DataFrame API needs to be tested.

Original file line number Diff line number Diff line change
Expand Up @@ -1250,7 +1250,12 @@ def ST_Transform(geometry: ColumnOrName, source_crs: ColumnOrName, target_crs: C
:return: Geometry converted to the target coordinate system as an
:rtype: Column
"""
args = (geometry, source_crs, target_crs) if disable_error is None else (geometry, source_crs, target_crs, disable_error)
if disable_error is None:
args = (geometry, source_crs, target_crs)
if source_crs is None:
args = (geometry, target_crs)
else:
args = (geometry, source_crs, target_crs, disable_error)
return _call_st_function("ST_Transform", args)


Expand Down
5 changes: 5 additions & 0 deletions python/tests/sql/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ def test_st_transform(self):
"select ST_Transform(polygondf.countyshape, 'epsg:4326','epsg:3857', false) from polygondf")
function_df.show()

function_df = self.spark.sql(
"select ST_Transform(polygondf.countyshape, 'epsg:3857', false) from polygondf"
)
function_df.show()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing here. Do not do show. Use assertion to check the correctness


def test_st_intersection_intersects_but_not_contains(self):
test_table = self.spark.sql(
"select ST_GeomFromWKT('POLYGON((1 1, 8 1, 8 8, 1 8, 1 1))') as a,ST_GeomFromWKT('POLYGON((2 2, 9 2, 9 9, 2 9, 2 2))') as b")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ case class ST_Centroid(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_Transform(inputExpressions: Seq[Expression])
extends InferredExpression(inferrableFunction4(FunctionsGeoTools.transform)) {
extends InferredExpression(inferrableFunction4(FunctionsGeoTools.transform), inferrableFunction3(FunctionsGeoTools.transform),
inferrableFunction2(FunctionsGeoTools.transform)) {

protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,8 @@ object st_functions extends DataFrameAPI {
def ST_Transform(geometry: String, sourceCRS: String, targetCRS: String): Column = wrapExpression[ST_Transform](geometry, sourceCRS, targetCRS, false)
def ST_Transform(geometry: Column, sourceCRS: Column, targetCRS: Column, disableError: Column): Column = wrapExpression[ST_Transform](geometry, sourceCRS, targetCRS, disableError)
def ST_Transform(geometry: String, sourceCRS: String, targetCRS: String, disableError: Boolean): Column = wrapExpression[ST_Transform](geometry, sourceCRS, targetCRS, disableError)
def ST_Transform(geometry: String, targetCRS: String): Column = wrapExpression[ST_Transform](geometry, targetCRS)
def ST_Transform(geometry: Column, targetCRS: Column): Column = wrapExpression[ST_Transform](geometry, targetCRS)

def ST_Union(a: Column, b: Column): Column = wrapExpression[ST_Union](a, b)
def ST_Union(a: String, b: String): Column = wrapExpression[ST_Union](a, b)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,19 @@ class dataFrameAPITestScala extends TestBaseScala {
}

it("Passed ST_Transform") {
val pointDf = sparkSession.sql("SELECT ST_Point(1.0, 1.0) AS geom")
val df = pointDf.select(ST_Transform($"geom", lit("EPSG:4326"), lit("EPSG:32649")).as("geom")).selectExpr("ST_ReducePrecision(geom, 2)")
var pointDf = sparkSession.sql("SELECT ST_Point(1.0, 1.0) AS geom")
var df = pointDf.select(ST_Transform($"geom", lit("EPSG:4326"), lit("EPSG:32649")).as("geom")).selectExpr("ST_ReducePrecision(geom, 2)")
val actualResult = df.take(1)(0).get(0).asInstanceOf[Geometry].toText()
val expectedResult = "POINT (-33741810.95 1823994.03)"
assert(actualResult == expectedResult)
assertEquals(expectedResult, actualResult)

pointDf = sparkSession.sql("SELECT ST_Point(40.0, 100.0) AS geom")
.select(ST_SetSRID("geom", 4326).as("geom"))
df = pointDf.select(ST_Transform($"geom", lit("EPSG:32649")).as("geom"))
.selectExpr("ST_ReducePrecision(geom, 2)")
val actual = df.take(1)(0).get(0).asInstanceOf[Geometry].toText()
val expected = "POINT (1560393.11 10364606.84)"
assertEquals(expected, actual)
}

it("Passed ST_Intersection") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,16 @@ class functionTestScala extends TestBaseScala with Matchers with GeometrySample
}

it("Passed ST_Transform") {
val point = "POINT (120 60)"
var point = "POINT (120 60)"
val transformedResult = sparkSession.sql(s"""select ST_Transform(ST_geomFromWKT('$point'),'EPSG:4326', 'EPSG:3857', false)""").rdd.map(row => row.getAs[Geometry](0)).collect()(0)
println(transformedResult)
assertEquals(1.3358338895192828E7, transformedResult.getCoordinate.x, FP_TOLERANCE)
assertEquals(8399737.889818355, transformedResult.getCoordinate.y, FP_TOLERANCE)

point = "POINT (100 40)"
val result = sparkSession.sql(
s"""SELECT
|ST_AsText(ST_ReducePrecision(ST_Transform(ST_SetSRID(ST_GeomFromWKT('$point'), 4326), 'EPSG:3857'), 2))""".stripMargin).first().get(0)
assertEquals("POINT (11131949.08 4865942.28)", result)
}

it("Passed ST_transform WKT version"){
Expand Down