From 59687b3513d26df4426eb6018f5abf771c2efe46 Mon Sep 17 00:00:00 2001 From: crusaderky <crusaderky@gmail.com> Date: Fri, 24 Jan 2025 19:29:58 +0000 Subject: [PATCH 1/3] DNM array-api-compat git tip --- pixi.lock | 86 ++++++++++++++++++++++++-------------------------- pyproject.toml | 5 +-- 2 files changed, 45 insertions(+), 46 deletions(-) diff --git a/pixi.lock b/pixi.lock index 00c8feef..f74068dc 100644 --- a/pixi.lock +++ b/pixi.lock @@ -9,7 +9,6 @@ environments: linux-64: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_kmp_llvm.tar.bz2 - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/bzip2-1.0.8-h4bc722e_7.conda - conda: https://prefix.dev/conda-forge/linux-64/ca-certificates-2024.12.14-hbcca054_0.conda - conda: https://prefix.dev/conda-forge/linux-64/ld_impl_linux-64-2.43-h712a8e2_2.conda @@ -30,9 +29,9 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/readline-8.2-h8228510_1.conda - conda: https://prefix.dev/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025a-h78e105d_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . osx-arm64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/bzip2-1.0.8-h99b78c6_7.conda - conda: https://prefix.dev/conda-forge/osx-arm64/ca-certificates-2024.12.14-hf0a4a13_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libexpat-2.6.4-h286801f_0.conda @@ -46,9 +45,9 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/readline-8.2-h92ec313_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/tk-8.6.13-h5083fa2_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025a-h78e105d_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . win-64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/bzip2-1.0.8-h2466b09_7.conda - conda: https://prefix.dev/conda-forge/win-64/ca-certificates-2024.12.14-h56e8100_0.conda - conda: https://prefix.dev/conda-forge/win-64/libexpat-2.6.4-he0c23c2_0.conda @@ -64,6 +63,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-h5fd82a7_24.conda - conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.42.34433-h6356254_24.conda - conda: https://prefix.dev/conda-forge/win-64/vs2015_runtime-14.42.34433-hfef2bbc_24.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . dev: channels: @@ -75,7 +75,6 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_kmp_llvm.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/astroid-3.3.8-py312h7900ff3_0.conda - conda: https://prefix.dev/conda-forge/noarch/asttokens-3.0.0-pyhd8ed1ab_1.conda @@ -324,10 +323,10 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/zlib-1.3.1-hb9d3cd8_2.conda - conda: https://prefix.dev/conda-forge/linux-64/zstandard-0.23.0-py312hef9b889_1.conda - conda: https://prefix.dev/conda-forge/linux-64/zstd-1.5.6-ha6fb4c9_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . osx-arm64: - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/astroid-3.3.8-py312h81bd7bf_0.conda - conda: https://prefix.dev/conda-forge/noarch/asttokens-3.0.0-pyhd8ed1ab_1.conda @@ -566,11 +565,11 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/zlib-1.3.1-h8359307_2.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstandard-0.23.0-py312h15fbf35_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstd-1.5.6-hb46c0d2_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . win-64: - conda: https://prefix.dev/conda-forge/win-64/_openmp_mutex-4.5-2_gnu.conda - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/astroid-3.3.8-py312h2e8e312_0.conda - conda: https://prefix.dev/conda-forge/noarch/asttokens-3.0.0-pyhd8ed1ab_1.conda @@ -784,6 +783,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/zipp-3.21.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstandard-0.23.0-py312h7606c53_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstd-1.5.6-h0ea2cb4_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . dev-cuda: channels: @@ -795,7 +795,6 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_kmp_llvm.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/astroid-3.3.8-py312h7900ff3_0.conda - conda: https://prefix.dev/conda-forge/noarch/asttokens-3.0.0-pyhd8ed1ab_1.conda @@ -1111,10 +1110,10 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/zlib-1.3.1-hb9d3cd8_2.conda - conda: https://prefix.dev/conda-forge/linux-64/zstandard-0.23.0-py312hef9b889_1.conda - conda: https://prefix.dev/conda-forge/linux-64/zstd-1.5.6-ha6fb4c9_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . osx-arm64: - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/astroid-3.3.8-py312h81bd7bf_0.conda - conda: https://prefix.dev/conda-forge/noarch/asttokens-3.0.0-pyhd8ed1ab_1.conda @@ -1353,11 +1352,11 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/zlib-1.3.1-h8359307_2.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstandard-0.23.0-py312h15fbf35_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstd-1.5.6-hb46c0d2_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . win-64: - conda: https://prefix.dev/conda-forge/win-64/_openmp_mutex-4.5-2_gnu.conda - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/astroid-3.3.8-py312h2e8e312_0.conda - conda: https://prefix.dev/conda-forge/noarch/asttokens-3.0.0-pyhd8ed1ab_1.conda @@ -1589,6 +1588,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/zipp-3.21.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstandard-0.23.0-py312h7606c53_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstd-1.5.6-h0ea2cb4_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . docs: channels: @@ -1600,7 +1600,6 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_kmp_llvm.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/babel-2.16.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/beautifulsoup4-4.12.3-pyha770c72_1.conda - conda: https://prefix.dev/conda-forge/linux-64/brotli-python-1.1.0-py312h2ec8cdc_2.conda @@ -1683,10 +1682,10 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/zipp-3.21.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/zstandard-0.23.0-py312hef9b889_1.conda - conda: https://prefix.dev/conda-forge/linux-64/zstd-1.5.6-ha6fb4c9_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . osx-arm64: - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/babel-2.16.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/beautifulsoup4-4.12.3-pyha770c72_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/brotli-python-1.1.0-py312hde4cb15_2.conda @@ -1761,10 +1760,10 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/zipp-3.21.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstandard-0.23.0-py312h15fbf35_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstd-1.5.6-hb46c0d2_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . win-64: - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/babel-2.16.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/beautifulsoup4-4.12.3-pyha770c72_1.conda - conda: https://prefix.dev/conda-forge/win-64/brotli-python-1.1.0-py312h275cf98_2.conda @@ -1841,6 +1840,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/zipp-3.21.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstandard-0.23.0-py312h7606c53_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstd-1.5.6-h0ea2cb4_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . lint: channels: @@ -1852,7 +1852,6 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_kmp_llvm.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/astroid-3.3.8-py312h7900ff3_0.conda - conda: https://prefix.dev/conda-forge/noarch/babel-2.16.0-pyhd8ed1ab_1.conda @@ -1962,10 +1961,10 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/zlib-1.3.1-hb9d3cd8_2.conda - conda: https://prefix.dev/conda-forge/linux-64/zstandard-0.23.0-py312hef9b889_1.conda - conda: https://prefix.dev/conda-forge/linux-64/zstd-1.5.6-ha6fb4c9_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . osx-arm64: - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/astroid-3.3.8-py312h81bd7bf_0.conda - conda: https://prefix.dev/conda-forge/noarch/babel-2.16.0-pyhd8ed1ab_1.conda @@ -2066,10 +2065,10 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/zlib-1.3.1-h8359307_2.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstandard-0.23.0-py312h15fbf35_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstd-1.5.6-hb46c0d2_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . win-64: - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/astroid-3.3.8-py312h2e8e312_0.conda - conda: https://prefix.dev/conda-forge/noarch/babel-2.16.0-pyhd8ed1ab_1.conda @@ -2172,6 +2171,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/zipp-3.21.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstandard-0.23.0-py312h7606c53_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstd-1.5.6-h0ea2cb4_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . tests: channels: @@ -2182,7 +2182,6 @@ environments: linux-64: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_kmp_llvm.tar.bz2 - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/bzip2-1.0.8-h4bc722e_7.conda - conda: https://prefix.dev/conda-forge/linux-64/ca-certificates-2024.12.14-hbcca054_0.conda @@ -2227,9 +2226,9 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025a-h78e105d_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . osx-arm64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/bzip2-1.0.8-h99b78c6_7.conda - conda: https://prefix.dev/conda-forge/osx-arm64/ca-certificates-2024.12.14-hf0a4a13_0.conda @@ -2264,9 +2263,9 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025a-h78e105d_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . win-64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/bzip2-1.0.8-h2466b09_7.conda - conda: https://prefix.dev/conda-forge/win-64/ca-certificates-2024.12.14-h56e8100_0.conda @@ -2305,6 +2304,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-h5fd82a7_24.conda - conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.42.34433-h6356254_24.conda - conda: https://prefix.dev/conda-forge/win-64/vs2015_runtime-14.42.34433-hfef2bbc_24.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . tests-backends: channels: @@ -2315,7 +2315,6 @@ environments: linux-64: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_kmp_llvm.tar.bz2 - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/aws-c-auth-0.8.1-h205f482_0.conda - conda: https://prefix.dev/conda-forge/linux-64/aws-c-cal-0.8.1-h1a47875_3.conda @@ -2498,9 +2497,9 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/zlib-1.3.1-hb9d3cd8_2.conda - conda: https://prefix.dev/conda-forge/linux-64/zstandard-0.23.0-py310ha39cb0e_1.conda - conda: https://prefix.dev/conda-forge/linux-64/zstd-1.5.6-ha6fb4c9_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . osx-arm64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/aws-c-auth-0.8.1-hfc2798a_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/aws-c-cal-0.8.1-hc8a0bd2_3.conda @@ -2674,10 +2673,10 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/zlib-1.3.1-h8359307_2.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstandard-0.23.0-py310h2665a74_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstd-1.5.6-hb46c0d2_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . win-64: - conda: https://prefix.dev/conda-forge/win-64/_openmp_mutex-4.5-2_gnu.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/aws-c-auth-0.8.1-hd11252f_0.conda - conda: https://prefix.dev/conda-forge/win-64/aws-c-cal-0.8.1-h099ea23_3.conda @@ -2828,6 +2827,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/zipp-3.21.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstandard-0.23.0-py310he5e10e1_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstd-1.5.6-h0ea2cb4_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . tests-cuda: channels: @@ -2838,7 +2838,6 @@ environments: linux-64: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_kmp_llvm.tar.bz2 - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/attr-2.5.1-h166bdaf_1.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/aws-c-auth-0.8.1-h205f482_0.conda @@ -3088,9 +3087,9 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/zlib-1.3.1-hb9d3cd8_2.conda - conda: https://prefix.dev/conda-forge/linux-64/zstandard-0.23.0-py310ha39cb0e_1.conda - conda: https://prefix.dev/conda-forge/linux-64/zstd-1.5.6-ha6fb4c9_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . osx-arm64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/aws-c-auth-0.8.1-hfc2798a_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/aws-c-cal-0.8.1-hc8a0bd2_3.conda @@ -3264,10 +3263,10 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/zlib-1.3.1-h8359307_2.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstandard-0.23.0-py310h2665a74_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstd-1.5.6-hb46c0d2_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . win-64: - conda: https://prefix.dev/conda-forge/win-64/_openmp_mutex-4.5-2_gnu.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/aws-c-auth-0.8.1-hd11252f_0.conda - conda: https://prefix.dev/conda-forge/win-64/aws-c-cal-0.8.1-h099ea23_3.conda @@ -3436,6 +3435,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/zipp-3.21.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstandard-0.23.0-py310he5e10e1_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstd-1.5.6-h0ea2cb4_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . tests-py310: channels: @@ -3446,7 +3446,6 @@ environments: linux-64: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2 - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/bzip2-1.0.8-h4bc722e_7.conda - conda: https://prefix.dev/conda-forge/linux-64/ca-certificates-2024.12.14-hbcca054_0.conda @@ -3486,9 +3485,9 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025a-h78e105d_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . osx-arm64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/bzip2-1.0.8-h99b78c6_7.conda - conda: https://prefix.dev/conda-forge/osx-arm64/ca-certificates-2024.12.14-hf0a4a13_0.conda @@ -3522,9 +3521,9 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025a-h78e105d_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . win-64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/bzip2-1.0.8-h2466b09_7.conda - conda: https://prefix.dev/conda-forge/win-64/ca-certificates-2024.12.14-h56e8100_0.conda @@ -3562,6 +3561,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-h5fd82a7_24.conda - conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.42.34433-h6356254_24.conda - conda: https://prefix.dev/conda-forge/win-64/vs2015_runtime-14.42.34433-hfef2bbc_24.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . tests-py313: channels: @@ -3572,7 +3572,6 @@ environments: linux-64: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2 - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/bzip2-1.0.8-h4bc722e_7.conda - conda: https://prefix.dev/conda-forge/linux-64/ca-certificates-2024.12.14-hbcca054_0.conda @@ -3612,9 +3611,9 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025a-h78e105d_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . osx-arm64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/bzip2-1.0.8-h99b78c6_7.conda - conda: https://prefix.dev/conda-forge/osx-arm64/ca-certificates-2024.12.14-hf0a4a13_0.conda @@ -3650,9 +3649,9 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025a-h78e105d_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . win-64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/bzip2-1.0.8-h2466b09_7.conda - conda: https://prefix.dev/conda-forge/win-64/ca-certificates-2024.12.14-h56e8100_0.conda @@ -3692,6 +3691,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-h5fd82a7_24.conda - conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.42.34433-h6356254_24.conda - conda: https://prefix.dev/conda-forge/win-64/vs2015_runtime-14.42.34433-hfef2bbc_24.conda + - pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 - pypi: . packages: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 @@ -3753,23 +3753,21 @@ packages: - pkg:pypi/alabaster?source=hash-mapping size: 18684 timestamp: 1733750512696 -- conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - sha256: c98308dcf035a413a635317c69b48143cdf4c5895853457062780395e5ea4633 - md5: e399bc184553ca13cb068d272a995f48 - depends: - - python >=3.9 - license: MIT - license_family: MIT - purls: - - pkg:pypi/array-api-compat?source=hash-mapping - size: 38442 - timestamp: 1735201429468 +- pypi: git+https://github.com/data-apis/array-api-compat#fa558f21884dbccc0102f6c9a2a34d0b149100b5 + name: array-api-compat + version: 1.10.1.dev0 + requires_dist: + - cupy ; extra == 'cupy' + - dask ; extra == 'dask' + - jax ; extra == 'jax' + - numpy ; extra == 'numpy' + - pytorch ; extra == 'pytorch' + - sparse>=0.15.1 ; extra == 'sparse' + requires_python: '>=3.9' - pypi: . name: array-api-extra version: 0.6.1.dev0 - sha256: bb6cd89a7f100a73d3f853de571b2f4fff0e70de8df0d113f2f5c1559744e6b6 - requires_dist: - - array-api-compat>=1.10.0,<2 + sha256: c806efb7ff2f1885f9194e51fbc67f00cc07a32927570d7e630459a6930b395e requires_python: '>=3.10' editable: true - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda diff --git a/pyproject.toml b/pyproject.toml index d15aba84..33bbb3cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ "Typing :: Typed", ] dynamic = ["version"] -dependencies = ["array-api-compat>=1.10.0,<2"] +# dependencies = ["array-api-compat>=1.11.0,<2"] # DNM [project.urls] Homepage = "https://github.com/data-apis/array-api-extra" @@ -48,10 +48,11 @@ platforms = ["linux-64", "osx-arm64", "win-64"] [tool.pixi.dependencies] python = ">=3.10,<3.14" -array-api-compat = ">=1.10.0,<2" +# array-api-compat = ">=1.11.0,<2" # DNM [tool.pixi.pypi-dependencies] array-api-extra = { path = ".", editable = true } +array-api-compat = { git = "https://github.com/data-apis/array-api-compat" } # DNM [tool.pixi.feature.lint.dependencies] typing-extensions = "*" From 93cc03547dccc54df3b08bd3b4303a331b5373e2 Mon Sep 17 00:00:00 2001 From: crusaderky <crusaderky@gmail.com> Date: Fri, 24 Jan 2025 20:25:55 +0000 Subject: [PATCH 2/3] WIP ENH: setdiff1d for Dask and jax.jit --- src/array_api_extra/_lib/_funcs.py | 100 ++++++++++++++++++-- src/array_api_extra/_lib/_utils/_helpers.py | 61 +----------- tests/test_funcs.py | 6 +- tests/test_utils.py | 41 +------- 4 files changed, 96 insertions(+), 112 deletions(-) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index f7eb8c88..b6620de6 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -11,7 +11,12 @@ from ._at import at from ._utils import _compat, _helpers -from ._utils._compat import array_namespace, is_jax_array +from ._utils._compat import ( + array_namespace, + is_dask_namespace, + is_jax_array, + is_jax_namespace, +) from ._utils._helpers import asarrays from ._utils._typing import Array @@ -547,6 +552,7 @@ def setdiff1d( /, *, assume_unique: bool = False, + fill_value: object | None = None, xp: ModuleType | None = None, ) -> Array: """ @@ -563,6 +569,11 @@ def setdiff1d( assume_unique : bool If ``True``, the input arrays are both assumed to be unique, which can speed up the calculation. Default is ``False``. + fill_value : object, optional + Pad the output array with this value. + + This is exclusively used for JAX arrays when running inside ``jax.jit``, + where all array shapes need to be known in advance. xp : array_namespace, optional The standard-compatible namespace for `x1` and `x2`. Default: infer. @@ -587,13 +598,86 @@ def setdiff1d( xp = array_namespace(x1, x2) x1, x2 = asarrays(x1, x2, xp=xp) - if assume_unique: - x1 = xp.reshape(x1, (-1,)) - x2 = xp.reshape(x2, (-1,)) - else: - x1 = xp.unique_values(x1) - x2 = xp.unique_values(x2) - return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)] + x1 = xp.reshape(x1, (-1,)) + x2 = xp.reshape(x2, (-1,)) + if x1.shape == (0,) or x2.shape == (0,): + return x1 + + def _x1_not_in_x2(x1: Array, x2: Array) -> Array: # numpydoc ignore=PR01,RT01 + """For each element of x1, return True if it is not also in x2.""" + # Even when assume_unique=True, there is no provision for x to be sorted + x2 = xp.sort(x2) + idx = xp.searchsorted(x2, x1) + + # FIXME at() is faster but needs JAX jit support for bool mask + # idx = at(idx, idx == x2.shape[0]).set(0) + idx = xp.where(idx == x2.shape[0], xp.zeros_like(idx), idx) + + return xp.take(x2, idx, axis=0) != x1 + + def _generic_impl(x1: Array, x2: Array) -> Array: # numpydoc ignore=PR01,RT01 + """Generic implementation (including eager JAX).""" + # Note: there is no provision in the Array API for xp.unique_values to sort + if not assume_unique: + # Call unique_values early to speed up the algorithm + x1 = xp.unique_values(x1) + x2 = xp.unique_values(x2) + mask = _x1_not_in_x2(x1, x2) + x1 = x1[mask] + return x1 if assume_unique else xp.sort(x1) + + def _dask_impl(x1: Array, x2: Array) -> Array: # numpydoc ignore=PR01,RT01 + """ + Dask implementation. + + Works around unique_values returning unknown shapes. + """ + # Do not call unique_values yet, as it would make array shapes unknown + mask = _x1_not_in_x2(x1, x2) + x1 = x1[mask] + # Note: da.unique_values sorts + return x1 if assume_unique else xp.unique_values(x1) + + def _jax_jit_impl( + x1: Array, x2: Array, fill_value: object | None + ) -> Array: # numpydoc ignore=PR01,RT01 + """ + JAX implementation inside jax.jit. + + Works around unique_values requiring a size= parameter + and not being able to filter by a boolean mask. + Returns array the same size as x1, padded with fill_value. + """ + # unique_values inside jax.jit is not supported unless it's got a fixed size + mask = _x1_not_in_x2(x1, x2) + + if fill_value is None: + fill_value = xp.zeros((), dtype=x1.dtype) + else: + fill_value = xp.asarray(fill_value, dtype=x1.dtype) + if cast(Array, fill_value).ndim != 0: + msg = "`fill_value` must be a scalar." + raise ValueError(msg) + + x1 = xp.where(mask, x1, fill_value) + # Note: jnp.unique_values sorts + return xp.unique_values(x1, size=x1.size, fill_value=fill_value) + + if is_dask_namespace(xp): + return _dask_impl(x1, x2) + + if is_jax_namespace(xp): + import jax + + try: + return _generic_impl(x1, x2) # eager mode + except ( + jax.errors.ConcretizationTypeError, + jax.errors.NonConcreteBooleanIndexError, + ): + return _jax_jit_impl(x1, x2, fill_value) # inside jax.jit + + return _generic_impl(x1, x2) def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: diff --git a/src/array_api_extra/_lib/_utils/_helpers.py b/src/array_api_extra/_lib/_utils/_helpers.py index b32a1081..62ff0116 100644 --- a/src/array_api_extra/_lib/_utils/_helpers.py +++ b/src/array_api_extra/_lib/_utils/_helpers.py @@ -10,66 +10,7 @@ from ._compat import is_array_api_obj, is_numpy_array from ._typing import Array -__all__ = ["in1d", "mean"] - - -def in1d( - x1: Array, - x2: Array, - /, - *, - assume_unique: bool = False, - invert: bool = False, - xp: ModuleType | None = None, -) -> Array: # numpydoc ignore=PR01,RT01 - """ - Check whether each element of an array is also present in a second array. - - Returns a boolean array the same length as `x1` that is True - where an element of `x1` is in `x2` and False otherwise. - - This function has been adapted using the original implementation - present in numpy: - https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/arraysetops.py#L524-L758 - """ - if xp is None: - xp = _compat.array_namespace(x1, x2) - - # This code is run to make the code significantly faster - if x2.shape[0] < 10 * x1.shape[0] ** 0.145: - if invert: - mask = xp.ones(x1.shape[0], dtype=xp.bool, device=_compat.device(x1)) - for a in x2: - mask &= x1 != a - else: - mask = xp.zeros(x1.shape[0], dtype=xp.bool, device=_compat.device(x1)) - for a in x2: - mask |= x1 == a - return mask - - rev_idx = xp.empty(0) # placeholder - if not assume_unique: - x1, rev_idx = xp.unique_inverse(x1) - x2 = xp.unique_values(x2) - - ar = xp.concat((x1, x2)) - device_ = _compat.device(ar) - # We need this to be a stable sort. - order = xp.argsort(ar, stable=True) - reverse_order = xp.argsort(order, stable=True) - sar = xp.take(ar, order, axis=0) - ar_size = _compat.size(sar) - assert ar_size is not None, "xp.unique*() on lazy backends raises" - if ar_size >= 1: - bool_ar = sar[1:] != sar[:-1] if invert else sar[1:] == sar[:-1] - else: - bool_ar = xp.asarray([False]) if invert else xp.asarray([True]) - flag = xp.concat((bool_ar, xp.asarray([invert], device=device_))) - ret = xp.take(flag, reverse_order, axis=0) - - if assume_unique: - return ret[: x1.shape[0]] - return xp.take(ret, rev_idx, axis=0) +__all__ = ["mean"] def mean( diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 2c265b23..de8a6fd0 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -35,8 +35,7 @@ lazy_xp_function(kron, static_argnames="xp") lazy_xp_function(nunique, static_argnames="xp") lazy_xp_function(pad, static_argnames=("pad_width", "mode", "constant_values", "xp")) -# FIXME calls in1d which calls xp.unique_values without size -lazy_xp_function(setdiff1d, jax_jit=False, static_argnames=("assume_unique", "xp")) +lazy_xp_function(setdiff1d, static_argnames=("assume_unique", "xp")) # FIXME .device attribute https://github.com/data-apis/array-api-compat/pull/238 lazy_xp_function(sinc, jax_jit=False, static_argnames="xp") @@ -576,8 +575,7 @@ def test_sequence_of_tuples_width(self, xp: ModuleType): assert padded.shape == (4, 4) -@pytest.mark.skip_xp_backend(Backend.DASK, reason="no argsort") -@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no device kwarg in asarray") +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no sort") class TestSetDiff1D: @pytest.mark.skip_xp_backend( Backend.TORCH, reason="index_select not implemented for uint32" diff --git a/tests/test_utils.py b/tests/test_utils.py index d9f50362..ab5f3589 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,49 +4,10 @@ import pytest from array_api_extra._lib import Backend -from array_api_extra._lib._testing import xp_assert_equal -from array_api_extra._lib._utils._compat import device as get_device -from array_api_extra._lib._utils._helpers import asarrays, in1d -from array_api_extra._lib._utils._typing import Device -from array_api_extra.testing import lazy_xp_function +from array_api_extra._lib._utils._helpers import asarrays # mypy: disable-error-code=no-untyped-usage -# FIXME calls xp.unique_values without size -lazy_xp_function(in1d, jax_jit=False, static_argnames=("assume_unique", "invert", "xp")) - - -class TestIn1D: - @pytest.mark.skip_xp_backend(Backend.DASK, reason="no argsort") - @pytest.mark.skip_xp_backend( - Backend.SPARSE, reason="no unique_inverse, no device kwarg in asarray" - ) - # cover both code paths - @pytest.mark.parametrize("n", [9, 15]) - def test_no_invert_assume_unique(self, xp: ModuleType, n: int): - x1 = xp.asarray([3, 8, 20]) - x2 = xp.arange(n) - expected = xp.asarray([True, True, False]) - actual = in1d(x1, x2) - xp_assert_equal(actual, expected) - - @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no device kwarg in asarray") - def test_device(self, xp: ModuleType, device: Device): - x1 = xp.asarray([3, 8, 20], device=device) - x2 = xp.asarray([2, 3, 4], device=device) - assert get_device(in1d(x1, x2)) == device - - @pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="explicit xp") - @pytest.mark.skip_xp_backend( - Backend.SPARSE, reason="no arange, no device kwarg in asarray" - ) - def test_xp(self, xp: ModuleType): - x1 = xp.asarray([1, 6]) - x2 = xp.arange(5) - expected = xp.asarray([True, False]) - actual = in1d(x1, x2, xp=xp) - xp_assert_equal(actual, expected) - @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype") @pytest.mark.parametrize( From 2a1554f3d5d0bea8b4a805539afa64e2bc116b02 Mon Sep 17 00:00:00 2001 From: crusaderky <crusaderky@gmail.com> Date: Fri, 24 Jan 2025 20:38:15 +0000 Subject: [PATCH 3/3] Design 2->4 --- src/array_api_extra/_lib/_funcs.py | 32 ++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index b6620de6..ee3064b0 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -552,6 +552,7 @@ def setdiff1d( /, *, assume_unique: bool = False, + size: int | None = None, fill_value: object | None = None, xp: ModuleType | None = None, ) -> Array: @@ -569,11 +570,16 @@ def setdiff1d( assume_unique : bool If ``True``, the input arrays are both assumed to be unique, which can speed up the calculation. Default is ``False``. - fill_value : object, optional - Pad the output array with this value. + size : int, optional + The size of the output array. This is exclusively used inside the JAX JIT, and + only for as long as JAX does not support arrays of unknown size inside it. In + all other cases, it is disregarded. + Returned elements will be clipped if they are more than size, and padded with + `fill_value` if they are less. Default: raise if inside ``jax.jit``. - This is exclusively used for JAX arrays when running inside ``jax.jit``, - where all array shapes need to be known in advance. + fill_value : object, optional + Pad the output array with this value. This is exclusively used for JAX arrays + when running inside ``jax.jit``. Default: 0. xp : array_namespace, optional The standard-compatible namespace for `x1` and `x2`. Default: infer. @@ -639,7 +645,7 @@ def _dask_impl(x1: Array, x2: Array) -> Array: # numpydoc ignore=PR01,RT01 return x1 if assume_unique else xp.unique_values(x1) def _jax_jit_impl( - x1: Array, x2: Array, fill_value: object | None + x1: Array, x2: Array, size: int | None, fill_value: object | None ) -> Array: # numpydoc ignore=PR01,RT01 """ JAX implementation inside jax.jit. @@ -648,9 +654,9 @@ def _jax_jit_impl( and not being able to filter by a boolean mask. Returns array the same size as x1, padded with fill_value. """ - # unique_values inside jax.jit is not supported unless it's got a fixed size - mask = _x1_not_in_x2(x1, x2) - + if size is None: + msg = "`size` is mandatory when running inside `jax.jit`." + raise ValueError(msg) if fill_value is None: fill_value = xp.zeros((), dtype=x1.dtype) else: @@ -659,9 +665,13 @@ def _jax_jit_impl( msg = "`fill_value` must be a scalar." raise ValueError(msg) + # unique_values inside jax.jit is not supported unless it's got a fixed size + mask = _x1_not_in_x2(x1, x2) x1 = xp.where(mask, x1, fill_value) - # Note: jnp.unique_values sorts - return xp.unique_values(x1, size=x1.size, fill_value=fill_value) + # Move fill_value to the right + x1 = xp.take(x1, xp.argsort(~mask, stable=True)) + x1 = x1[:size] + x1 = xp.unique_values(x1, size=size, fill_value=fill_value) if is_dask_namespace(xp): return _dask_impl(x1, x2) @@ -675,7 +685,7 @@ def _jax_jit_impl( jax.errors.ConcretizationTypeError, jax.errors.NonConcreteBooleanIndexError, ): - return _jax_jit_impl(x1, x2, fill_value) # inside jax.jit + return _jax_jit_impl(x1, x2, size, fill_value) # inside jax.jit return _generic_impl(x1, x2)