@@ -112,11 +112,13 @@ for t in $*; do
112
112
BAZEL_TARGET=" ${BAZEL_TARGET} $t "
113
113
done
114
114
115
+ TEST_TAG_FILTER_ARRAY=()
116
+ TEST_TAG_FILTER_ARRAY+=(' -multiaccelerator' )
117
+
115
118
COMMON_FLAGS=$( cat << EOF
116
119
--@local_config_cuda//:enable_cuda
117
120
--cache_test_results=${CACHE_TEST_RESULTS}
118
121
--test_timeout=600
119
- --test_tag_filters=-multiaccelerator
120
122
--test_env=JAX_SKIP_SLOW_TESTS=1
121
123
--test_env=JAX_ACCELERATOR_COUNT=${NGPUS}
122
124
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform
@@ -138,7 +140,11 @@ case "${BATTERY}" in
138
140
JOBS_PER_GPU=8
139
141
JOBS=$(( NGPUS * JOBS_PER_GPU))
140
142
EXTRA_FLAGS=" --local_test_jobs=${JOBS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_GPU} --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow"
141
- BAZEL_TARGET=" ${BAZEL_TARGET} //tests:gpu_tests"
143
+ # collect from all tests subdirectories recursively,
144
+ # use jax_test_gpu tag generated by jax_multiplatform_test rule:
145
+ # https://github.com/jax-ml/jax/blob/d36afe4f7fe01fe5db16069d796600090db5a3ce/jaxlib/jax.bzl#L265
146
+ TEST_TAG_FILTER_ARRAY+=(' jax_test_gpu' )
147
+ BAZEL_TARGET=" ${BAZEL_TARGET} //tests/..."
142
148
;;
143
149
backend-independent)
144
150
JOBS_PER_GPU=4
@@ -157,6 +163,8 @@ case "${BATTERY}" in
157
163
;;
158
164
esac
159
165
166
+ TEST_TAG_FILTERS=$( IFS=, ; echo " --test_tag_filters=${TEST_TAG_FILTER_ARRAY[*]} " )
167
+
160
168
print_var NCPUS
161
169
print_var NGPUS
162
170
print_var BATTERY
@@ -165,6 +173,7 @@ print_var JOBS_PER_GPU
165
173
print_var JOBS
166
174
print_var BUILD_JAXLIB
167
175
print_var BAZEL_TARGET
176
+ print_var TEST_TAG_FILTERS
168
177
print_var COMMON_FLAGS
169
178
print_var EXTRA_FLAGS
170
179
@@ -182,4 +191,4 @@ pip install matplotlib
182
191
183
192
cd ` jax_source_dir`
184
193
python build/build.py --configure_only
185
- bazel test ${BAZEL_TARGET} ${COMMON_FLAGS} ${EXTRA_FLAGS}
194
+ bazel test ${BAZEL_TARGET} ${TEST_TAG_FILTERS} ${ COMMON_FLAGS} ${EXTRA_FLAGS}
0 commit comments