From abb55939c284fee6478b75100922074907c0e706 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 23 Sep 2024 17:02:00 -0400 Subject: [PATCH] feat(pyspark): add official support and ci testing with spark connect (#10187) ## Description of changes This PR adds testing for using the pyspark Ibis backend with spark-connect. The way this is done is running a Spark connect instance as a docker compose service, similar to our other client-server model backends. The primary bit of functionality that isn't tested is UDFs (which means JSON unwrapping is also not tested, because that's implemented as a UDF). These effectively require a clone of the Python environment on the server, and that seems out of scope for initial support of spark connect. --- .env | 1 + .github/workflows/ibis-backends.yml | 33 +- compose.yaml | 20 + docker/spark-connect/conf.properties | 12 + docker/spark-connect/log4j2.properties | 68 +++ ibis/backends/pyspark/__init__.py | 38 +- ibis/backends/pyspark/tests/conftest.py | 483 +++++++++++---------- ibis/backends/pyspark/tests/test_array.py | 174 -------- ibis/backends/pyspark/tests/test_basic.py | 20 +- ibis/backends/pyspark/tests/test_ddl.py | 3 +- ibis/backends/pyspark/tests/test_udf.py | 5 + ibis/backends/tests/errors.py | 15 +- ibis/backends/tests/test_aggregation.py | 31 +- ibis/backends/tests/test_array.py | 39 +- ibis/backends/tests/test_client.py | 7 + ibis/backends/tests/test_export.py | 3 + ibis/backends/tests/test_expr_caching.py | 2 + ibis/backends/tests/test_json.py | 8 + ibis/backends/tests/test_register.py | 9 +- ibis/backends/tests/test_temporal.py | 19 +- ibis/backends/tests/test_udf.py | 9 +- ibis/backends/tests/test_vectorized_udf.py | 15 +- ibis/backends/tests/test_window.py | 26 ++ ibis/conftest.py | 2 + pyproject.toml | 1 + 25 files changed, 581 insertions(+), 462 deletions(-) create mode 100644 docker/spark-connect/conf.properties create mode 100644 docker/spark-connect/log4j2.properties delete mode 100644 ibis/backends/pyspark/tests/test_array.py diff --git a/.env b/.env index 05c6411ce4e2..70f4abe8c082 100644 --- a/.env +++ b/.env @@ -5,3 +5,4 @@ PGPASSWORD="postgres" MYSQL_PWD="ibis" MSSQL_SA_PASSWORD="1bis_Testing!" DRUID_URL="druid://localhost:8082/druid/v2/sql" +SPARK_CONFIG=./docker/spark-connect/conf.properties diff --git a/.github/workflows/ibis-backends.yml b/.github/workflows/ibis-backends.yml index d44fbd0b0c48..70ffa1ab4af4 100644 --- a/.github/workflows/ibis-backends.yml +++ b/.github/workflows/ibis-backends.yml @@ -442,13 +442,9 @@ jobs: - name: download backend data run: just download-data - - name: show docker compose version - if: matrix.backend.services != null - run: docker compose version - - name: start services if: matrix.backend.services != null - run: docker compose up --wait ${{ join(matrix.backend.services, ' ') }} + run: just up ${{ join(matrix.backend.services, ' ') }} - name: install python uses: actions/setup-python@v5 @@ -600,7 +596,7 @@ jobs: - name: start services if: matrix.backend.services != null - run: docker compose up --wait ${{ join(matrix.backend.services, ' ') }} + run: just up ${{ join(matrix.backend.services, ' ') }} - name: install python uses: actions/setup-python@v5 @@ -653,7 +649,7 @@ jobs: run: docker compose logs test_pyspark: - name: PySpark ${{ matrix.pyspark-minor-version }} ubuntu-latest python-${{ matrix.python-version }} + name: PySpark ${{ matrix.tag }} ${{ matrix.pyspark-minor-version }} ubuntu-latest python-${{ matrix.python-version }} runs-on: ubuntu-latest strategy: fail-fast: false @@ -665,12 +661,14 @@ jobs: deps: - "'pandas@<2'" - "'numpy@<1.24'" + tag: local - python-version: "3.11" pyspark-version: "3.5.2" pyspark-minor-version: "3.5" deps: - "'pandas@>2'" - "'numpy@>1.24'" + tag: local - python-version: "3.12" pyspark-version: "3.5.2" pyspark-minor-version: "3.5" @@ -678,6 +676,14 @@ jobs: - "'pandas@>2'" - "'numpy@>1.24'" - setuptools + tag: local + - python-version: "3.12" + pyspark-version: "3.5.2" + pyspark-minor-version: "3.5" + deps: + - setuptools + tag: remote + SPARK_REMOTE: "sc://localhost:15002" steps: - name: checkout uses: actions/checkout@v4 @@ -691,6 +697,10 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: start services + if: matrix.tag == 'remote' + run: just up spark-connect + - name: download backend data run: just download-data @@ -730,7 +740,14 @@ jobs: shell: bash run: just download-iceberg-jar ${{ matrix.pyspark-minor-version }} - - name: run tests + - name: run spark connect tests + if: matrix.tag == 'remote' + run: just ci-check -m pyspark + env: + SPARK_REMOTE: ${{ matrix.SPARK_REMOTE }} + + - name: run spark tests + if: matrix.tag == 'local' run: just ci-check -m pyspark - name: check that no untracked files were produced diff --git a/compose.yaml b/compose.yaml index d82c4bc5be33..91107904d7c4 100644 --- a/compose.yaml +++ b/compose.yaml @@ -589,6 +589,24 @@ services: networks: - risingwave + spark-connect: + image: bitnami/spark:3.5.2 + ports: + - 15002:15002 + command: /opt/bitnami/spark/sbin/start-connect-server.sh --name ibis_testing --packages org.apache.spark:spark-connect_2.12:3.5.2,org.apache.iceberg:iceberg-spark-runtime-3.5_2.12:1.5.2 + healthcheck: + test: + - CMD-SHELL + - bash -c 'printf \"GET / HTTP/1.1\n\n\" > /dev/tcp/127.0.0.1/15002; exit $$?;' + interval: 5s + retries: 6 + volumes: + - spark-connect:/data + - $PWD/docker/spark-connect/conf.properties:/opt/bitnami/spark/conf/spark-defaults.conf:ro + # - $PWD/docker/spark-connect/log4j2.properties:/opt/bitnami/spark/conf/log4j2.properties:ro + networks: + - spark-connect + networks: impala: # docker defaults to naming networks "$PROJECT_$NETWORK" but the Java Hive @@ -606,6 +624,7 @@ networks: exasol: flink: risingwave: + spark-connect: volumes: clickhouse: @@ -617,3 +636,4 @@ volumes: exasol: impala: risingwave: + spark-connect: diff --git a/docker/spark-connect/conf.properties b/docker/spark-connect/conf.properties new file mode 100644 index 000000000000..ffc1a253def2 --- /dev/null +++ b/docker/spark-connect/conf.properties @@ -0,0 +1,12 @@ +spark.driver.extraJavaOptions=-Duser.timezone=GMT +spark.executor.extraJavaOptions=-Duser.timezone=GMT +spark.jars.packages=org.apache.iceberg:iceberg-spark-runtime-3.5_2.12:1.5.2 +spark.sql.catalog.local.type=hadoop +spark.sql.catalog.local.warehouse=warehouse +spark.sql.catalog.local=org.apache.iceberg.spark.SparkCatalog +spark.sql.extensions=org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions +spark.sql.legacy.timeParserPolicy=LEGACY +spark.sql.session.timeZone=UTC +spark.sql.streaming.schemaInference=true +spark.ui.enabled=false +spark.ui.showConsoleProgress=false diff --git a/docker/spark-connect/log4j2.properties b/docker/spark-connect/log4j2.properties new file mode 100644 index 000000000000..deab3438d4dd --- /dev/null +++ b/docker/spark-connect/log4j2.properties @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the console +rootLogger.level = error +rootLogger.appenderRef.stdout.ref = console + +# In the pattern layout configuration below, we specify an explicit `%ex` conversion +# pattern for logging Throwables. If this was omitted, then (by default) Log4J would +# implicitly add an `%xEx` conversion pattern which logs stacktraces with additional +# class packaging information. That extra information can sometimes add a substantial +# performance overhead, so we disable it in our default logging config. +# For more information, see SPARK-39361. +appender.console.type = Console +appender.console.name = console +appender.console.target = SYSTEM_ERR +appender.console.layout.type = PatternLayout +appender.console.layout.pattern = %d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n%ex + +# Set the default spark-shell/spark-sql log level to WARN. When running the +# spark-shell/spark-sql, the log level for these classes is used to overwrite +# the root logger's log level, so that the user can have different defaults +# for the shell and regular Spark apps. +logger.repl.name = org.apache.spark.repl.Main +logger.repl.level = error + +logger.thriftserver.name = org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver +logger.thriftserver.level = error + +# Settings to quiet third party logs that are too verbose +logger.jetty1.name = org.sparkproject.jetty +logger.jetty1.level = error +logger.jetty2.name = org.sparkproject.jetty.util.component.AbstractLifeCycle +logger.jetty2.level = error +logger.replexprTyper.name = org.apache.spark.repl.SparkIMain$exprTyper +logger.replexprTyper.level = error +logger.replSparkILoopInterpreter.name = org.apache.spark.repl.SparkILoop$SparkILoopInterpreter +logger.replSparkILoopInterpreter.level = error +logger.parquet1.name = org.apache.parquet +logger.parquet1.level = error +logger.parquet2.name = parquet +logger.parquet2.level = error + +# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support +logger.RetryingHMSHandler.name = org.apache.hadoop.hive.metastore.RetryingHMSHandler +logger.RetryingHMSHandler.level = fatal +logger.FunctionRegistry.name = org.apache.hadoop.hive.ql.exec.FunctionRegistry +logger.FunctionRegistry.level = error + +# For deploying Spark ThriftServer +# SPARK-34128: Suppress undesirable TTransportException warnings involved in THRIFT-4805 +appender.console.filter.1.type = RegexFilter +appender.console.filter.1.regex = .*Thrift error occurred during processing of message.* +appender.console.filter.1.onMatch = deny +appender.console.filter.1.onMismatch = neutral diff --git a/ibis/backends/pyspark/__init__.py b/ibis/backends/pyspark/__init__.py index babbd124eeb6..c5f921381139 100644 --- a/ibis/backends/pyspark/__init__.py +++ b/ibis/backends/pyspark/__init__.py @@ -31,9 +31,15 @@ from ibis.util import deprecated try: - from pyspark.errors import AnalysisException, ParseException + from pyspark.errors import ParseException + from pyspark.errors.exceptions.connect import SparkConnectGrpcException except ImportError: - from pyspark.sql.utils import AnalysisException, ParseException + from pyspark.sql.utils import ParseException + + # Use a dummy class for when spark connect is not available + class SparkConnectGrpcException(Exception): + pass + if TYPE_CHECKING: from collections.abc import Mapping, Sequence @@ -186,13 +192,6 @@ def do_connect( # Databricks Serverless compute only supports limited properties # and any attempt to set unsupported properties will result in an error. # https://docs.databricks.com/en/spark/conf.html - try: - from pyspark.errors.exceptions.connect import SparkConnectGrpcException - except ImportError: - # Use a dummy class for when spark connect is not available - class SparkConnectGrpcException(Exception): - pass - with contextlib.suppress(SparkConnectGrpcException): self._session.conf.set("spark.sql.mapKeyDedupPolicy", "LAST_WIN") @@ -456,7 +455,9 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: df.createTempView(op.name) def _finalize_memtable(self, name: str) -> None: - self._session.catalog.dropTempView(name) + """No-op, otherwise a deadlock can occur when using Spark Connect.""" + if isinstance(session := self._session, pyspark.sql.SparkSession): + session.catalog.dropTempView(name) @contextlib.contextmanager def _safe_raw_sql(self, query: str) -> Any: @@ -579,16 +580,20 @@ def get_schema( table_loc = self._to_sqlglot_table((catalog, database)) catalog, db = self._to_catalog_db_tuple(table_loc) + session = self._session with self._active_catalog_database(catalog, db): try: - df = self._session.table(table_name) - except AnalysisException as e: - if not self._session.catalog.tableExists(table_name): + df = session.table(table_name) + # this is intentionally included in the try block because when + # using spark connect, the table-not-found exception coming + # from the server will *NOT* be raised until the schema + # property is accessed + struct = PySparkType.to_ibis(df.schema) + except Exception as e: + if not session.catalog.tableExists(table_name): raise com.TableNotFound(table_name) from e raise - struct = PySparkType.to_ibis(df.schema) - return sch.Schema(struct) def create_table( @@ -752,7 +757,7 @@ def _create_cached_table(self, name, expr): query = self.compile(expr) t = self._session.sql(query).cache() assert t.is_cached - t.createOrReplaceTempView(name) + t.createTempView(name) # store the underlying spark dataframe so we can release memory when # asked to, instead of when the session ends self._cached_dataframes[name] = t @@ -761,7 +766,6 @@ def _create_cached_table(self, name, expr): def _drop_cached_table(self, name): self._session.catalog.dropTempView(name) t = self._cached_dataframes.pop(name) - assert t.is_cached t.unpersist() assert not t.is_cached diff --git a/ibis/backends/pyspark/tests/conftest.py b/ibis/backends/pyspark/tests/conftest.py index 3d53e52e24be..65df70f78142 100644 --- a/ibis/backends/pyspark/tests/conftest.py +++ b/ibis/backends/pyspark/tests/conftest.py @@ -1,7 +1,9 @@ from __future__ import annotations +import abc import os from datetime import datetime, timedelta, timezone +from pathlib import Path from typing import TYPE_CHECKING, Any from unittest import mock @@ -9,28 +11,33 @@ import pandas as pd import pytest from filelock import FileLock -from packaging.version import parse as vparse import ibis from ibis import util from ibis.backends.conftest import TEST_TABLES from ibis.backends.pyspark import Backend from ibis.backends.pyspark.datatypes import PySparkSchema -from ibis.backends.tests.base import BackendTest +from ibis.backends.tests.base import BackendTest, ServiceBackendTest from ibis.backends.tests.data import json_types, topk, win +from ibis.conftest import IS_SPARK_REMOTE, SPARK_REMOTE if TYPE_CHECKING: - from pathlib import Path + from collections.abc import Iterable def set_pyspark_database(con, database): con._session.catalog.setCurrentDatabase(database) -class TestConf(BackendTest): - deps = ("pyspark",) +class BaseSparkTestConf(abc.ABC): + @property + @abc.abstractmethod + def parquet_dir(self) -> str: + """Directory containing Parquet files.""" def _load_data(self, **_: Any) -> None: + import pyspark.sql.functions as F + import pyspark.sql.types as pt from pyspark.sql import Row s = self.connection._session @@ -39,7 +46,7 @@ def _load_data(self, **_: Any) -> None: sort_cols = {"functional_alltypes": "id"} for name in TEST_TABLES: - path = str(self.data_dir / "parquet" / f"{name}.parquet") + path = os.path.join(self.parquet_dir, f"{name}.parquet") t = s.read.parquet(path).repartition(num_partitions) if (sort_col := sort_cols.get(name)) is not None: t = t.sort(sort_col) @@ -138,246 +145,264 @@ def _load_data(self, **_: Any) -> None: s.createDataFrame(win).createOrReplaceTempView("win") s.createDataFrame(topk.to_pandas()).createOrReplaceTempView("topk") - @staticmethod - def connect(*, tmpdir, worker_id, **kw): - # Spark internally stores timestamps as UTC values, and timestamp - # data that is brought in without a specified time zone is - # converted as local time to UTC with microsecond resolution. - # https://spark.apache.org/docs/latest/sql-pyspark-pandas-with-arrow.html#timestamp-with-time-zone-semantics - - import pyspark - from pyspark.sql import SparkSession - - pyspark_version = vparse(pyspark.__version__) - pyspark_minor_version = f"{pyspark_version.major:d}.{pyspark_version.minor:d}" - - config = ( - SparkSession.builder.appName("ibis_testing") - .master("local[1]") - .config("spark.cores.max", 1) - .config("spark.default.parallelism", 1) - .config("spark.driver.extraJavaOptions", "-Duser.timezone=GMT") - .config("spark.dynamicAllocation.enabled", False) - .config("spark.executor.extraJavaOptions", "-Duser.timezone=GMT") - .config("spark.executor.heartbeatInterval", "3600s") - .config("spark.executor.instances", 1) - .config("spark.network.timeout", "4200s") - .config("spark.rdd.compress", False) - .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - .config("spark.shuffle.compress", False) - .config("spark.shuffle.spill.compress", False) - .config("spark.sql.legacy.timeParserPolicy", "LEGACY") - .config("spark.sql.session.timeZone", "UTC") - .config("spark.sql.shuffle.partitions", 1) - .config("spark.storage.blockManagerSlaveTimeoutMs", "4200s") - .config("spark.ui.enabled", False) - .config("spark.ui.showConsoleProgress", False) - .config("spark.sql.execution.arrow.pyspark.enabled", False) - .config("spark.sql.streaming.schemaInference", True) - .config( - "spark.sql.extensions", - "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions", - ) - .config( - "spark.jars.packages", - f"org.apache.iceberg:iceberg-spark-runtime-{pyspark_minor_version}_2.12:1.5.2", - ) - .config("spark.sql.catalog.local", "org.apache.iceberg.spark.SparkCatalog") - .config("spark.sql.catalog.local.type", "hadoop") - .config("spark.sql.catalog.local.warehouse", "icehouse") + s.range(0, 10).withColumn("str_col", F.lit("value")).createTempView( + "basic_table" ) - try: - from delta.pip_utils import configure_spark_with_delta_pip - except ImportError: - configure_spark_with_delta_pip = lambda cfg: cfg - else: - config = config.config( - "spark.sql.catalog.spark_catalog", - "org.apache.spark.sql.delta.catalog.DeltaCatalog", - ).config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") + df_nulls = s.createDataFrame( + [ + ["k1", np.nan, "Alfred", None], + ["k1", 3.0, None, "joker"], + ["k2", 27.0, "Batman", "batmobile"], + ["k2", None, "Catwoman", "motorcycle"], + ], + ["key", "age", "user", "toy"], + ) + df_nulls.createTempView("null_table") - spark = configure_spark_with_delta_pip(config).getOrCreate() - return ibis.pyspark.connect(spark, **kw) + df_dates = s.createDataFrame( + [["2018-01-02"], ["2018-01-03"], ["2018-01-04"]], ["date_str"] + ) + df_dates.createTempView("date_table") + df_time_indexed = s.createDataFrame( + [ + [datetime(2017, 1, 2, 5, tzinfo=timezone.utc), 1, 1.0], + [datetime(2017, 1, 2, 5, tzinfo=timezone.utc), 2, 2.0], + [datetime(2017, 1, 2, 6, tzinfo=timezone.utc), 1, 3.0], + [datetime(2017, 1, 2, 6, tzinfo=timezone.utc), 2, 4.0], + [datetime(2017, 1, 2, 7, tzinfo=timezone.utc), 1, 5.0], + [datetime(2017, 1, 2, 7, tzinfo=timezone.utc), 2, 6.0], + [datetime(2017, 1, 4, 8, tzinfo=timezone.utc), 1, 7.0], + [datetime(2017, 1, 4, 8, tzinfo=timezone.utc), 2, 8.0], + ], + ["time", "key", "value"], + ) -class TestConfForStreaming(BackendTest): - deps = ("pyspark",) + df_time_indexed.createTempView("time_indexed_table") + + if not IS_SPARK_REMOTE: + # TODO(cpcloud): understand why this doesn't work with spark connect + df_interval = s.createDataFrame( + [ + [ + timedelta(days=10), + timedelta(hours=10), + timedelta(minutes=10), + timedelta(seconds=10), + ] + ], + pt.StructType( + [ + pt.StructField( + "interval_day", + pt.DayTimeIntervalType( + pt.DayTimeIntervalType.DAY, pt.DayTimeIntervalType.DAY + ), + ), + pt.StructField( + "interval_hour", + pt.DayTimeIntervalType( + pt.DayTimeIntervalType.HOUR, pt.DayTimeIntervalType.HOUR + ), + ), + pt.StructField( + "interval_minute", + pt.DayTimeIntervalType( + pt.DayTimeIntervalType.MINUTE, + pt.DayTimeIntervalType.MINUTE, + ), + ), + pt.StructField( + "interval_second", + pt.DayTimeIntervalType( + pt.DayTimeIntervalType.SECOND, + pt.DayTimeIntervalType.SECOND, + ), + ), + ] + ), + ) - def _load_data(self, **_: Any) -> None: - s = self.connection._session - num_partitions = 4 + df_interval.createTempView("interval_table") + + +if IS_SPARK_REMOTE: + + class TestConf(BaseSparkTestConf, ServiceBackendTest): + deps = ("pyspark",) + data_volume = "/data" + service_name = "spark-connect" - watermark_cols = {"functional_alltypes": "timestamp_col"} + @property + def parquet_dir(self) -> str: + return self.data_volume - for name, schema in TEST_TABLES.items(): - path = str(self.data_dir / "directory" / "parquet" / name) - t = ( - s.readStream.schema(PySparkSchema.from_ibis(schema)) - .parquet(path) - .repartition(num_partitions) + @property + def test_files(self) -> Iterable[Path]: + return self.data_dir.joinpath("parquet").glob("*.parquet") + + @staticmethod + def connect(*, tmpdir, worker_id, **kw): + from pyspark.sql import SparkSession + + spark = ( + SparkSession.builder.appName("ibis_testing") + .remote(SPARK_REMOTE) + .getOrCreate() ) - if (watermark_col := watermark_cols.get(name)) is not None: - t = t.withWatermark(watermark_col, "10 seconds") - t.createOrReplaceTempView(name) + return ibis.pyspark.connect(spark, **kw) - @classmethod - def load_data( - cls, data_dir: Path, tmpdir: Path, worker_id: str, **kw: Any - ) -> BackendTest: - """Load testdata from `data_dir`.""" - # handling for multi-processes pytest + @pytest.fixture(scope="session") + def con_streaming(data_dir, tmp_path_factory, worker_id): + pytest.skip("Streaming tests are not supported in remote mode") - # get the temp directory shared by all workers - root_tmp_dir = tmpdir.getbasetemp() / "streaming" - if worker_id != "master": - root_tmp_dir = root_tmp_dir.parent + def write_to_memory(self, expr, table_name): + assert self.mode == "batch" + raise NotImplementedError +else: - fn = root_tmp_dir / cls.name() - with FileLock(f"{fn}.lock"): - cls.skip_if_missing_deps() + class TestConf(BaseSparkTestConf, BackendTest): + deps = ("pyspark",) - inst = cls(data_dir=data_dir, tmpdir=tmpdir, worker_id=worker_id, **kw) + @property + def parquet_dir(self) -> str: + return str(self.data_dir / "parquet") - if inst.stateful: - inst.stateful_load(fn, **kw) - else: - inst.stateless_load(**kw) - inst.postload(tmpdir=tmpdir, worker_id=worker_id, **kw) - return inst + @staticmethod + def connect(*, tmpdir, worker_id, **kw): + from pyspark.sql import SparkSession - @staticmethod - def connect(*, tmpdir, worker_id, **kw): - from pyspark.sql import SparkSession + config = SparkSession.builder.appName("ibis_testing") - # SparkContext is shared globally; only one SparkContext should be active - # per JVM. We need to create a new SparkSession for streaming tests but - # this session shares the same SparkContext. - spark = SparkSession.getActiveSession().newSession() - con = ibis.pyspark.connect(spark, mode="streaming", **kw) - return con + # load from properties file, yuck + with Path( + os.environ.get( + "SPARK_CONFIG", + Path(ibis.__file__) + .parents[1] + .joinpath("docker", "spark-connect", "conf.properties"), + ) + ).open(mode="r") as config_file: + for line in config_file: + config = config.config(*map(str.strip, line.strip().split("=", 1))) + + config = ( + config.config("spark.cores.max", "1") + .config("spark.default.parallelism", "1") + .config("spark.dynamicAllocation.enabled", "false") + .config("spark.executor.heartbeatInterval", "3600s") + .config("spark.executor.instances", "1") + .config("spark.network.timeout", "4200s") + .config("spark.rdd.compress", "false") + .config( + "spark.serializer", "org.apache.spark.serializer.KryoSerializer" + ) + .config("spark.shuffle.compress", "false") + .config("spark.shuffle.spill.compress", "false") + .config("spark.sql.execution.arrow.pyspark.enabled", "false") + .config("spark.sql.shuffle.partitions", "1") + .config("spark.storage.blockManagerSlaveTimeoutMs", "4200s") + ) + try: + from delta.pip_utils import configure_spark_with_delta_pip + except ImportError: + configure_spark_with_delta_pip = lambda cfg: cfg + else: + config = config.config( + "spark.sql.catalog.spark_catalog", + "org.apache.spark.sql.delta.catalog.DeltaCatalog", + ).config( + "spark.sql.extensions", + "io.delta.sql.DeltaSparkSessionExtension", + ) -@pytest.fixture(scope="session") -def con(data_dir, tmp_path_factory, worker_id): - import pyspark.sql.functions as F - import pyspark.sql.types as pt + spark = configure_spark_with_delta_pip(config).getOrCreate() + return ibis.pyspark.connect(spark, **kw) - backend_test = TestConf.load_data(data_dir, tmp_path_factory, worker_id) - con = backend_test.connection + class TestConfForStreaming(BackendTest): + deps = ("pyspark",) - df = con._session.range(0, 10) - df = df.withColumn("str_col", F.lit("value")) - df.createTempView("basic_table") - - df_nulls = con._session.createDataFrame( - [ - ["k1", np.nan, "Alfred", None], - ["k1", 3.0, None, "joker"], - ["k2", 27.0, "Batman", "batmobile"], - ["k2", None, "Catwoman", "motorcycle"], - ], - ["key", "age", "user", "toy"], - ) - df_nulls.createTempView("null_table") - - df_dates = con._session.createDataFrame( - [["2018-01-02"], ["2018-01-03"], ["2018-01-04"]], ["date_str"] - ) - df_dates.createTempView("date_table") - - df_arrays = con._session.createDataFrame( - [ - ["k1", [1, 2, 3], ["a"]], - ["k2", [4, 5], ["test1", "test2", "test3"]], - ["k3", [6], ["w", "x", "y", "z"]], - ["k1", [], ["cat", "dog"]], - ["k1", [7, 8], []], - ], - ["key", "array_int", "array_str"], - ) - df_arrays.createTempView("array_table") - - df_time_indexed = con._session.createDataFrame( - [ - [datetime(2017, 1, 2, 5, tzinfo=timezone.utc), 1, 1.0], - [datetime(2017, 1, 2, 5, tzinfo=timezone.utc), 2, 2.0], - [datetime(2017, 1, 2, 6, tzinfo=timezone.utc), 1, 3.0], - [datetime(2017, 1, 2, 6, tzinfo=timezone.utc), 2, 4.0], - [datetime(2017, 1, 2, 7, tzinfo=timezone.utc), 1, 5.0], - [datetime(2017, 1, 2, 7, tzinfo=timezone.utc), 2, 6.0], - [datetime(2017, 1, 4, 8, tzinfo=timezone.utc), 1, 7.0], - [datetime(2017, 1, 4, 8, tzinfo=timezone.utc), 2, 8.0], - ], - ["time", "key", "value"], - ) - - df_time_indexed.createTempView("time_indexed_table") - - df_interval = con._session.createDataFrame( - [ - [ - timedelta(days=10), - timedelta(hours=10), - timedelta(minutes=10), - timedelta(seconds=10), - ] - ], - pt.StructType( - [ - pt.StructField( - "interval_day", - pt.DayTimeIntervalType( - pt.DayTimeIntervalType.DAY, pt.DayTimeIntervalType.DAY - ), - ), - pt.StructField( - "interval_hour", - pt.DayTimeIntervalType( - pt.DayTimeIntervalType.HOUR, pt.DayTimeIntervalType.HOUR - ), - ), - pt.StructField( - "interval_minute", - pt.DayTimeIntervalType( - pt.DayTimeIntervalType.MINUTE, pt.DayTimeIntervalType.MINUTE - ), - ), - pt.StructField( - "interval_second", - pt.DayTimeIntervalType( - pt.DayTimeIntervalType.SECOND, pt.DayTimeIntervalType.SECOND - ), - ), - ] - ), - ) + def _load_data(self, **_: Any) -> None: + s = self.connection._session + num_partitions = 4 - df_interval.createTempView("interval_table") + watermark_cols = {"functional_alltypes": "timestamp_col"} - df_interval_invalid = con._session.createDataFrame( - [[timedelta(days=10, hours=10, minutes=10, seconds=10)]], - pt.StructType( - [ - pt.StructField( - "interval_day_hour", - pt.DayTimeIntervalType( - pt.DayTimeIntervalType.DAY, pt.DayTimeIntervalType.HOUR - ), + for name, schema in TEST_TABLES.items(): + path = str(self.data_dir / "directory" / "parquet" / name) + t = ( + s.readStream.schema(PySparkSchema.from_ibis(schema)) + .parquet(path) + .repartition(num_partitions) ) - ] - ), - ) + if (watermark_col := watermark_cols.get(name)) is not None: + t = t.withWatermark(watermark_col, "10 seconds") + t.createOrReplaceTempView(name) + + @classmethod + def load_data( + cls, data_dir: Path, tmpdir: Path, worker_id: str, **kw: Any + ) -> BackendTest: + """Load testdata from `data_dir`.""" + # handling for multi-processes pytest + + # get the temp directory shared by all workers + root_tmp_dir = tmpdir.getbasetemp() / "streaming" + if worker_id != "master": + root_tmp_dir = root_tmp_dir.parent + + fn = root_tmp_dir / cls.name() + with FileLock(f"{fn}.lock"): + cls.skip_if_missing_deps() + + inst = cls(data_dir=data_dir, tmpdir=tmpdir, worker_id=worker_id, **kw) + + if inst.stateful: + inst.stateful_load(fn, **kw) + else: + inst.stateless_load(**kw) + inst.postload(tmpdir=tmpdir, worker_id=worker_id, **kw) + return inst + + @staticmethod + def connect(*, tmpdir, worker_id, **kw): + from pyspark.sql import SparkSession + + # SparkContext is shared globally; only one SparkContext should be active + # per JVM. We need to create a new SparkSession for streaming tests but + # this session shares the same SparkContext. + spark = SparkSession.getActiveSession().newSession() + con = ibis.pyspark.connect(spark, mode="streaming", **kw) + return con + + @pytest.fixture(scope="session") + def con_streaming(data_dir, tmp_path_factory, worker_id): + backend_test = TestConfForStreaming.load_data( + data_dir, tmp_path_factory, worker_id + ) + return backend_test.connection - df_interval_invalid.createTempView("invalid_interval_table") + @pytest.fixture(autouse=True, scope="function") + def stop_active_jobs(con_streaming): + yield + for sq in con_streaming._session.streams.active: + sq.stop() + sq.awaitTermination() - return con + def write_to_memory(self, expr, table_name): + assert self.mode == "streaming" + df = self._session.sql(expr.compile()) + df.writeStream.format("memory").queryName(table_name).start() @pytest.fixture(scope="session") -def con_streaming(data_dir, tmp_path_factory, worker_id): - backend_test = TestConfForStreaming.load_data(data_dir, tmp_path_factory, worker_id) - return backend_test.connection +def con(data_dir, tmp_path_factory, worker_id): + backend_test = TestConf.load_data(data_dir, tmp_path_factory, worker_id) + con = backend_test.connection + + return con class IbisWindow: @@ -388,12 +413,7 @@ def __init__(self, windows): def get_windows(self): # Return a list of Ibis windows return [ - ibis.window( - preceding=w[0], - following=w[1], - order_by="time", - group_by="key", - ) + ibis.window(preceding=w[0], following=w[1], order_by="time", group_by="key") for w in self.windows ] @@ -440,21 +460,6 @@ def default_session_fixture(): yield -def write_to_memory(self, expr, table_name): - if self.mode == "batch": - raise NotImplementedError - df = self._session.sql(expr.compile()) - df.writeStream.format("memory").queryName(table_name).start() - - -@pytest.fixture(autouse=True, scope="function") -def stop_active_jobs(con_streaming): - yield - for sq in con_streaming._session.streams.active: - sq.stop() - sq.awaitTermination() - - @pytest.fixture def awards_players_schema(): return TEST_TABLES["awards_players"] diff --git a/ibis/backends/pyspark/tests/test_array.py b/ibis/backends/pyspark/tests/test_array.py deleted file mode 100644 index b253b084cc70..000000000000 --- a/ibis/backends/pyspark/tests/test_array.py +++ /dev/null @@ -1,174 +0,0 @@ -from __future__ import annotations - -import numpy as np -import pandas as pd -import pandas.testing as tm -import pytest - -import ibis - -pytest.importorskip("pyspark") - - -@pytest.fixture -def t(con): - return con.table("array_table") - - -@pytest.fixture -def df(con): - return con._session.table("array_table").toPandas() - - -def test_array_length(t, df): - result = t.mutate(length=t.array_int.length()).execute() - expected = df.assign(length=df.array_int.map(lambda a: len(a))) - tm.assert_frame_equal(result, expected) - - -def test_array_length_scalar(con): - raw_value = [1, 2, 3] - value = ibis.literal(raw_value) - expr = value.length() - result = con.execute(expr) - expected = len(raw_value) - assert result == expected - - -@pytest.mark.parametrize( - ["start", "stop"], - [ - (1, 3), - (1, 1), - (2, 3), - (2, 5), - (None, 3), - (None, None), - (3, None), - (-3, None), - (None, -3), - (-3, -1), - ], -) -def test_array_slice(t, df, start, stop): - result = t.mutate(sliced=t.array_int[start:stop]).execute() - expected = df.assign(sliced=df.array_int.map(lambda a: a[start:stop])) - tm.assert_frame_equal(result, expected) - - -@pytest.mark.parametrize( - ["start", "stop"], - [ - (1, 3), - (1, 1), - (2, 3), - (2, 5), - (None, 3), - (None, None), - (3, None), - (-3, None), - (None, -3), - (-3, -1), - ], -) -def test_array_slice_scalar(con, start, stop): - raw_value = [-11, 42, 10] - value = ibis.literal(raw_value) - expr = value[start:stop] - result = con.execute(expr) - expected = raw_value[start:stop] - assert result == expected - - -@pytest.mark.parametrize("index", [1, 3, 4, 11, -11]) -def test_array_index(t, df, index): - expr = t.select(t.array_int[index].name("indexed")) - result = expr.execute() - - expected = pd.DataFrame( - { - "indexed": df.array_int.apply( - lambda x: x[index] if -len(x) <= index < len(x) else np.nan - ) - } - ) - tm.assert_frame_equal(result, expected) - - -@pytest.mark.parametrize("index", [1, 3, 4, 11]) -def test_array_index_scalar(con, index): - raw_value = [-10, 1, 2, 42] - value = ibis.literal(raw_value) - expr = value[index] - result = con.execute(expr) - expected = raw_value[index] if index < len(raw_value) else np.nan - assert result == expected or (np.isnan(result) and np.isnan(expected)) - - -@pytest.mark.parametrize("op", [lambda x, y: x + y, lambda x, y: y + x]) -def test_array_concat(t, df, op): - x = t.array_int.cast("array") - y = t.array_str - expr = op(x, y).name("array_result") - result = expr.execute() - - expected = op(df.array_int.apply(lambda x: list(map(str, x))), df.array_str).rename( - "array_result" - ) - tm.assert_series_equal(result, expected) - - -@pytest.mark.parametrize("op", [lambda x, y: x + y, lambda x, y: y + x]) -def test_array_concat_scalar(con, op): - raw_left = [1, 2, 3] - raw_right = [3, 4] - left = ibis.literal(raw_left) - right = ibis.literal(raw_right) - expr = op(left, right) - result = con.execute(expr) - assert result == op(raw_left, raw_right) - - -@pytest.mark.parametrize("n", [1, 3, 4, 7, -2]) # negative returns empty list -@pytest.mark.parametrize("mul", [lambda x, n: x * n, lambda x, n: n * x]) -def test_array_repeat(t, df, n, mul): - expr = t.select(mul(t.array_int, n).name("repeated")) - result = expr.execute() - - expected = pd.DataFrame({"repeated": df.array_int * n}) - tm.assert_frame_equal(result, expected) - - -@pytest.mark.parametrize("n", [1, 3, 4, 7, -2]) # negative returns empty list -@pytest.mark.parametrize("mul", [lambda x, n: x * n, lambda x, n: n * x]) -def test_array_repeat_scalar(con, n, mul): - raw_array = [1, 2] - array = ibis.literal(raw_array) - expr = mul(array, n) - result = con.execute(expr) - expected = mul(raw_array, n) - assert result == expected - - -def test_array_collect(t, df): - expr = t.group_by(t.key).aggregate(collected=t.array_int.collect()) - result = expr.execute().sort_values("key").reset_index(drop=True) - - expected = ( - df.groupby("key") - .array_int.apply(list) - .reset_index() - .rename(columns={"array_int": "collected"}) - ) - tm.assert_frame_equal(result, expected) - - -def test_array_filter(t, df): - expr = t.select(t.array_int.filter(lambda item: item != 3).name("array_int")) - result = expr.execute() - - df["array_int"] = df["array_int"].apply( - lambda ar: [item for item in ar if item != 3] - ) - expected = df[["array_int"]] - tm.assert_frame_equal(result, expected) diff --git a/ibis/backends/pyspark/tests/test_basic.py b/ibis/backends/pyspark/tests/test_basic.py index 192d06e4f081..95bafb4690dd 100644 --- a/ibis/backends/pyspark/tests/test_basic.py +++ b/ibis/backends/pyspark/tests/test_basic.py @@ -1,7 +1,10 @@ from __future__ import annotations +from datetime import timedelta + import pandas as pd import pandas.testing as tm +import pyspark.sql.types as pt import pytest from pytest import param @@ -116,7 +119,22 @@ def test_alias_after_select(t, df): def test_interval_columns_invalid(con): - msg = r"DayTimeIntervalType\(0, 1\) couldn't be converted to Interval" + df_interval_invalid = con._session.createDataFrame( + [[timedelta(days=10, hours=10, minutes=10, seconds=10)]], + pt.StructType( + [ + pt.StructField( + "interval_day_hour", + pt.DayTimeIntervalType( + pt.DayTimeIntervalType.DAY, pt.DayTimeIntervalType.SECOND + ), + ) + ] + ), + ) + + df_interval_invalid.createTempView("invalid_interval_table") + msg = r"DayTimeIntervalType.+ couldn't be converted to Interval" with pytest.raises(IbisTypeError, match=msg): con.table("invalid_interval_table") diff --git a/ibis/backends/pyspark/tests/test_ddl.py b/ibis/backends/pyspark/tests/test_ddl.py index 975d2840ca0e..cb00750a0003 100644 --- a/ibis/backends/pyspark/tests/test_ddl.py +++ b/ibis/backends/pyspark/tests/test_ddl.py @@ -59,10 +59,9 @@ def temp_db(con, temp_base): def test_create_database_with_location(con, temp_db): - base = os.path.dirname(temp_db) name = os.path.basename(temp_db) con.create_database(name, path=temp_db) - assert os.path.exists(base) + assert name in con.list_databases() def test_drop_table_not_exist(con): diff --git a/ibis/backends/pyspark/tests/test_udf.py b/ibis/backends/pyspark/tests/test_udf.py index d5c80ac27c35..e6b8789f8fce 100644 --- a/ibis/backends/pyspark/tests/test_udf.py +++ b/ibis/backends/pyspark/tests/test_udf.py @@ -5,6 +5,7 @@ import ibis from ibis.backends.pyspark import PYSPARK_LT_35 +from ibis.conftest import IS_SPARK_REMOTE pytest.importorskip("pyspark") @@ -46,6 +47,10 @@ def test_python_udf(t, df): @pytest.mark.xfail(PYSPARK_LT_35, reason="pyarrow UDFs require PySpark 3.5+") +@pytest.mark.xfail( + IS_SPARK_REMOTE, + reason="pyarrow UDFs aren't tested with spark remote due to environment setup complexities", +) def test_pyarrow_udf(t, df): result = t.mutate(repeated=pyarrow_repeat(t.str_col, 2)).execute() expected = df.assign(repeated=df.str_col * 2) diff --git a/ibis/backends/tests/errors.py b/ibis/backends/tests/errors.py index 436f74dab71e..17bb81b97849 100644 --- a/ibis/backends/tests/errors.py +++ b/ibis/backends/tests/errors.py @@ -27,7 +27,6 @@ None ) - try: from pyexasol.exceptions import ExaQueryError except ImportError: @@ -47,11 +46,21 @@ try: # PySpark 3.5.0 - from pyspark.errors.exceptions.captured import ( + from pyspark.errors.exceptions.base import ( + AnalysisException as PySparkAnalysisException, + ) + from pyspark.errors.exceptions.base import ( ArithmeticException as PySparkArithmeticException, ) + from pyspark.errors.exceptions.base import ParseException as PySparkParseException + from pyspark.errors.exceptions.base import PythonException as PySparkPythonException + from pyspark.errors.exceptions.connect import ( + SparkConnectGrpcException as PySparkConnectGrpcException, + ) except ImportError: - PySparkArithmeticException = None + PySparkParseException = PySparkAnalysisException = PySparkArithmeticException = ( + PySparkPythonException + ) = PySparkConnectGrpcException = None try: from google.api_core.exceptions import BadRequest as GoogleBadRequest diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 1db71e2eccc9..62b2142d19a3 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -26,9 +26,11 @@ PyDruidProgrammingError, PyODBCProgrammingError, PySparkAnalysisException, + PySparkPythonException, SnowflakeProgrammingError, TrinoUserError, ) +from ibis.conftest import IS_SPARK_REMOTE from ibis.legacy.udf.vectorized import reduction np = pytest.importorskip("numpy") @@ -73,6 +75,12 @@ def mean_udf(s): reason="no udf support", raises=com.OperationNotDefinedError, ), + pytest.mark.notyet( + ["pyspark"], + condition=IS_SPARK_REMOTE, + raises=PySparkPythonException, + reason="remote udfs not yet tested due to environment complexities", + ), ], ), param(lambda t: t.double_col.min(), lambda t: t.double_col.min(), id="min"), @@ -211,14 +219,9 @@ def test_aggregate_grouped(backend, alltypes, df, result_fn, expected_fn): ], raises=com.OperationNotDefinedError, ) -@pytest.mark.notimpl( +@pytest.mark.notyet( ["pyspark"], - raises=NotImplementedError, - reason=( - "Invalid return type with grouped aggregate Pandas UDFs: " - "StructType([StructField('mean', DoubleType(), True), " - "StructField('std', DoubleType(), True)]) is not supported" - ), + raises=(PySparkPythonException, NotImplementedError), ) def test_aggregate_multikey_group_reduction_udf(backend, alltypes, df): """Tests .aggregate() on a multi-key group_by with a reduction @@ -1426,7 +1429,7 @@ def test_topk_filter_op(con, alltypes, df, result_fn, expected_fn): @pytest.mark.parametrize( - "agg_fn", [lambda s: list(s), lambda s: np.array(s)], ids=lambda obj: obj.__name__ + "agg_fn", [lambda s: list(s), lambda s: np.array(s)], ids=["list", "ndarray"] ) @pytest.mark.notimpl( [ @@ -1450,6 +1453,12 @@ def test_topk_filter_op(con, alltypes, df, result_fn, expected_fn): ], raises=com.OperationNotDefinedError, ) +@pytest.mark.notyet( + ["pyspark"], + condition=IS_SPARK_REMOTE, + raises=PySparkPythonException, + reason="remote udfs not yet tested due to environment complexities", +) def test_aggregate_list_like(backend, alltypes, df, agg_fn): """Tests .aggregate() where the result of an aggregation is a list-like. @@ -1492,6 +1501,12 @@ def test_aggregate_list_like(backend, alltypes, df, agg_fn): ], raises=com.OperationNotDefinedError, ) +@pytest.mark.notyet( + ["pyspark"], + condition=IS_SPARK_REMOTE, + raises=PySparkPythonException, + reason="remote udfs not yet tested due to environment complexities", +) def test_aggregate_mixed_udf(backend, alltypes, df): """Tests .aggregate() with multiple aggregations with mixed result types. diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index 123aa60409a3..c49d808412c9 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -31,6 +31,7 @@ TrinoUserError, ) from ibis.common.collections import frozendict +from ibis.conftest import IS_SPARK_REMOTE np = pytest.importorskip("numpy") pd = pytest.importorskip("pandas") @@ -170,6 +171,17 @@ def test_array_radd_concat(con): assert np.array_equal(result, expected) +@pytest.mark.parametrize("op", [lambda x, y: x + y, lambda x, y: y + x]) +def test_array_concat_scalar(con, op): + raw_left = [1, 2, 3] + raw_right = [3, 4] + left = ibis.literal(raw_left) + right = ibis.literal(raw_right) + expr = op(left, right) + result = con.execute(expr) + assert result == op(raw_left, raw_right) + + def test_array_length(con): expr = ibis.literal([1, 2, 3]).length() assert con.execute(expr.name("tmp")) == 3 @@ -432,7 +444,13 @@ def test_array_slice(backend, start, stop): ["bigquery"], raises=GoogleBadRequest, reason="BigQuery doesn't support arrays with null elements", - ) + ), + pytest.mark.notyet( + ["pyspark"], + condition=IS_SPARK_REMOTE, + raises=AssertionError, + reason="somehow, transformed results are different types", + ), ], id="nulls", ), @@ -444,11 +462,6 @@ def test_array_slice(backend, start, stop): [lambda x: x + 1, partial(lambda x, y: x + y, y=1), ibis._ + 1], ids=["lambda", "partial", "deferred"], ) -@pytest.mark.notimpl( - ["risingwave"], - raises=PsycoPg2InternalError, - reason="TODO(Kexiang): seems a bug", -) def test_array_map(con, input, output, func): t = ibis.memtable(input, schema=ibis.schema(dict(a="!array"))) t = ibis.memtable(input, schema=ibis.schema(dict(a="!array"))) @@ -672,6 +685,14 @@ def test_array_remove(con, input, expected): {"a": [[1, 3, 3], [], [42, 42], [], [None], None]}, [{3, 1}, set(), {42}, set(), {None}, None], id="null", + marks=[ + pytest.mark.notyet( + ["pyspark"], + condition=IS_SPARK_REMOTE, + raises=AssertionError, + reason="somehow, transformed results are different types", + ), + ], ), param( {"a": [[1, 3, 3], [], [42, 42], [], None]}, @@ -743,6 +764,12 @@ def test_array_sort(con, data): raises=AssertionError, reason="DataFusion transforms null elements to NAN", ), + pytest.mark.notyet( + ["pyspark"], + condition=IS_SPARK_REMOTE, + raises=AssertionError, + reason="somehow, transformed results are different types", + ), ], ), param( diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index 069d6a525b2a..7d1ae751e58b 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -35,6 +35,7 @@ PyODBCProgrammingError, SnowflakeProgrammingError, ) +from ibis.conftest import IS_SPARK_REMOTE from ibis.util import gen_name if TYPE_CHECKING: @@ -1686,6 +1687,12 @@ def test_insert_into_table_missing_columns(con, temp_table): @pytest.mark.notyet( ["bigquery"], raises=AssertionError, reason="test is flaky", strict=False ) +@pytest.mark.notyet( + ["pyspark"], + condition=IS_SPARK_REMOTE, + raises=AssertionError, + reason="likely deadlock when using spark connect", +) def test_memtable_cleanup(con): name = ibis.util.gen_name("temp_memtable") t = ibis.memtable({"a": [1, 2, 3], "b": list("def")}, name=name) diff --git a/ibis/backends/tests/test_export.py b/ibis/backends/tests/test_export.py index 112bcbf3f128..977243519862 100644 --- a/ibis/backends/tests/test_export.py +++ b/ibis/backends/tests/test_export.py @@ -23,6 +23,7 @@ SnowflakeProgrammingError, TrinoUserError, ) +from ibis.conftest import IS_SPARK_REMOTE pd = pytest.importorskip("pandas") pa = pytest.importorskip("pyarrow") @@ -215,6 +216,8 @@ def test_table_to_parquet_dir(tmp_path, backend, awards_players): outparquet_dir = tmp_path / "out" if backend.name() == "pyspark": + if IS_SPARK_REMOTE: + pytest.skip("writes to remote output directory") # pyspark already writes more than one file awards_players.to_parquet_dir(outparquet_dir) else: diff --git a/ibis/backends/tests/test_expr_caching.py b/ibis/backends/tests/test_expr_caching.py index 75aba279553e..03c73e3e9772 100644 --- a/ibis/backends/tests/test_expr_caching.py +++ b/ibis/backends/tests/test_expr_caching.py @@ -5,6 +5,7 @@ import ibis import ibis.common.exceptions as com +from ibis.conftest import IS_SPARK_REMOTE pa = pytest.importorskip("pyarrow") ds = pytest.importorskip("pyarrow.dataset") @@ -56,6 +57,7 @@ def test_persist_expression_contextmanager(backend, con, alltypes): @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") @mark.notyet( ["pyspark"], + condition=not IS_SPARK_REMOTE, raises=AssertionError, reason=( "PySpark holds on to `cached_table` in the stack frame of an internal function. " diff --git a/ibis/backends/tests/test_json.py b/ibis/backends/tests/test_json.py index ae2374d8dcd1..d4d772beda26 100644 --- a/ibis/backends/tests/test_json.py +++ b/ibis/backends/tests/test_json.py @@ -8,6 +8,8 @@ from packaging.version import parse as vparse import ibis.expr.types as ir +from ibis.backends.tests.errors import PySparkPythonException +from ibis.conftest import IS_SPARK_REMOTE np = pytest.importorskip("numpy") pd = pytest.importorskip("pandas") @@ -109,6 +111,12 @@ def test_json_array(backend, json_t): ) @pytest.mark.notimpl(["risingwave"]) @pytest.mark.notyet(["flink"], reason="should work but doesn't deserialize JSON") +@pytest.mark.notyet( + ["pyspark"], + condition=IS_SPARK_REMOTE, + raises=PySparkPythonException, + reason="environment issues", +) @pytest.mark.parametrize( ("typ", "expected_data"), [ diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index 02df85ef1cb0..0ed6925a2ce7 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -12,13 +12,20 @@ import ibis from ibis.backends.conftest import TEST_TABLES +from ibis.backends.tests.errors import PySparkAnalysisException +from ibis.conftest import IS_SPARK_REMOTE if TYPE_CHECKING: from collections.abc import Iterator import pyarrow as pa -pytestmark = pytest.mark.notimpl(["druid", "exasol", "oracle"]) +pytestmark = [ + pytest.mark.notimpl(["druid", "exasol", "oracle"]), + pytest.mark.notyet( + ["pyspark"], condition=IS_SPARK_REMOTE, raises=PySparkAnalysisException + ), +] @contextlib.contextmanager diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index 0bc48c3038b9..9c08d7fe3245 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -33,10 +33,12 @@ Py4JJavaError, PyDruidProgrammingError, PyODBCProgrammingError, + PySparkConnectGrpcException, SnowflakeProgrammingError, TrinoUserError, ) from ibis.common.annotations import ValidationError +from ibis.conftest import IS_SPARK_REMOTE np = pytest.importorskip("numpy") pd = pytest.importorskip("pandas") @@ -743,6 +745,12 @@ def convert_to_offset(x): ["bigquery", "snowflake", "sqlite", "exasol", "mssql"], raises=com.OperationNotDefinedError, ), + pytest.mark.notyet( + ["pyspark"], + condition=IS_SPARK_REMOTE, + raises=PySparkConnectGrpcException, + reason="arrow conversion breaks", + ), pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError), pytest.mark.notimpl( ["duckdb"], @@ -792,6 +800,12 @@ def convert_to_offset(x): raises=com.OperationNotDefinedError, reason="Some wonkiness in sqlglot generation.", ), + pytest.mark.notyet( + ["pyspark"], + condition=IS_SPARK_REMOTE, + raises=PySparkConnectGrpcException, + reason="arrow conversion breaks", + ), ], ), ], @@ -1691,6 +1705,7 @@ def test_integer_cast_to_timestamp_scalar(alltypes, df): @pytest.mark.notyet( ["pyspark"], reason="PySpark doesn't handle big timestamps", + condition=not IS_SPARK_REMOTE, raises=pd.errors.OutOfBoundsDatetime, ) @pytest.mark.notimpl(["flink"], raises=ArrowInvalid) @@ -1753,7 +1768,9 @@ def test_timestamp_date_comparison(backend, alltypes, df, left_fn, right_fn): @pytest.mark.notimpl( ["clickhouse"], reason="returns incorrect results", raises=AssertionError ) -@pytest.mark.notimpl(["pyspark"], raises=pd.errors.OutOfBoundsDatetime) +@pytest.mark.notimpl( + ["pyspark"], condition=not IS_SPARK_REMOTE, raises=pd.errors.OutOfBoundsDatetime +) @pytest.mark.notimpl(["polars"], raises=AssertionError, reason="returns NaT") @pytest.mark.notyet( ["flink"], diff --git a/ibis/backends/tests/test_udf.py b/ibis/backends/tests/test_udf.py index 3713a4cd3058..f27397dc256d 100644 --- a/ibis/backends/tests/test_udf.py +++ b/ibis/backends/tests/test_udf.py @@ -6,7 +6,8 @@ import ibis.common.exceptions as com from ibis import _, udf -from ibis.backends.tests.errors import Py4JJavaError +from ibis.backends.tests.errors import Py4JJavaError, PySparkPythonException +from ibis.conftest import IS_SPARK_REMOTE no_python_udfs = mark.notimpl( [ @@ -151,6 +152,12 @@ def add_one_pyarrow(s: int) -> int: # s is series, int is the element type raises=NotImplementedError, reason="postgres only supports Python-native UDFs", ) +@mark.notyet( + ["pyspark"], + condition=IS_SPARK_REMOTE, + raises=PySparkPythonException, + reason="remote udfs not yet tested due to environment complexities", +) @mark.parametrize( "add_one", [ diff --git a/ibis/backends/tests/test_vectorized_udf.py b/ibis/backends/tests/test_vectorized_udf.py index 9a3ffb6120f9..cbcfce905c1e 100644 --- a/ibis/backends/tests/test_vectorized_udf.py +++ b/ibis/backends/tests/test_vectorized_udf.py @@ -6,12 +6,25 @@ import ibis import ibis.common.exceptions as com import ibis.expr.datatypes as dt +from ibis.backends.tests.errors import PySparkPythonException +from ibis.conftest import IS_SPARK_REMOTE from ibis.legacy.udf.vectorized import analytic, elementwise, reduction np = pytest.importorskip("numpy") pd = pytest.importorskip("pandas") -pytestmark = pytest.mark.notimpl(["druid", "oracle", "risingwave"]) +pytestmark = [ + pytest.mark.notimpl(["druid", "oracle", "risingwave"]), + pytest.mark.notyet( + ["pyspark"], + condition=IS_SPARK_REMOTE, + raises=PySparkPythonException, + # TODO(cpcloud): this API is deprecated in 10.0.0, no use copypasting a + # bunch of markers just for two passing tests + strict=False, + reason="remote udfs not yet tested due to environment complexities", + ), +] def _format_udf_return_type(func, result_formatter): diff --git a/ibis/backends/tests/test_window.py b/ibis/backends/tests/test_window.py index 0942160ad1c0..c57bacb06698 100644 --- a/ibis/backends/tests/test_window.py +++ b/ibis/backends/tests/test_window.py @@ -18,8 +18,10 @@ Py4JJavaError, PyDruidProgrammingError, PyODBCProgrammingError, + PySparkPythonException, SnowflakeProgrammingError, ) +from ibis.conftest import IS_SPARK_REMOTE from ibis.legacy.udf.vectorized import analytic, reduction np = pytest.importorskip("numpy") @@ -368,6 +370,12 @@ def test_grouped_bounded_expanding_window( ], raises=com.OperationNotDefinedError, ), + pytest.mark.notyet( + ["pyspark"], + condition=IS_SPARK_REMOTE, + raises=PySparkPythonException, + reason="remote udfs not yet tested due to environment complexities", + ), ], ), ], @@ -550,6 +558,12 @@ def test_grouped_bounded_preceding_window( ], raises=com.OperationNotDefinedError, ), + pytest.mark.notyet( + ["pyspark"], + condition=IS_SPARK_REMOTE, + raises=PySparkPythonException, + reason="remote udfs not yet tested due to environment complexities", + ), ], ), ], @@ -708,6 +722,12 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): ], raises=com.OperationNotDefinedError, ), + pytest.mark.notyet( + ["pyspark"], + condition=IS_SPARK_REMOTE, + raises=PySparkPythonException, + reason="remote udfs not yet tested due to environment complexities", + ), ], ), param( @@ -736,6 +756,12 @@ def test_simple_ungrouped_window_with_scalar_order_by(alltypes): ], raises=com.OperationNotDefinedError, ), + pytest.mark.notyet( + ["pyspark"], + condition=IS_SPARK_REMOTE, + raises=PySparkPythonException, + reason="remote udfs not yet tested due to environment complexities", + ), ], ), # Analytic ops diff --git a/ibis/conftest.py b/ibis/conftest.py index 40d309c47d04..e0af794b5c03 100644 --- a/ibis/conftest.py +++ b/ibis/conftest.py @@ -17,6 +17,8 @@ WINDOWS = platform.system() == "Windows" ARM64 = platform.machine() in ("arm64", "aarch64") CI = os.environ.get("CI") is not None +SPARK_REMOTE = os.environ.get("SPARK_REMOTE") +IS_SPARK_REMOTE = bool(SPARK_REMOTE) @pytest.fixture(autouse=True) diff --git a/pyproject.toml b/pyproject.toml index ca83f20f971d..bf4eebd93637 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -295,6 +295,7 @@ filterwarnings = [ "error", # pyspark uses a deprecated pandas API "ignore:is_datetime64tz_dtype is deprecated and will be removed in a future version:DeprecationWarning", + "ignore:is_categorical_dtype is deprecated .+:DeprecationWarning", # pyspark and impala leave sockets open "ignore:Exception ignored in:", # pandas