Skip to content

Commit 0f9ce62

Browse files
authored
Collect recursively and filter GPU tests using jax_test_gpu tag
1 parent d1ff3c8 commit 0f9ce62

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

.github/container/test-jax.sh

+4-1
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,10 @@ case "${BATTERY}" in
138138
JOBS_PER_GPU=8
139139
JOBS=$((NGPUS * JOBS_PER_GPU))
140140
EXTRA_FLAGS="--local_test_jobs=${JOBS} --test_env=JAX_TESTS_PER_ACCELERATOR=${JOBS_PER_GPU} --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow"
141-
BAZEL_TARGET="${BAZEL_TARGET} //tests:gpu_tests"
141+
# collect from all tests subdirectories recursively,
142+
# use jax_test_gpu tag generated by jax_multiplatform_test rule:
143+
# https://github.com/jax-ml/jax/blob/d36afe4f7fe01fe5db16069d796600090db5a3ce/jaxlib/jax.bzl#L265
144+
BAZEL_TARGET="${BAZEL_TARGET} //tests/... --test_tag_filters=jax_test_gpu"
142145
;;
143146
backend-independent)
144147
JOBS_PER_GPU=4

0 commit comments

Comments
 (0)