Skip to content

Commit

Permalink
[SEDONA-706] Fix Python dataframe api for multi-threaded environment (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Kontinuation authored Feb 3, 2025
1 parent f107af5 commit 3b09d9a
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 13 deletions.
8 changes: 2 additions & 6 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,9 @@ jobs:
- name: Run Spark Connect tests
env:
PYTHON_VERSION: ${{ matrix.python }}
SPARK_VERSION: ${{ matrix.spark }}
if: ${{ matrix.spark >= '3.4.0' }}
run: |
if [ ! -f "${VENV_PATH}/lib/python${PYTHON_VERSION}/site-packages/pyspark/sbin/start-connect-server.sh" ]
then
echo "Skipping connect tests for Spark $SPARK_VERSION"
exit
fi
export SPARK_HOME=${VENV_PATH}/lib/python${PYTHON_VERSION}/site-packages/pyspark
export SPARK_REMOTE=local
Expand Down
14 changes: 7 additions & 7 deletions python/sedona/sql/dataframe_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import typing
from typing import Any, Callable, Iterable, List, Mapping, Tuple, Type, Union

from pyspark import SparkContext
from pyspark.sql import Column, SparkSession
from pyspark.sql import functions as f

Expand Down Expand Up @@ -57,12 +58,6 @@ def _convert_argument_to_java_column(arg: Any) -> Column:
def call_sedona_function(
object_name: str, function_name: str, args: Union[Any, Tuple[Any]]
) -> Column:
spark = SparkSession.getActiveSession()
if spark is None:
raise ValueError(
"No active spark session was detected. Unable to call sedona function."
)

# apparently a Column is an Iterable so we need to check for it explicitly
if (not isinstance(args, Iterable)) or isinstance(
args, (str, Column, ConnectColumn)
Expand All @@ -75,7 +70,12 @@ def call_sedona_function(

args = map(_convert_argument_to_java_column, args)

jobject = getattr(spark._jvm, object_name)
jvm = SparkContext._jvm
if jvm is None:
raise ValueError(
"No active spark context was detected. Unable to call sedona function."
)
jobject = getattr(jvm, object_name)
jfunc = getattr(jobject, function_name)

jc = jfunc(*args)
Expand Down
27 changes: 27 additions & 0 deletions python/tests/sql/test_dataframe_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
# specific language governing permissions and limitations
# under the License.
from math import radians
import os
import threading
import concurrent.futures
from typing import Callable, Tuple

import pytest
Expand Down Expand Up @@ -1732,13 +1735,37 @@ def test_call_function_with_wrong_type(self, func, args):
):
func(*args)

def test_multi_thread(self):
df = self.spark.range(0, 100)

def run_spatial_query():
result = df.select(
stf.ST_Buffer(stc.ST_Point("id", f.col("id") + 1), 1.0).alias("geom")
).collect()
assert len(result) == 100

# Create and run 4 threads
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
futures = [executor.submit(run_spatial_query) for _ in range(4)]
concurrent.futures.wait(futures)
for future in futures:
future.result()

@pytest.mark.skipif(
os.getenv("SPARK_REMOTE") is not None,
reason="Checkpoint dir is not available in Spark Connect",
)
def test_dbscan(self):
df = self.spark.createDataFrame([{"id": 1, "x": 2, "y": 3}]).withColumn(
"geometry", f.expr("ST_Point(x, y)")
)

df.withColumn("dbscan", ST_DBSCAN("geometry", 1.0, 2, False)).collect()

@pytest.mark.skipif(
os.getenv("SPARK_REMOTE") is not None,
reason="Checkpoint dir is not available in Spark Connect",
)
def test_lof(self):
df = self.spark.createDataFrame([{"id": 1, "x": 2, "y": 3}]).withColumn(
"geometry", f.expr("ST_Point(x, y)")
Expand Down

0 comments on commit 3b09d9a

Please sign in to comment.