From b4a18de34e8c563f8cd014b1c3c5016e7c2d4e50 Mon Sep 17 00:00:00 2001 From: jameswillis Date: Fri, 24 Jan 2025 23:04:44 -0800 Subject: [PATCH] [SEDONA-703] Add utils for converting between RDD[Row] and SpatialRdd --- .../org/apache/sedona/sql/utils/Adapter.scala | 105 +++++++++++++++++- .../apache/sedona/sql/adapterTestScala.scala | 54 ++++++++- 2 files changed, 152 insertions(+), 7 deletions(-) diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/utils/Adapter.scala b/spark/common/src/main/scala/org/apache/sedona/sql/utils/Adapter.scala index 9b1067a25a..25086d670e 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/utils/Adapter.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/utils/Adapter.scala @@ -76,7 +76,7 @@ object Adapter { val fieldList = dataFrame.schema.toList.map(f => f.name) val geomColId = fieldList.indexOf(geometryFieldName) assert(geomColId >= 0) - toRdd(dataFrame, geomColId) + toStringEncodedRdd(dataFrame.rdd, geomColId) } /** @@ -93,7 +93,7 @@ object Adapter { geometryColId: Int, fieldNames: Seq[String]): SpatialRDD[Geometry] = { var spatialRDD = new SpatialRDD[Geometry] - spatialRDD.rawSpatialRDD = toRdd(dataFrame, geometryColId).toJavaRDD() + spatialRDD.rawSpatialRDD = toStringEncodedRdd(dataFrame.rdd, geometryColId).toJavaRDD() import scala.jdk.CollectionConverters._ if (fieldNames.nonEmpty) spatialRDD.fieldNames = fieldNames.asJava else spatialRDD.fieldNames = null @@ -283,10 +283,10 @@ object Adapter { leftGeom ++ leftUserData ++ rightGeom ++ rightUserData } - private def toRdd(dataFrame: DataFrame, geometryColId: Int): RDD[Geometry] = { - dataFrame.rdd.map[Geometry](f => { - var geometry = f.get(geometryColId).asInstanceOf[Geometry] - var fieldSize = f.size + private def toStringEncodedRdd(rdd: RDD[Row], geometryColId: Int): RDD[Geometry] = { + rdd.map[Geometry](f => { + val geometry = f.get(geometryColId).asInstanceOf[Geometry] + val fieldSize = f.size var userData: String = null if (fieldSize > 1) { userData = "" @@ -300,6 +300,99 @@ object Adapter { }) } + /** + * Convert an RDD of Rows to a SpatialRDD with a geometry column at a specified index. + * + * columns other than the geometry column are serialized into the geometry's user data as tab + * separated strings. The fieldnames field is not set on the SpatialRDD because they are not + * available in the input RDD. + * + * @param rdd + * the RDD of Rows to convert + * @param geometryColId + * the index of the geometry column in the Row to make the spatial RDD of + * @param deduplicateGeom + * whether to remove the geometry from the user data to avoid duplication + * @return + * the SpatialRDD where the geometry column is the geometry and the user data is the other + * columns + */ + def toStringEncodedSpatialRDD(rdd: RDD[Row], geometryColId: Int): SpatialRDD[Geometry] = { + val spatialRDD = new SpatialRDD[Geometry] + spatialRDD.setRawSpatialRDD(toStringEncodedRdd(rdd, geometryColId)) + spatialRDD + } + + private def toRowEncodedRdd( + rdd: RDD[Row], + geometryColId: Int, + deduplicateGeom: Boolean): RDD[Geometry] = { + rdd.map(f => { + if (deduplicateGeom) { + val withoutGeom = new GenericRowWithSchema( + f.toSeq.patch(geometryColId, Nil, 1).toArray, + StructType(f.schema.patch(geometryColId, Nil, 1))) + + val geometry = f.getAs[Geometry](geometryColId) + geometry.setUserData(withoutGeom) + geometry + } else { + val geometry = f.getAs[Geometry](geometryColId) + geometry.setUserData(f) + geometry + } + }) + } + + /** + * Convert an RDD of Rows to a SpatialRDD with a geometry column at a specified index. + * + * The original Row is stored as the geometry's user data. This should make it easier to work + * with and convert back to a DataFrame. The fieldNames field is not set in the output + * SpatialRDD because they are encoded in the user data. + * + * @param rdd + * the RDD of Rows to convert + * @param geometryColId + * the index of the geometry column in the Row to make the spatial RDD of + * @return + * the SpatialRDD where the geometry column is the geometry and the user data is the Row + */ + def toRowEncodedSpatialRdd( + rdd: RDD[Row], + geometryColId: Int, + deduplicateGeom: Boolean = false): SpatialRDD[Geometry] = { + val spatialRdd = new SpatialRDD[Geometry] + spatialRdd.setRawSpatialRDD(toRowEncodedRdd(rdd, geometryColId, deduplicateGeom)) + spatialRdd + } + + /** + * Convert an RDD of Geometries where the userData is a Row to an RDD of Rows. + * + * This is the inverse of toRowEncodedSpatialRdd. + * + * @param rdd + * the RDD of Geometries to convert + * @param geomColId + * the index in which to replace the geometry column in the Row. Only use if the geometry was + * removed from the user data at creation time. + * @return + * the RDD of Rows + */ + def fromRowEncodedGeomRdd(rdd: RDD[Geometry], geomColId: Integer = null): RDD[Row] = { + if (geomColId == null) { + rdd.map(geom => Row.fromSeq(geom.getUserData.asInstanceOf[Row].toSeq)) + } else { + rdd.map(geom => { + val userData = geom.getUserData.asInstanceOf[Row] + val geomWithoutUserData = geom.copy + geomWithoutUserData.setUserData(null) + Row.fromSeq(userData.toSeq.patch(geomColId, Seq(geomWithoutUserData), 0)) + }) + } + } + private def getGeomAndFields( geom: Geometry, fieldNames: Seq[String]): (Seq[Geometry], Seq[String]) = { diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/adapterTestScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/adapterTestScala.scala index 0d267e6256..21e39e4b60 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/adapterTestScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/adapterTestScala.scala @@ -25,9 +25,11 @@ import org.apache.sedona.core.formatMapper.shapefileParser.ShapefileReader import org.apache.sedona.core.spatialOperator.JoinQuery import org.apache.sedona.core.spatialRDD.{CircleRDD, PointRDD, PolygonRDD} import org.apache.sedona.sql.utils.Adapter +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.expr import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT import org.apache.spark.sql.types._ -import org.locationtech.jts.geom.Point +import org.locationtech.jts.geom.{Geometry, Point} import org.scalatest.GivenWhenThen class adapterTestScala extends TestBaseScala with GivenWhenThen { @@ -552,6 +554,56 @@ class adapterTestScala extends TestBaseScala with GivenWhenThen { assert(row.get(5).asInstanceOf[String] == "attr2") } } + + } + + it("can convert an RDD of Rows to a spatialRDD and back") { + val srcRdd = sparkSession.read + .format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixedWktGeometryInputLocation) + .withColumn("geom", expr("ST_GeomFromWKT(_c0)")) + .withColumn( + "structColumn", + expr("named_struct('structtext', 'spark', 'structint', 5, 'structbool', false)")) + .rdd + val spatialRDD = Adapter.toRowEncodedSpatialRdd(srcRdd, 18) + assert( + spatialRDD.rawSpatialRDD + .take(1) + .get(0) + .asInstanceOf[Geometry] + .getUserData + .asInstanceOf[Row] + .schema == srcRdd.take(1)(0).schema) + val roundTripRdd = Adapter.fromRowEncodedGeomRdd(spatialRDD.rawSpatialRDD) + assert(roundTripRdd.collect() sameElements srcRdd.collect()) + } + + it("can convert an RDD of Rows to a spatialRDD and back without geometry") { + val srcRdd = sparkSession.read + .format("csv") + .option("delimiter", "\t") + .option("header", "false") + .load(mixedWktGeometryInputLocation) + .withColumn("geom", expr("ST_GeomFromWKT(_c0)")) + .withColumn( + "structColumn", + expr("named_struct('structtext', 'spark', 'structint', 5, 'structbool', false)")) + .rdd + val geomIndex = 18 + val spatialRDD = Adapter.toRowEncodedSpatialRdd(srcRdd, geomIndex, deduplicateGeom = true) + assert( + spatialRDD.rawSpatialRDD + .take(1) + .get(0) + .asInstanceOf[Geometry] + .getUserData + .asInstanceOf[Row] + .schema == StructType(srcRdd.take(1)(0).schema.patch(geomIndex, Nil, 1))) + val roundTripRdd = Adapter.fromRowEncodedGeomRdd(spatialRDD.rawSpatialRDD, geomIndex) + assert(roundTripRdd.collect() sameElements srcRdd.collect()) } } }