Skip to content

Commit

Permalink
[SEDONA-703] Add utils for converting between RDD[Row] and SpatialRdd
Browse files Browse the repository at this point in the history
  • Loading branch information
jameswillis committed Jan 25, 2025
1 parent 3151f7d commit b4a18de
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 7 deletions.
105 changes: 99 additions & 6 deletions spark/common/src/main/scala/org/apache/sedona/sql/utils/Adapter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ object Adapter {
val fieldList = dataFrame.schema.toList.map(f => f.name)
val geomColId = fieldList.indexOf(geometryFieldName)
assert(geomColId >= 0)
toRdd(dataFrame, geomColId)
toStringEncodedRdd(dataFrame.rdd, geomColId)
}

/**
Expand All @@ -93,7 +93,7 @@ object Adapter {
geometryColId: Int,
fieldNames: Seq[String]): SpatialRDD[Geometry] = {
var spatialRDD = new SpatialRDD[Geometry]
spatialRDD.rawSpatialRDD = toRdd(dataFrame, geometryColId).toJavaRDD()
spatialRDD.rawSpatialRDD = toStringEncodedRdd(dataFrame.rdd, geometryColId).toJavaRDD()
import scala.jdk.CollectionConverters._
if (fieldNames.nonEmpty) spatialRDD.fieldNames = fieldNames.asJava
else spatialRDD.fieldNames = null
Expand Down Expand Up @@ -283,10 +283,10 @@ object Adapter {
leftGeom ++ leftUserData ++ rightGeom ++ rightUserData
}

private def toRdd(dataFrame: DataFrame, geometryColId: Int): RDD[Geometry] = {
dataFrame.rdd.map[Geometry](f => {
var geometry = f.get(geometryColId).asInstanceOf[Geometry]
var fieldSize = f.size
private def toStringEncodedRdd(rdd: RDD[Row], geometryColId: Int): RDD[Geometry] = {
rdd.map[Geometry](f => {
val geometry = f.get(geometryColId).asInstanceOf[Geometry]
val fieldSize = f.size
var userData: String = null
if (fieldSize > 1) {
userData = ""
Expand All @@ -300,6 +300,99 @@ object Adapter {
})
}

/**
* Convert an RDD of Rows to a SpatialRDD with a geometry column at a specified index.
*
* columns other than the geometry column are serialized into the geometry's user data as tab
* separated strings. The fieldnames field is not set on the SpatialRDD because they are not
* available in the input RDD.
*
* @param rdd
* the RDD of Rows to convert
* @param geometryColId
* the index of the geometry column in the Row to make the spatial RDD of
* @param deduplicateGeom
* whether to remove the geometry from the user data to avoid duplication
* @return
* the SpatialRDD where the geometry column is the geometry and the user data is the other
* columns
*/
def toStringEncodedSpatialRDD(rdd: RDD[Row], geometryColId: Int): SpatialRDD[Geometry] = {
val spatialRDD = new SpatialRDD[Geometry]
spatialRDD.setRawSpatialRDD(toStringEncodedRdd(rdd, geometryColId))
spatialRDD
}

private def toRowEncodedRdd(
rdd: RDD[Row],
geometryColId: Int,
deduplicateGeom: Boolean): RDD[Geometry] = {
rdd.map(f => {
if (deduplicateGeom) {
val withoutGeom = new GenericRowWithSchema(
f.toSeq.patch(geometryColId, Nil, 1).toArray,
StructType(f.schema.patch(geometryColId, Nil, 1)))

val geometry = f.getAs[Geometry](geometryColId)
geometry.setUserData(withoutGeom)
geometry
} else {
val geometry = f.getAs[Geometry](geometryColId)
geometry.setUserData(f)
geometry
}
})
}

/**
* Convert an RDD of Rows to a SpatialRDD with a geometry column at a specified index.
*
* The original Row is stored as the geometry's user data. This should make it easier to work
* with and convert back to a DataFrame. The fieldNames field is not set in the output
* SpatialRDD because they are encoded in the user data.
*
* @param rdd
* the RDD of Rows to convert
* @param geometryColId
* the index of the geometry column in the Row to make the spatial RDD of
* @return
* the SpatialRDD where the geometry column is the geometry and the user data is the Row
*/
def toRowEncodedSpatialRdd(
rdd: RDD[Row],
geometryColId: Int,
deduplicateGeom: Boolean = false): SpatialRDD[Geometry] = {
val spatialRdd = new SpatialRDD[Geometry]
spatialRdd.setRawSpatialRDD(toRowEncodedRdd(rdd, geometryColId, deduplicateGeom))
spatialRdd
}

