@@ -112,11 +112,13 @@ for t in $*; do
112
112
BAZEL_TARGET=" ${BAZEL_TARGET} $t "
113
113
done
114
114
115
+ TEST_TAG_FILTER_ARRAY=()
116
+ TEST_TAG_FILTER_ARRAY+=(' -multiaccelerator' )
117
+
115
118
COMMON_FLAGS=$( cat << EOF
116
119
--@local_config_cuda//:enable_cuda
117
120
--cache_test_results=${CACHE_TEST_RESULTS}
118
121
--test_timeout=600
119
- --test_tag_filters=-multiaccelerator
120
122
--test_env=JAX_SKIP_SLOW_TESTS=1
121
123
--test_env=JAX_ACCELERATOR_COUNT=${NGPUS}
122
124
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform
@@ -141,7 +143,8 @@ case "${BATTERY}" in
141
143
# collect from all tests subdirectories recursively,
142
144
# use jax_test_gpu tag generated by jax_multiplatform_test rule:
143
145
# https://github.com/jax-ml/jax/blob/d36afe4f7fe01fe5db16069d796600090db5a3ce/jaxlib/jax.bzl#L265
144
- BAZEL_TARGET=" ${BAZEL_TARGET} //tests/... --test_tag_filters=jax_test_gpu"
146
+ TEST_TAG_FILTER_ARRAY+=(' jax_test_gpu' )
147
+ BAZEL_TARGET=" ${BAZEL_TARGET} //tests/...
145
148
;;
146
149
backend-independent)
147
150
JOBS_PER_GPU=4
@@ -160,6 +163,8 @@ case "${BATTERY}" in
160
163
;;
161
164
esac
162
165
166
+ TEST_TAG_FILTERS=$( IFS=, echo " --test_tag_filters=${TEST_TAG_FILTER_ARRAY[*]} " )
167
+
163
168
print_var NCPUS
164
169
print_var NGPUS
165
170
print_var BATTERY
@@ -168,6 +173,7 @@ print_var JOBS_PER_GPU
168
173
print_var JOBS
169
174
print_var BUILD_JAXLIB
170
175
print_var BAZEL_TARGET
176
+ print_var TEST_TAG_FILTERS
171
177
print_var COMMON_FLAGS
172
178
print_var EXTRA_FLAGS
173
179
@@ -185,4 +191,4 @@ pip install matplotlib
185
191
186
192
cd ` jax_source_dir`
187
193
python build/build.py --configure_only
188
- bazel test ${BAZEL_TARGET} ${COMMON_FLAGS} ${EXTRA_FLAGS}
194
+ bazel test ${BAZEL_TARGET} ${TEST_TAG_FILTERS} ${ COMMON_FLAGS} ${EXTRA_FLAGS}
0 commit comments