From cdf89c92d2090b6db8f8b05d144915979f669373 Mon Sep 17 00:00:00 2001 From: Andrey Portnoy Date: Wed, 9 Oct 2024 14:32:51 -0400 Subject: [PATCH 1/4] Collect recursively and filter GPU tests using `jax_test_gpu` tag --- .github/container/test-jax.sh | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/container/test-jax.sh b/.github/container/test-jax.sh index 6afbaace1..880c3cabd 100755 --- a/.github/container/test-jax.sh +++ b/.github/container/test-jax.sh @@ -138,7 +138,10 @@ case "${BATTERY}" in JOBS_PER_GPU=8 JOBS=$((NGPUS * JOBS_PER_GPU)) EXTRA_FLAGS="--local_test_jobs=${JOBS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_GPU} --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow" - BAZEL_TARGET="${BAZEL_TARGET} //tests:gpu_tests" + # collect from all tests subdirectories recursively, + # use jax_test_gpu tag generated by jax_multiplatform_test rule: + # https://github.com/jax-ml/jax/blob/d36afe4f7fe01fe5db16069d796600090db5a3ce/jaxlib/jax.bzl#L265 + BAZEL_TARGET="${BAZEL_TARGET} //tests/... --test_tag_filters=jax_test_gpu" ;; backend-independent) JOBS_PER_GPU=4 From a2666cae646f5f04d821fc39a8c6b4a5c35941b7 Mon Sep 17 00:00:00 2001 From: Andrey Portnoy Date: Thu, 10 Oct 2024 13:34:50 -0400 Subject: [PATCH 2/4] Populate `--test_tag_filters` through separate array variable --- .github/container/test-jax.sh | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/.github/container/test-jax.sh b/.github/container/test-jax.sh index 880c3cabd..a9e8a3808 100755 --- a/.github/container/test-jax.sh +++ b/.github/container/test-jax.sh @@ -112,11 +112,13 @@ for t in $*; do BAZEL_TARGET="${BAZEL_TARGET} $t" done +TEST_TAG_FILTER_ARRAY=() +TEST_TAG_FILTER_ARRAY+=('-multiaccelerator') + COMMON_FLAGS=$(cat << EOF --@local_config_cuda//:enable_cuda --cache_test_results=${CACHE_TEST_RESULTS} --test_timeout=600 ---test_tag_filters=-multiaccelerator --test_env=JAX_SKIP_SLOW_TESTS=1 --test_env=JAX_ACCELERATOR_COUNT=${NGPUS} --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform @@ -141,7 +143,8 @@ case "${BATTERY}" in # collect from all tests subdirectories recursively, # use jax_test_gpu tag generated by jax_multiplatform_test rule: # https://github.com/jax-ml/jax/blob/d36afe4f7fe01fe5db16069d796600090db5a3ce/jaxlib/jax.bzl#L265 - BAZEL_TARGET="${BAZEL_TARGET} //tests/... --test_tag_filters=jax_test_gpu" + TEST_TAG_FILTER_ARRAY+=('jax_test_gpu') + BAZEL_TARGET="${BAZEL_TARGET} //tests/... ;; backend-independent) JOBS_PER_GPU=4 @@ -160,6 +163,8 @@ case "${BATTERY}" in ;; esac +TEST_TAG_FILTERS=$(IFS=, echo "--test_tag_filters=${TEST_TAG_FILTER_ARRAY[*]}") + print_var NCPUS print_var NGPUS print_var BATTERY @@ -168,6 +173,7 @@ print_var JOBS_PER_GPU print_var JOBS print_var BUILD_JAXLIB print_var BAZEL_TARGET +print_var TEST_TAG_FILTERS print_var COMMON_FLAGS print_var EXTRA_FLAGS @@ -185,4 +191,4 @@ pip install matplotlib cd `jax_source_dir` python build/build.py --configure_only -bazel test ${BAZEL_TARGET} ${COMMON_FLAGS} ${EXTRA_FLAGS} +bazel test ${BAZEL_TARGET} ${TEST_TAG_FILTERS} ${COMMON_FLAGS} ${EXTRA_FLAGS} From ceff277d965a8ad0a0bc1c1071a2756daf951b2a Mon Sep 17 00:00:00 2001 From: Andrey Portnoy Date: Thu, 10 Oct 2024 18:48:35 -0400 Subject: [PATCH 3/4] Add missing double quote in test-jax.sh --- .github/container/test-jax.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/container/test-jax.sh b/.github/container/test-jax.sh index a9e8a3808..11bca9d3d 100755 --- a/.github/container/test-jax.sh +++ b/.github/container/test-jax.sh @@ -144,7 +144,7 @@ case "${BATTERY}" in # use jax_test_gpu tag generated by jax_multiplatform_test rule: # https://github.com/jax-ml/jax/blob/d36afe4f7fe01fe5db16069d796600090db5a3ce/jaxlib/jax.bzl#L265 TEST_TAG_FILTER_ARRAY+=('jax_test_gpu') - BAZEL_TARGET="${BAZEL_TARGET} //tests/... + BAZEL_TARGET="${BAZEL_TARGET} //tests/..." ;; backend-independent) JOBS_PER_GPU=4 From 02e8499753ad1c411a02e9927044dd65f08b493a Mon Sep 17 00:00:00 2001 From: Andrey Portnoy Date: Mon, 14 Oct 2024 23:23:08 -0400 Subject: [PATCH 4/4] Add missing semicolon in test-jax.sh --- .github/container/test-jax.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/container/test-jax.sh b/.github/container/test-jax.sh index 11bca9d3d..d73ed6d7e 100755 --- a/.github/container/test-jax.sh +++ b/.github/container/test-jax.sh @@ -163,7 +163,7 @@ case "${BATTERY}" in ;; esac -TEST_TAG_FILTERS=$(IFS=, echo "--test_tag_filters=${TEST_TAG_FILTER_ARRAY[*]}") +TEST_TAG_FILTERS=$(IFS=, ; echo "--test_tag_filters=${TEST_TAG_FILTER_ARRAY[*]}") print_var NCPUS print_var NGPUS