Skip to content

Commit 62dccd1

Browse files
committed
Refactor to make more testable, add tests/CI.
1 parent 32e2e8b commit 62dccd1

File tree

10 files changed

+1288
-688
lines changed

10 files changed

+1288
-688
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .main import main
2+
3+
__all__ = ["main"]
+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import argparse
2+
import datetime
3+
import getpass
4+
import os
5+
import pathlib
6+
import tempfile
7+
8+
9+
def parse_args():
10+
parser = argparse.ArgumentParser(
11+
description="""
12+
Triage failures in JAX/XLA-related tests. The expectation is that the given
13+
test command is failing in recent versions, but that it passed in the past. The
14+
script first triages the regression with a search of the nightly containers,
15+
and then refines the search to a particular commit of JAX or XLA.""",
16+
)
17+
18+
container_search_args = parser.add_argument_group(
19+
title="Container-level search",
20+
description="""
21+
First, it is verified that the test command fails on the given end date, unless
22+
both --end-date and --skip-precondition-checks were passed. Then, the program
23+
searches backwards to find a container when the given test did pass. The
24+
--start-date option can be used to speed up this search, if you already know a
25+
date on which the test was passing. The earliest failure is located to within
26+
--threshold-days days.""",
27+
)
28+
commit_search_args = parser.add_argument_group(
29+
title="Commit-level search",
30+
description="""
31+
Second, the failure is localised to a commit of JAX or XLA by re-building and
32+
re-testing inside the earliest container that demonstrates the failure. At each
33+
point, the oldest JAX commit that is newer than XLA is used.""",
34+
)
35+
parser.add_argument(
36+
"--container",
37+
help="""
38+
Container to use. Example: jax, pax, triton. Used to construct the URLs of
39+
nightly containers, like ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD.""",
40+
required=True,
41+
)
42+
parser.add_argument(
43+
"--output-prefix",
44+
default=datetime.datetime.now().strftime("triage-%Y-%m-%d-%H-%M-%S"),
45+
help="""
46+
Prefix for output log and JSON files. Default: triage-YYYY-MM-DD-HH-MM-SS.
47+
An INFO-and-above log is written as PREFIX.log, a DEBUG-and-above log is
48+
written as PREFIX-debug.log, and a JSON summary is written as
49+
PREFIX-summary.json""",
50+
type=pathlib.Path,
51+
)
52+
parser.add_argument(
53+
"--skip-precondition-checks",
54+
action="store_true",
55+
help="""
56+
Skip checks that should pass by construction. This saves time, but may yield
57+
incorrect results if you are not careful. Specifically this means that the test
58+
is assumed to fail on --end-date (if specified), pass on --start-date (if
59+
specified), and fail after recompilation in the earliest-known-failure
60+
container. Careful use of this option, along with --start-date, --end-date and
61+
--threshold-days, allows the container-level search to be skipped.""",
62+
)
63+
parser.add_argument(
64+
"test_command",
65+
nargs="+",
66+
help="""
67+
Command to execute inside the container. This should be as targeted as
68+
possible.""",
69+
)
70+
container_search_args.add_argument(
71+
"--end-date",
72+
help="""
73+
Initial estimate of the earliest nightly container date where the test case
74+
fails. Defaults to the newest available nightly container date. If this and
75+
--skip-precondition-checks are both set then it will not be verified that the
76+
test case fails on this date.""",
77+
type=lambda s: datetime.date.fromisoformat(s),
78+
)
79+
container_search_args.add_argument(
80+
"--start-date",
81+
help="""
82+
Initial estimate of the latest nightly container date where the test case
83+
passes. Defaults to the day before --end-date, but setting this to a date
84+
further in the past may lead to faster convergence of the initial backwards
85+
search for a date when the test case passed. If this and
86+
--skip-precondition-checks are both set then the test case *must* pass on
87+
this date, which will *not* be verified.""",
88+
type=lambda s: datetime.date.fromisoformat(s),
89+
)
90+
container_search_args.add_argument(
91+
"--threshold-days",
92+
default=1,
93+
help="""
94+
Convergence threshold. Ideally, the container-level search will continue while
95+
the number of days separating the last known success and first known failure is
96+
smaller than this value. The minimum, and default, value is 1. Note that in
97+
case of nightly build failures the search may finish without reaching this
98+
threshold.""",
99+
type=int,
100+
)
101+
commit_search_args.add_argument(
102+
"--bazel-cache",
103+
default=os.path.join(
104+
tempfile.gettempdir(), f"{getpass.getuser()}-bazel-triage-cache"
105+
),
106+
help="""
107+
Bazel cache to use when [re-]building JAX/XLA during the fine search. This can
108+
be a remote cache server or a local directory. Using a persistent cache can
109+
significantly speed up the commit-level search. By default, uses a temporary
110+
directory including the name of the current user.""",
111+
)
112+
return parser.parse_args()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import logging
2+
import pathlib
3+
import subprocess
4+
import typing
5+
6+
7+
class DockerContainer:
8+
def __init__(
9+
self,
10+
url: str,
11+
*,
12+
logger: logging.Logger,
13+
mounts: typing.List[typing.Tuple[pathlib.Path, pathlib.Path]],
14+
):
15+
self._logger = logger
16+
self._mount_args = []
17+
for src, dst in mounts:
18+
self._mount_args += ["-v", f"{src}:{dst}"]
19+
self._url = url
20+
21+
def __enter__(self):
22+
result = subprocess.run(
23+
[
24+
"docker",
25+
"run",
26+
"--detach",
27+
# Otherwise bazel shutdown hangs.
28+
"--init",
29+
"--gpus=all",
30+
"--shm-size=1g",
31+
]
32+
+ self._mount_args
33+
+ [
34+
self._url,
35+
"sleep",
36+
"infinity",
37+
],
38+
check=True,
39+
encoding="utf-8",
40+
stderr=subprocess.PIPE,
41+
stdout=subprocess.PIPE,
42+
)
43+
self._id = result.stdout.strip()
44+
return self
45+
46+
def __exit__(self, *exc_info):
47+
subprocess.run(
48+
["docker", "stop", self._id],
49+
check=True,
50+
stderr=subprocess.PIPE,
51+
stdout=subprocess.PIPE,
52+
)
53+
54+
def exec(
55+
self, command: typing.List[str], workdir=None
56+
) -> subprocess.CompletedProcess:
57+
"""
58+
Run a command inside a persistent container.
59+
"""
60+
workdir = [] if workdir is None else ["--workdir", workdir]
61+
return subprocess.run(
62+
["docker", "exec"] + workdir + [self._id] + command,
63+
encoding="utf-8",
64+
stderr=subprocess.PIPE,
65+
stdout=subprocess.PIPE,
66+
)
67+
68+
def check_exec(
69+
self, cmd: typing.List[str], **kwargs
70+
) -> subprocess.CompletedProcess:
71+
result = self.exec(cmd, **kwargs)
72+
if result.returncode != 0:
73+
self._logger.fatal(
74+
f"{' '.join(cmd)} exited with return code {result.returncode}"
75+
)
76+
self._logger.fatal(result.stdout)
77+
self._logger.fatal(result.stderr)
78+
result.check_returncode()
79+
return result

0 commit comments

Comments
 (0)