diff --git a/README.md b/README.md index 8327533355..2e236de94d 100644 --- a/README.md +++ b/README.md @@ -93,25 +93,43 @@ This example loads NYC taxi trip records and taxi zone information stored as .CS #### Load NYC taxi trips and taxi zones data from CSV Files Stored on AWS S3 ```python -taxidf = sedona.read.format('csv').option("header","true").option("delimiter", ",").load("s3a://your-directory/data/nyc-taxi-data.csv") -taxidf = taxidf.selectExpr('ST_Point(CAST(Start_Lon AS Decimal(24,20)), CAST(Start_Lat AS Decimal(24,20))) AS pickup', 'Trip_Pickup_DateTime', 'Payment_Type', 'Fare_Amt') +taxidf = ( + sedona.read.format("csv") + .option("header", "true") + .option("delimiter", ",") + .load("s3a://your-directory/data/nyc-taxi-data.csv") +) +taxidf = taxidf.selectExpr( + "ST_Point(CAST(Start_Lon AS Decimal(24,20)), CAST(Start_Lat AS Decimal(24,20))) AS pickup", + "Trip_Pickup_DateTime", + "Payment_Type", + "Fare_Amt", +) ``` ```python -zoneDf = sedona.read.format('csv').option("delimiter", ",").load("s3a://your-directory/data/TIGER2018_ZCTA5.csv") -zoneDf = zoneDf.selectExpr('ST_GeomFromWKT(_c0) as zone', '_c1 as zipcode') +zoneDf = ( + sedona.read.format("csv") + .option("delimiter", ",") + .load("s3a://your-directory/data/TIGER2018_ZCTA5.csv") +) +zoneDf = zoneDf.selectExpr("ST_GeomFromWKT(_c0) as zone", "_c1 as zipcode") ``` #### Spatial SQL query to only return Taxi trips in Manhattan ```python -taxidf_mhtn = taxidf.where('ST_Contains(ST_PolygonFromEnvelope(-74.01,40.73,-73.93,40.79), pickup)') +taxidf_mhtn = taxidf.where( + "ST_Contains(ST_PolygonFromEnvelope(-74.01,40.73,-73.93,40.79), pickup)" +) ``` #### Spatial Join between Taxi Dataframe and Zone Dataframe to Find taxis in each zone ```python -taxiVsZone = sedona.sql('SELECT zone, zipcode, pickup, Fare_Amt FROM zoneDf, taxiDf WHERE ST_Contains(zone, pickup)') +taxiVsZone = sedona.sql( + "SELECT zone, zipcode, pickup, Fare_Amt FROM zoneDf, taxiDf WHERE ST_Contains(zone, pickup)" +) ``` #### Show a map of the loaded Spatial Dataframes using GeoPandas @@ -120,14 +138,14 @@ taxiVsZone = sedona.sql('SELECT zone, zipcode, pickup, Fare_Amt FROM zoneDf, tax zoneGpd = gpd.GeoDataFrame(zoneDf.toPandas(), geometry="zone") taxiGpd = gpd.GeoDataFrame(taxidf.toPandas(), geometry="pickup") -zone = zoneGpd.plot(color='yellow', edgecolor='black', zorder=1) -zone.set_xlabel('Longitude (degrees)') -zone.set_ylabel('Latitude (degrees)') +zone = zoneGpd.plot(color="yellow", edgecolor="black", zorder=1) +zone.set_xlabel("Longitude (degrees)") +zone.set_ylabel("Latitude (degrees)") zone.set_xlim(-74.1, -73.8) zone.set_ylim(40.65, 40.9) -taxi = taxiGpd.plot(ax=zone, alpha=0.01, color='red', zorder=3) +taxi = taxiGpd.plot(ax=zone, alpha=0.01, color="red", zorder=3) ``` ## Docker image diff --git a/docs/api/sql/Raster-visualizer.md b/docs/api/sql/Raster-visualizer.md index 1ff207a5e5..f0b83cb165 100644 --- a/docs/api/sql/Raster-visualizer.md +++ b/docs/api/sql/Raster-visualizer.md @@ -78,9 +78,14 @@ Example: ```python from sedona.raster_utils.SedonaUtils import SedonaUtils + # Or from sedona.spark import * -df = sedona.read.format('binaryFile').load(DATA_DIR + 'raster.tiff').selectExpr("RS_FromGeoTiff(content) as raster") +df = ( + sedona.read.format("binaryFile") + .load(DATA_DIR + "raster.tiff") + .selectExpr("RS_FromGeoTiff(content) as raster") +) htmlDF = df.selectExpr("RS_AsImage(raster, 500) as raster_image") SedonaUtils.display_image(htmlDF) ``` diff --git a/docs/api/sql/Spider.md b/docs/api/sql/Spider.md index f3870d5fc0..5dfa6569fe 100644 --- a/docs/api/sql/Spider.md +++ b/docs/api/sql/Spider.md @@ -24,9 +24,18 @@ Sedona offers a spatial data generator called Spider. It is a data source that g Once you have your [`SedonaContext` object created](../Overview#quick-start), you can create a DataFrame with the `spider` data source. ```python -df_random_points = sedona.read.format("spider").load(n=1000, distribution='uniform') -df_random_boxes = sedona.read.format("spider").load(n=1000, distribution='gaussian', geometryType='box', maxWidth=0.05, maxHeight=0.05) -df_random_polygons = sedona.read.format("spider").load(n=1000, distribution='bit', geometryType='polygon', minSegment=3, maxSegment=5, maxSize=0.1) +df_random_points = sedona.read.format("spider").load(n=1000, distribution="uniform") +df_random_boxes = sedona.read.format("spider").load( + n=1000, distribution="gaussian", geometryType="box", maxWidth=0.05, maxHeight=0.05 +) +df_random_polygons = sedona.read.format("spider").load( + n=1000, + distribution="bit", + geometryType="polygon", + minSegment=3, + maxSegment=5, + maxSize=0.1, +) ``` Now we have three DataFrames with random spatial data. We can show the first three rows of the `df_random_points` DataFrame to verify the data is generated correctly. @@ -57,22 +66,24 @@ import matplotlib.pyplot as plt import geopandas as gpd # Convert DataFrames to GeoDataFrames -gdf_random_points = gpd.GeoDataFrame(df_random_points.toPandas(), geometry='geometry') -gdf_random_boxes = gpd.GeoDataFrame(df_random_boxes.toPandas(), geometry='geometry') -gdf_random_polygons = gpd.GeoDataFrame(df_random_polygons.toPandas(), geometry='geometry') +gdf_random_points = gpd.GeoDataFrame(df_random_points.toPandas(), geometry="geometry") +gdf_random_boxes = gpd.GeoDataFrame(df_random_boxes.toPandas(), geometry="geometry") +gdf_random_polygons = gpd.GeoDataFrame( + df_random_polygons.toPandas(), geometry="geometry" +) # Create a figure and a set of subplots fig, axes = plt.subplots(1, 3, figsize=(15, 5)) # Plot each GeoDataFrame on a different subplot -gdf_random_points.plot(ax=axes[0], color='blue', markersize=5) -axes[0].set_title('Random Points') +gdf_random_points.plot(ax=axes[0], color="blue", markersize=5) +axes[0].set_title("Random Points") -gdf_random_boxes.boundary.plot(ax=axes[1], color='red') -axes[1].set_title('Random Boxes') +gdf_random_boxes.boundary.plot(ax=axes[1], color="red") +axes[1].set_title("Random Boxes") -gdf_random_polygons.boundary.plot(ax=axes[2], color='green') -axes[2].set_title('Random Polygons') +gdf_random_polygons.boundary.plot(ax=axes[2], color="green") +axes[2].set_title("Random Polygons") # Adjust the layout plt.tight_layout() @@ -122,8 +133,11 @@ Example: ```python import geopandas as gpd -df = sedona.read.format("spider").load(n=300, distribution='uniform', geometryType='box', maxWidth=0.05, maxHeight=0.05) -gpd.GeoDataFrame(df.toPandas(), geometry='geometry').boundary.plot() + +df = sedona.read.format("spider").load( + n=300, distribution="uniform", geometryType="box", maxWidth=0.05, maxHeight=0.05 +) +gpd.GeoDataFrame(df.toPandas(), geometry="geometry").boundary.plot() ``` ![Uniform Distribution](../../image/spider/spider-uniform.png) @@ -145,8 +159,11 @@ Example: ```python import geopandas as gpd -df = sedona.read.format("spider").load(n=300, distribution='gaussian', geometryType='polygon', maxSize=0.05) -gpd.GeoDataFrame(df.toPandas(), geometry='geometry').boundary.plot() + +df = sedona.read.format("spider").load( + n=300, distribution="gaussian", geometryType="polygon", maxSize=0.05 +) +gpd.GeoDataFrame(df.toPandas(), geometry="geometry").boundary.plot() ``` ![Gaussian Distribution](../../image/spider/spider-gaussian.png) @@ -170,8 +187,11 @@ Example: ```python import geopandas as gpd -df = sedona.read.format("spider").load(n=300, distribution='bit', geometryType='point', probability=0.2, digits=10) -gpd.GeoDataFrame(df.toPandas(), geometry='geometry').plot(markersize=1) + +df = sedona.read.format("spider").load( + n=300, distribution="bit", geometryType="point", probability=0.2, digits=10 +) +gpd.GeoDataFrame(df.toPandas(), geometry="geometry").plot(markersize=1) ``` ![Bit Distribution](../../image/spider/spider-bit.png) @@ -195,8 +215,11 @@ Example: ```python import geopandas as gpd -df = sedona.read.format("spider").load(n=300, distribution='diagonal', geometryType='point', percentage=0.5, buffer=0.5) -gpd.GeoDataFrame(df.toPandas(), geometry='geometry').plot(markersize=1) + +df = sedona.read.format("spider").load( + n=300, distribution="diagonal", geometryType="point", percentage=0.5, buffer=0.5 +) +gpd.GeoDataFrame(df.toPandas(), geometry="geometry").plot(markersize=1) ``` ![Diagonal Distribution](../../image/spider/spider-diagonal.png) @@ -218,8 +241,11 @@ Example: ```python import geopandas as gpd -df = sedona.read.format("spider").load(n=2000, distribution='sierpinski', geometryType='point') -gpd.GeoDataFrame(df.toPandas(), geometry='geometry').plot(markersize=1) + +df = sedona.read.format("spider").load( + n=2000, distribution="sierpinski", geometryType="point" +) +gpd.GeoDataFrame(df.toPandas(), geometry="geometry").plot(markersize=1) ``` ![Sierpinski Distribution](../../image/spider/spider-sierpinski.png) @@ -237,8 +263,11 @@ Example: ```python import geopandas as gpd -df = sedona.read.format("spider").load(n=300, distribution='parcel', dither=0.5, splitRange=0.5) -gpd.GeoDataFrame(df.toPandas(), geometry='geometry').boundary.plot() + +df = sedona.read.format("spider").load( + n=300, distribution="parcel", dither=0.5, splitRange=0.5 +) +gpd.GeoDataFrame(df.toPandas(), geometry="geometry").boundary.plot() ``` ![Parcel Distribution](../../image/spider/spider-parcel.png) @@ -274,8 +303,11 @@ Example: ```python import geopandas as gpd -df_random_points = sedona.read.format("spider").load(n=1000, distribution='uniform', translateX=0.5, translateY=0.5, scaleX=2, scaleY=2) -gpd.GeoDataFrame(df_random_points.toPandas(), geometry='geometry').plot(markersize=1) + +df_random_points = sedona.read.format("spider").load( + n=1000, distribution="uniform", translateX=0.5, translateY=0.5, scaleX=2, scaleY=2 +) +gpd.GeoDataFrame(df_random_points.toPandas(), geometry="geometry").plot(markersize=1) ``` The data is now in the region `[0.5, 2.5] x [0.5, 2.5]`. diff --git a/docs/api/sql/Stac.md b/docs/api/sql/Stac.md index 26ebd082b7..bf91d4c965 100644 --- a/docs/api/sql/Stac.md +++ b/docs/api/sql/Stac.md @@ -34,7 +34,9 @@ df.show() You can load a STAC collection from a s3 collection file object: ```python -df = sedona.read.format("stac").load("s3a://example.com/stac_bucket/stac_collection.json") +df = sedona.read.format("stac").load( + "s3a://example.com/stac_bucket/stac_collection.json" +) df.printSchema() df.show() ``` @@ -42,7 +44,9 @@ df.show() You can also load a STAC collection from an HTTP/HTTPS endpoint: ```python -df = sedona.read.format("stac").load("https://earth-search.aws.element84.com/v1/collections/sentinel-2-pre-c1-l2a") +df = sedona.read.format("stac").load( + "https://earth-search.aws.element84.com/v1/collections/sentinel-2-pre-c1-l2a" +) df.printSchema() df.show() ``` @@ -225,9 +229,7 @@ client = Client.open("https://planetarycomputer.microsoft.com/api/stac/v1") ```python items = client.search( - collection_id="aster-l1t", - datetime="2020", - return_dataframe=False + collection_id="aster-l1t", datetime="2020", return_dataframe=False ) ``` @@ -235,10 +237,7 @@ items = client.search( ```python items = client.search( - collection_id="aster-l1t", - datetime="2020-05", - return_dataframe=False, - max_items=5 + collection_id="aster-l1t", datetime="2020-05", return_dataframe=False, max_items=5 ) ``` @@ -250,22 +249,15 @@ items = client.search( ids=["AST_L1T_00312272006020322_20150518201805"], bbox=[-180.0, -90.0, 180.0, 90.0], datetime=["2006-01-01T00:00:00Z", "2007-01-01T00:00:00Z"], - return_dataframe=False + return_dataframe=False, ) ``` ### Search Multiple Items with Multiple Bounding Boxes ```python -bbox_list = [ - [-180.0, -90.0, 180.0, 90.0], - [-100.0, -50.0, 100.0, 50.0] -] -items = client.search( - collection_id="aster-l1t", - bbox=bbox_list, - return_dataframe=False -) +bbox_list = [[-180.0, -90.0, 180.0, 90.0], [-100.0, -50.0, 100.0, 50.0]] +items = client.search(collection_id="aster-l1t", bbox=bbox_list, return_dataframe=False) ``` ### Search Items and Get DataFrame as Return with Multiple Intervals @@ -273,12 +265,10 @@ items = client.search( ```python interval_list = [ ["2020-01-01T00:00:00Z", "2020-06-01T00:00:00Z"], - ["2020-07-01T00:00:00Z", "2021-01-01T00:00:00Z"] + ["2020-07-01T00:00:00Z", "2021-01-01T00:00:00Z"], ] df = client.search( - collection_id="aster-l1t", - datetime=interval_list, - return_dataframe=True + collection_id="aster-l1t", datetime=interval_list, return_dataframe=True ) df.show() ``` @@ -288,9 +278,7 @@ df.show() ```python # Save items in DataFrame to GeoParquet with both bounding boxes and intervals client.get_collection("aster-l1t").save_to_geoparquet( - output_path="/path/to/output", - bbox=bbox_list, - datetime="2020-05" + output_path="/path/to/output", bbox=bbox_list, datetime="2020-05" ) ``` diff --git a/docs/setup/azure-synapse-analytics.md b/docs/setup/azure-synapse-analytics.md index 1317dc3aeb..1fb268b9f5 100644 --- a/docs/setup/azure-synapse-analytics.md +++ b/docs/setup/azure-synapse-analytics.md @@ -75,14 +75,23 @@ Start your notebook with: ```python from sedona.spark import SedonaContext -config = SedonaContext.builder() \ - .config('spark.jars.packages', - 'org.apache.sedona:sedona-spark-shaded-3.4_2.12-1.6.1,' - 'org.datasyslab:geotools-wrapper-1.6.1-28.2') \ - .config("spark.serializer","org.apache.spark.serializer.KryoSerializer") \ - .config("spark.kryo.registrator", "org.apache.sedona.core.serde.SedonaKryoRegistrator") \ - .config("spark.sql.extensions", "org.apache.sedona.viz.sql.SedonaVizExtensions,org.apache.sedona.sql.SedonaSqlExtensions") \ +config = ( + SedonaContext.builder() + .config( + "spark.jars.packages", + "org.apache.sedona:sedona-spark-shaded-3.4_2.12-1.6.1," + "org.datasyslab:geotools-wrapper-1.6.1-28.2", + ) + .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .config( + "spark.kryo.registrator", "org.apache.sedona.core.serde.SedonaKryoRegistrator" + ) + .config( + "spark.sql.extensions", + "org.apache.sedona.viz.sql.SedonaVizExtensions,org.apache.sedona.sql.SedonaSqlExtensions", + ) .getOrCreate() +) sedona = SedonaContext.create(config) ``` diff --git a/docs/setup/databricks.md b/docs/setup/databricks.md index 434a43f012..0a9e7cda9e 100644 --- a/docs/setup/databricks.md +++ b/docs/setup/databricks.md @@ -56,6 +56,7 @@ SedonaSQLRegistrator.registerAll(spark) ```python from sedona.register.geo_registrator import SedonaRegistrator + SedonaRegistrator.registerAll(spark) ``` diff --git a/docs/setup/install-python.md b/docs/setup/install-python.md index c2f66e4624..5d65f82ba7 100644 --- a/docs/setup/install-python.md +++ b/docs/setup/install-python.md @@ -64,12 +64,20 @@ You can get it using one of the following methods: ```python from sedona.spark import * -config = SedonaContext.builder(). \ - config('spark.jars.packages', - 'org.apache.sedona:sedona-spark-3.3_2.12:{{ sedona.current_version }},' - 'org.datasyslab:geotools-wrapper:{{ sedona.current_geotools }}'). \ - config('spark.jars.repositories', 'https://artifacts.unidata.ucar.edu/repository/unidata-all'). \ - getOrCreate() + +config = ( + SedonaContext.builder() + .config( + "spark.jars.packages", + "org.apache.sedona:sedona-spark-3.3_2.12:{{ sedona.current_version }}," + "org.datasyslab:geotools-wrapper:{{ sedona.current_geotools }}", + ) + .config( + "spark.jars.repositories", + "https://artifacts.unidata.ucar.edu/repository/unidata-all", + ) + .getOrCreate() +) sedona = SedonaContext.create(config) ``` @@ -81,15 +89,18 @@ SedonaRegistrator is deprecated in Sedona 1.4.1 and later versions. Please use t from pyspark.sql import SparkSession from sedona.register import SedonaRegistrator from sedona.utils import SedonaKryoRegistrator, KryoSerializer -spark = SparkSession. \ - builder. \ - appName('appName'). \ - config("spark.serializer", KryoSerializer.getName). \ - config("spark.kryo.registrator", SedonaKryoRegistrator.getName). \ - config('spark.jars.packages', - 'org.apache.sedona:sedona-spark-shaded-3.3_2.12:{{ sedona.current_version }},' - 'org.datasyslab:geotools-wrapper:{{ sedona.current_geotools }}'). \ - getOrCreate() + +spark = ( + SparkSession.builder.appName("appName") + .config("spark.serializer", KryoSerializer.getName) + .config("spark.kryo.registrator", SedonaKryoRegistrator.getName) + .config( + "spark.jars.packages", + "org.apache.sedona:sedona-spark-shaded-3.3_2.12:{{ sedona.current_version }}," + "org.datasyslab:geotools-wrapper:{{ sedona.current_geotools }}", + ) + .getOrCreate() +) SedonaRegistrator.registerAll(spark) ``` diff --git a/docs/setup/release-notes.md b/docs/setup/release-notes.md index 2931e1ce61..7a5666846a 100644 --- a/docs/setup/release-notes.md +++ b/docs/setup/release-notes.md @@ -392,8 +392,10 @@ Sedona 1.6.0 is compiled against Spark 3.3 / Spark 3.4 / Spark 3.5, Flink 1.19, ```python from pyspark.sql.types import DoubleType + def mean_udf(raster): - return float(raster.as_numpy().mean()) + return float(raster.as_numpy().mean()) + sedona.udf.register("mean_udf", mean_udf, DoubleType()) df_raster.withColumn("mean", expr("mean_udf(rast)")).show() @@ -1064,11 +1066,15 @@ Sedona 1.4.1 is compiled against Spark 3.3 / Spark 3.4 / Flink 1.12, Java 8. ```python from sedona.spark import * - config = SedonaContext.builder().\ - config('spark.jars.packages', - 'org.apache.sedona:sedona-spark-shaded-3.3_2.12:1.4.1,' - 'org.datasyslab:geotools-wrapper:1.4.0-28.2'). \ - getOrCreate() + config = ( + SedonaContext.builder() + .config( + "spark.jars.packages", + "org.apache.sedona:sedona-spark-shaded-3.3_2.12:1.4.1," + "org.datasyslab:geotools-wrapper:1.4.0-28.2", + ) + .getOrCreate() + ) sedona = SedonaContext.create(config) sedona.sql("SELECT ST_GeomFromWKT(XXX) FROM") ``` diff --git a/docs/tutorial/concepts/clustering-algorithms.md b/docs/tutorial/concepts/clustering-algorithms.md index 830b0667ee..79cb72cd8f 100644 --- a/docs/tutorial/concepts/clustering-algorithms.md +++ b/docs/tutorial/concepts/clustering-algorithms.md @@ -50,21 +50,24 @@ Let’s create a Spark DataFrame with this data and then run the clustering with ```python df = ( - sedona.createDataFrame([ - (1, 8.0, 2.0), - (2, 2.6, 4.0), - (3, 2.5, 4.0), - (4, 8.5, 2.5), - (5, 2.8, 4.3), - (6, 12.8, 4.5), - (7, 2.5, 4.2), - (8, 8.2, 2.5), - (9, 8.0, 3.0), - (10, 1.0, 5.0), - (11, 8.0, 2.5), - (12, 5.0, 6.0), - (13, 4.0, 3.0), - ], ["id", "x", "y"]) + sedona.createDataFrame( + [ + (1, 8.0, 2.0), + (2, 2.6, 4.0), + (3, 2.5, 4.0), + (4, 8.5, 2.5), + (5, 2.8, 4.3), + (6, 12.8, 4.5), + (7, 2.5, 4.2), + (8, 8.2, 2.5), + (9, 8.0, 3.0), + (10, 1.0, 5.0), + (11, 8.0, 2.5), + (12, 5.0, 6.0), + (13, 4.0, 3.0), + ], + ["id", "x", "y"], + ) ).withColumn("point", ST_Point(col("x"), col("y"))) ``` diff --git a/docs/tutorial/raster.md b/docs/tutorial/raster.md index 1ac0908718..641b441428 100644 --- a/docs/tutorial/raster.md +++ b/docs/tutorial/raster.md @@ -279,7 +279,12 @@ For multiple raster data files use the following code to load the data [from pat === "Python" ```python - rawDf = sedona.read.format("binaryFile").option("recursiveFileLookup", "true").option("pathGlobFilter", "*.tif*").load(path_to_raster_data_folder) + rawDf = ( + sedona.read.format("binaryFile") + .option("recursiveFileLookup", "true") + .option("pathGlobFilter", "*.tif*") + .load(path_to_raster_data_folder) + ) rawDf.createOrReplaceTempView("rawdf") rawDf.show() ``` @@ -608,7 +613,11 @@ Sedona allows collecting Dataframes with raster columns and working with them lo The raster objects are represented as `SedonaRaster` objects in Python, which can be used to perform raster operations. ```python -df_raster = sedona.read.format("binaryFile").load("/path/to/raster.tif").selectExpr("RS_FromGeoTiff(content) as rast") +df_raster = ( + sedona.read.format("binaryFile") + .load("/path/to/raster.tif") + .selectExpr("RS_FromGeoTiff(content) as rast") +) rows = df_raster.collect() raster = rows[0].rast raster # @@ -617,18 +626,18 @@ raster # You can retrieve the metadata of the raster by accessing the properties of the `SedonaRaster` object. ```python -raster.width # width of the raster -raster.height # height of the raster -raster.affine_trans # affine transformation matrix -raster.crs_wkt # coordinate reference system as WKT +raster.width # width of the raster +raster.height # height of the raster +raster.affine_trans # affine transformation matrix +raster.crs_wkt # coordinate reference system as WKT ``` You can get a numpy array containing the band data of the raster using the `as_numpy` or `as_numpy_masked` method. The band data is organized in CHW order. ```python -raster.as_numpy() # numpy array of the raster -raster.as_numpy_masked() # numpy array with nodata values masked as nan +raster.as_numpy() # numpy array of the raster +raster.as_numpy_masked() # numpy array with nodata values masked as nan ``` If you want to work with the raster data using `rasterio`, you can retrieve a `rasterio.DatasetReader` object using the @@ -640,7 +649,7 @@ If you want to work with the raster data using `rasterio`, you can retrieve a `r ```python ds = raster.as_rasterio() # rasterio.DatasetReader object # Work with the raster using rasterio -band1 = ds.read(1) # read the first band +band1 = ds.read(1) # read the first band ``` ## Writing Python UDF to work with raster data @@ -651,9 +660,11 @@ return any Spark data type as output. This is an example of a Python UDF that ca ```python from pyspark.sql.types import DoubleType + def mean_udf(raster): return float(raster.as_numpy().mean()) + sedona.udf.register("mean_udf", mean_udf, DoubleType()) df_raster.withColumn("mean", expr("mean_udf(rast)")).show() ``` @@ -674,13 +685,17 @@ objects yet. However, you can write a UDF that returns the band data as an array from pyspark.sql.types import ArrayType, DoubleType import numpy as np + def mask_udf(raster): - band1 = raster.as_numpy()[0,:,:] + band1 = raster.as_numpy()[0, :, :] mask = (band1 < 1400).astype(np.float64) return mask.flatten().tolist() + sedona.udf.register("mask_udf", band_udf, ArrayType(DoubleType())) -df_raster.withColumn("mask", expr("mask_udf(rast)")).withColumn("mask_rast", expr("RS_MakeRaster(rast, 'I', mask)")).show() +df_raster.withColumn("mask", expr("mask_udf(rast)")).withColumn( + "mask_rast", expr("RS_MakeRaster(rast, 'I', mask)") +).show() ``` ```