/**
* Convert an RDD of Geometries where the userData is a Row to an RDD of Rows.
*
* This is the inverse of toRowEncodedSpatialRdd.
*
* @param rdd
* the RDD of Geometries to convert
* @param geomColId
* the index in which to replace the geometry column in the Row. Only use if the geometry was
* removed from the user data at creation time.
* @return
* the RDD of Rows
*/
def fromRowEncodedGeomRdd(rdd: RDD[Geometry], geomColId: Integer = null): RDD[Row] = {
if (geomColId == null) {
rdd.map(geom => Row.fromSeq(geom.getUserData.asInstanceOf[Row].toSeq))
} else {
rdd.map(geom => {
val userData = geom.getUserData.asInstanceOf[Row]
val geomWithoutUserData = geom.copy
geomWithoutUserData.setUserData(null)
Row.fromSeq(userData.toSeq.patch(geomColId, Seq(geomWithoutUserData), 0))
})
}
}

private def getGeomAndFields(
geom: Geometry,
fieldNames: Seq[String]): (Seq[Geometry], Seq[String]) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ import org.apache.sedona.core.formatMapper.shapefileParser.ShapefileReader
import org.apache.sedona.core.spatialOperator.JoinQuery
import org.apache.sedona.core.spatialRDD.{CircleRDD, PointRDD, PolygonRDD}
import org.apache.sedona.sql.utils.Adapter
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.types._
import org.locationtech.jts.geom.Point
import org.locationtech.jts.geom.{Geometry, Point}
import org.scalatest.GivenWhenThen

class adapterTestScala extends TestBaseScala with GivenWhenThen {
Expand Down Expand Up @@ -552,6 +554,56 @@ class adapterTestScala extends TestBaseScala with GivenWhenThen {
assert(row.get(5).asInstanceOf[String] == "attr2")
}
}

}

it("can convert an RDD of Rows to a spatialRDD and back") {
val srcRdd = sparkSession.read
.format("csv")
.option("delimiter", "\t")
.option("header", "false")
.load(mixedWktGeometryInputLocation)
.withColumn("geom", expr("ST_GeomFromWKT(_c0)"))
.withColumn(
"structColumn",
expr("named_struct('structtext', 'spark', 'structint', 5, 'structbool', false)"))
.rdd
val spatialRDD = Adapter.toRowEncodedSpatialRdd(srcRdd, 18)
assert(
spatialRDD.rawSpatialRDD
.take(1)
.get(0)
.asInstanceOf[Geometry]
.getUserData
.asInstanceOf[Row]
.schema == srcRdd.take(1)(0).schema)
val roundTripRdd = Adapter.fromRowEncodedGeomRdd(spatialRDD.rawSpatialRDD)
assert(roundTripRdd.collect() sameElements srcRdd.collect())
}

it("can convert an RDD of Rows to a spatialRDD and back without geometry") {
val srcRdd = sparkSession.read
.format("csv")
.option("delimiter", "\t")
.option("header", "false")
.load(mixedWktGeometryInputLocation)
.withColumn("geom", expr("ST_GeomFromWKT(_c0)"))
.withColumn(
"structColumn",
expr("named_struct('structtext', 'spark', 'structint', 5, 'structbool', false)"))
.rdd
val geomIndex = 18
val spatialRDD = Adapter.toRowEncodedSpatialRdd(srcRdd, geomIndex, deduplicateGeom = true)
assert(
spatialRDD.rawSpatialRDD
.take(1)
.get(0)
.asInstanceOf[Geometry]
.getUserData
.asInstanceOf[Row]
.schema == StructType(srcRdd.take(1)(0).schema.patch(geomIndex, Nil, 1)))
val roundTripRdd = Adapter.fromRowEncodedGeomRdd(spatialRDD.rawSpatialRDD, geomIndex)
assert(roundTripRdd.collect() sameElements srcRdd.collect())
}
}
}

0 comments on commit b4a18de

Please sign in to comment.