From 2f4f9ddb6f038938e25f3dbdcc948f6fd837d3b8 Mon Sep 17 00:00:00 2001 From: Kristin Cowalcijk Date: Thu, 2 Nov 2023 01:15:28 +0800 Subject: [PATCH] [SEDONA-414] Make ST_MakeLine in sedona-spark work with array inputs. (#1068) --- .../java/org/apache/sedona/flink/FunctionTest.java | 4 ++++ .../spark/sql/sedona_sql/expressions/Functions.scala | 2 +- .../apache/sedona/sql/dataFrameAPITestScala.scala | 4 ++++ .../org/apache/sedona/sql/functionTestScala.scala | 12 ++++++++---- 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java index 4117863692..7a098b403d 100644 --- a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java +++ b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java @@ -794,6 +794,10 @@ public void testMakeLine() { table = table.select(call(Functions.ST_MakeLine.class.getSimpleName(), $("point1"), $("point2"))); Geometry result = (Geometry) first(table).getField(0); assertEquals("LINESTRING (0 0, 1 1)", result.toString()); + + table = tableEnv.sqlQuery("SELECT ST_MakeLine(ARRAY[ST_Point(2, 2), ST_Point(3, 3)]) AS line"); + result = (Geometry) first(table).getField(0); + assertEquals("LINESTRING (2 2, 3 3)", result.toString()); } @Test diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala index 5548343bb3..30ae7a32c7 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala @@ -687,7 +687,7 @@ case class ST_SubDivideExplode(children: Seq[Expression]) } case class ST_MakeLine(inputExpressions: Seq[Expression]) - extends InferredExpression(InferrableFunction.allowRightNull(Functions.makeLine _)) { + extends InferredExpression(inferrableFunction2(Functions.makeLine), inferrableFunction1(Functions.makeLine)) { protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { copy(inputExpressions = newChildren) diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala index 0e5d094db4..b7b837698b 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala @@ -314,6 +314,10 @@ class dataFrameAPITestScala extends TestBaseScala { val actualResult = df.take(1)(0).get(0).asInstanceOf[Geometry].toText() val expectedResult = "LINESTRING (0 0, 1 1)" assert(actualResult == expectedResult) + + val df2 = sparkSession.sql("SELECT ST_MakeLine(ARRAY(ST_Point(0, 0), ST_Point(1, 1), ST_Point(2, 2)))") + val actualResult2 = df2.take(1)(0).get(0).asInstanceOf[Geometry].toText() + assert(actualResult2 == "LINESTRING (0 0, 1 1, 2 2)") } it("Passed ST_MakeValid On Invalid Polygon") { diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala index 6e2c8e359d..f7563beb98 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala @@ -322,11 +322,15 @@ class functionTestScala extends TestBaseScala with Matchers with GeometrySample } it("Passed ST_MakeLine") { - - var testtable = sparkSession.sql( - "SELECT ST_MakeLine(ST_GeomFromText('POINT(1 2)'), ST_GeomFromText('POINT(3 4)'))" + val testtable = sparkSession.sql( + """SELECT + |ST_MakeLine(ST_GeomFromText('POINT(1 2)'), ST_GeomFromText('POINT(3 4)')), + |ST_MakeLine(ARRAY(ST_Point(5, 6), ST_Point(7, 8), ST_Point(9, 10))) + |""".stripMargin ) - assert(testtable.take(1)(0).get(0).asInstanceOf[Geometry].toText.equals("LINESTRING (1 2, 3 4)")) + val row = testtable.take(1)(0) + assert(row.get(0).asInstanceOf[Geometry].toText.equals("LINESTRING (1 2, 3 4)")) + assert(row.get(1).asInstanceOf[Geometry].toText.equals("LINESTRING (5 6, 7 8, 9 10)")) } it("Passed ST_Polygon") {