Skip to content

Commit cfc3f74

Browse files
authored
Collect recursively and filter GPU tests using jax_test_gpu tag (#1091)
1 parent 0832fac commit cfc3f74

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

.github/container/test-jax.sh

+12-3
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,13 @@ for t in $*; do
112112
BAZEL_TARGET="${BAZEL_TARGET} $t"
113113
done
114114

115+
TEST_TAG_FILTER_ARRAY=()
116+
TEST_TAG_FILTER_ARRAY+=('-multiaccelerator')
117+
115118
COMMON_FLAGS=$(cat << EOF
116119
--@local_config_cuda//:enable_cuda
117120
--cache_test_results=${CACHE_TEST_RESULTS}
118121
--test_timeout=600
119-
--test_tag_filters=-multiaccelerator
120122
--test_env=JAX_SKIP_SLOW_TESTS=1
121123
--test_env=JAX_ACCELERATOR_COUNT=${NGPUS}
122124
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform
@@ -138,7 +140,11 @@ case "${BATTERY}" in
138140
JOBS_PER_GPU=8
139141
JOBS=$((NGPUS * JOBS_PER_GPU))
140142
EXTRA_FLAGS="--local_test_jobs=${JOBS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_GPU} --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow"
141-
BAZEL_TARGET="${BAZEL_TARGET} //tests:gpu_tests"
143+
# collect from all tests subdirectories recursively,
144+
# use jax_test_gpu tag generated by jax_multiplatform_test rule:
145+
# https://github.com/jax-ml/jax/blob/d36afe4f7fe01fe5db16069d796600090db5a3ce/jaxlib/jax.bzl#L265
146+
TEST_TAG_FILTER_ARRAY+=('jax_test_gpu')
147+
BAZEL_TARGET="${BAZEL_TARGET} //tests/..."
142148
;;
143149
backend-independent)
144150
JOBS_PER_GPU=4
@@ -157,6 +163,8 @@ case "${BATTERY}" in
157163
;;
158164
esac
159165

166+
TEST_TAG_FILTERS=$(IFS=, ; echo "--test_tag_filters=${TEST_TAG_FILTER_ARRAY[*]}")
167+
160168
print_var NCPUS
161169
print_var NGPUS
162170
print_var BATTERY
@@ -165,6 +173,7 @@ print_var JOBS_PER_GPU
165173
print_var JOBS
166174
print_var BUILD_JAXLIB
167175
print_var BAZEL_TARGET
176+
print_var TEST_TAG_FILTERS
168177
print_var COMMON_FLAGS
169178
print_var EXTRA_FLAGS
170179

@@ -182,4 +191,4 @@ pip install matplotlib
182191

183192
cd `jax_source_dir`
184193
python build/build.py --configure_only
185-
bazel test ${BAZEL_TARGET} ${COMMON_FLAGS} ${EXTRA_FLAGS}
194+
bazel test ${BAZEL_TARGET} ${TEST_TAG_FILTERS} ${COMMON_FLAGS} ${EXTRA_FLAGS}

0 commit comments

Comments
 (0)