Skip to content

Commit d9f5c18

Browse files
authored
Merge branch 'main' into vkozlov/jetstream-4-maxtext
2 parents ab21dae + e8043a5 commit d9f5c18

27 files changed

+423
-176
lines changed

.github/container/Dockerfile.maxtext.amd64

-7
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,6 @@ git-clone.sh ${URLREF_MAXTEXT} ${SRC_PATH_MAXTEXT}
2020
echo "-r ${SRC_PATH_MAXTEXT}/requirements.txt" >> /opt/pip-tools.d/requirements-maxtext.in
2121
EOF
2222

23-
###############################################################################
24-
## Apply patch
25-
###############################################################################
26-
27-
ADD maxtext-mha.patch /opt
28-
RUN cd "${SRC_PATH_MAXTEXT}" && patch -p1 < /opt/maxtext-mha.patch && git diff
29-
3023
###############################################################################
3124
## Add test script to the path
3225
###############################################################################

.github/container/Dockerfile.maxtext.arm64

-7
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,6 @@ git-clone.sh ${URLREF_MAXTEXT} ${SRC_PATH_MAXTEXT}
5858
echo "-r ${SRC_PATH_MAXTEXT}/requirements.txt" >> /opt/pip-tools.d/requirements-maxtext.in
5959
EOF
6060

61-
###############################################################################
62-
## Apply patch
63-
###############################################################################
64-
65-
ADD maxtext-mha.patch /opt
66-
RUN cd "${SRC_PATH_MAXTEXT}" && patch -p1 < /opt/maxtext-mha.patch && git diff
67-
6861
###############################################################################
6962
## Add test script to the path
7063
###############################################################################

.github/container/Dockerfile.pax.amd64

+3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ for src in ${SRC_PATH_PAXML} ${SRC_PATH_PRAXIS}; do
2929
pushd ${src}
3030
sed -i "s| @ git+https://github.com/google/flax||g" requirements.in
3131
sed -i "s| @ git+https://github.com/google/jax||g" requirements.in
32+
## we pin etils because newer etils versions are not compatible with the
33+
## version of TFDS required by Pax
34+
sed -i "s/etils/etils==1.7.0/g" requirements.in
3235
if git diff --quiet; then
3336
echo "URL specs no longer present in select dependencies for ${src}"
3437
exit 1

.github/container/install-nsight.sh

+6-5
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@ apt-get clean
1717

1818
rm -rf /var/lib/apt/lists/*
1919

20-
NSYS202451=/opt/nvidia/nsight-systems-cli/2024.5.1
21-
if [[ -d "${NSYS202451}" ]]; then
22-
# * can match at least sbsa-armv8 and x86
23-
(cd ${NSYS202451}/target-linux-*/python/packages && git apply < /opt/nvidia/nsys-2024.5-tid-export.patch)
24-
fi
20+
for NSYS in /opt/nvidia/nsight-systems-cli/2024.5.1 /opt/nvidia/nsight-systems-cli/2024.6.1; do
21+
if [[ -d "${NSYS}" ]]; then
22+
# * can match at least sbsa-armv8 and x86
23+
(cd ${NSYS}/target-linux-*/python/packages && git apply < /opt/nvidia/nsys-2024.5-tid-export.patch)
24+
fi
25+
done
2526

2627
# Install extra dependencies needed for `nsys recipe ...` commands. These are
2728
# used by the nsys-jax wrapper script.

.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py

+46-8
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pathlib
77
from typing import Any
88

9-
from .protobuf import HloProto, xla_module_metadata
9+
from .protobuf import HloProto, _host_memory_space, xla_module_metadata
1010
from .utils import make_child_mask, ProfilerData
1111

1212
pd.options.mode.copy_on_write = True
@@ -38,6 +38,11 @@ def align_profiler_data_timestamps(
3838
# Determine which collective size will be used for the alignment
3939
num_profiled_devices = len(comm_df.index.get_level_values("Device").unique())
4040
max_collective_size = comm_df["CollectiveSize"].max()
41+
if max_collective_size == 1:
42+
print(
43+
f"WARNING: cannot align {num_profiled_devices} devices because max collective size is 1"
44+
)
45+
return frames, {}
4146
assert (
4247
num_profiled_devices == max_collective_size
4348
), f"Aligning {num_profiled_devices} using collectives of size {max_collective_size} is not implemented"
@@ -193,13 +198,51 @@ def _get_message_size(
193198
"all-to-all",
194199
"collective-broadcast",
195200
"collective-permute-start",
201+
"dynamic-slice",
202+
"dynamic-update-slice",
196203
"reduce-scatter",
197204
}
198205
), f"{instruction}: message size calculation for {comm_inst.opcode} has not yet been validated"
206+
207+
def _byte_size(inst) -> int:
208+
size_bits = math.prod(
209+
inst.shape.dimensions,
210+
start=element_type_width(inst.shape.element_type),
211+
)
212+
size_bytes, rem = divmod(size_bits, 8)
213+
assert rem == 0
214+
return size_bytes
215+
199216
if comm_inst.opcode == "collective-permute-start":
200217
# See https://openxla.org/xla/operation_semantics#collectivepermute, which
201218
# generates pair-wise send+recv between devices
202219
collective_size = 2
220+
elif comm_inst.opcode in {"dynamic-slice", "dynamic-update-slice"}:
221+
# Label host-device transfers orchestrated by dynamic[-update]-slice as single
222+
# device collectives.
223+
collective_size = 1
224+
if comm_inst.opcode == "dynamic-update-slice":
225+
# For dynamic-update-slice the second operand is the one being copied
226+
_, src_inst = module_proto.find_instruction_by_id(comm_inst.operand_ids[1])
227+
transfer_size = _byte_size(src_inst.proto())
228+
else:
229+
# For dynamic-slice the return type size is the transfer size
230+
assert comm_inst.opcode == "dynamic-slice"
231+
_, src_inst = module_proto.find_instruction_by_id(comm_inst.operand_ids[0])
232+
transfer_size = _byte_size(comm_inst)
233+
dest_on_host = _host_memory_space(comm_inst)
234+
src_on_host = _host_memory_space(src_inst.proto())
235+
assert src_on_host != dest_on_host, (
236+
'dynamic[-update]-slice is only considered is only "communication" if it '
237+
"represents a host-device transfer"
238+
)
239+
return (
240+
transfer_size,
241+
"device-to-host" if dest_on_host else "host-to-device",
242+
1, # collective size
243+
1.0, # bw_correction
244+
1.0, # bus_correction
245+
)
203246
else:
204247
# replica_groups is something like {{0,1},{4,5},{2,3},{6,7}}, if there are 8
205248
# devices that are doing pair-wise collectives
@@ -220,17 +263,12 @@ def _get_message_size(
220263
total_msg_size = 0
221264
for operand_id in comm_inst.operand_ids:
222265
_, operand = module_proto.find_instruction_by_id(operand_id)
223-
msg_size_bits = math.prod(
224-
operand.proto().shape.dimensions,
225-
start=element_type_width(operand.proto().shape.element_type),
226-
)
266+
msg_size_bytes = _byte_size(operand.proto())
227267
if comm_inst.opcode == "reduce-scatter":
228268
# NCCL's convention is that the message size of a reduce-scatter is the size of output buffer:
229269
# https://github.com/NVIDIA/nccl/blob/ab2b89c4c339bd7f816fbc114a4b05d386b66290/src/collectives.cc#L122
230-
msg_size_bits, rem = divmod(msg_size_bits, collective_size)
270+
msg_size_bytes, rem = divmod(msg_size_bytes, collective_size)
231271
assert rem == 0
232-
msg_size_bytes, rem = divmod(msg_size_bits, 8)
233-
assert rem == 0
234272
total_msg_size += msg_size_bytes
235273

236274
collective = comm_inst.opcode.removesuffix("-start")

.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ def is_communication(row):
103103
return _calculate_overlap(thunk_df)
104104

105105

106+
compile_prefix = "XlaCompile:#module="
107+
108+
106109
def _load_nvtx_gpu_proj_trace_single(
107110
prefix: pathlib.Path,
108111
file: pathlib.Path,
@@ -305,10 +308,21 @@ def _load_nvtx_gpu_proj_trace_single(
305308
unique_pid_tid_pairs = module_df.loc[:, ("PID", "TID")].drop_duplicates()
306309
if len(unique_pid_tid_pairs) == 1:
307310
main_pid_tid_candidates.add(tuple(unique_pid_tid_pairs.iloc[0]))
311+
# If the profile only includes N>1 modules, we may still be able to identify the
312+
# main thread as the one responsible for XlaCompile ranges projected onto the GPU
313+
# timeline
314+
compile_ranges = df.loc[~all_thunks, "Name"].str.startswith(
315+
tsl_prefix + compile_prefix
316+
)
317+
compile_range_ids = compile_ranges[compile_ranges].index
318+
unique_pid_tid_pairs = df.loc[compile_range_ids, ("PID", "TID")].drop_duplicates()
319+
if len(unique_pid_tid_pairs) == 1:
320+
main_pid_tid_candidates.add(tuple(unique_pid_tid_pairs.iloc[0]))
308321
assert len(main_pid_tid_candidates) < 2
309322
if len(main_pid_tid_candidates) == 1:
310323
# Possibly not correct if len(device_by_pid_tid) > 1
311324
assert len(device_by_pid_tid) > 0
325+
# Associate the main thread with the 0th device in device_by_pid_tid
312326
main_thread_df = device_by_pid_tid.iloc[:1]
313327
main_thread_df.index = pd.MultiIndex.from_tuples(
314328
main_pid_tid_candidates, names=["PID", "TID"]
@@ -425,16 +439,13 @@ def _load_nvtx_gpu_proj_trace(
425439
return output
426440

427441

428-
compile_prefix = "TSL:XlaCompile:#module="
429-
430-
431442
def _splice_parallel_ranges(compile_df: pd.DataFrame) -> pd.DataFrame:
432443
# When parallel compilation is enabled, we end up with worker threads that
433444
# emit NVTX ranges but which are not accounted for in the RangeStack tree.
434445
# Splice these in under the relevant XlaCompile ranges in the RangeStack tree and
435446
# drop everything else.
436447
retain_mask = pd.Series(False, index=compile_df.index)
437-
compile_mask = compile_df["Name"].str.startswith(compile_prefix)
448+
compile_mask = compile_df["Name"].str.startswith("TSL:" + compile_prefix)
438449
for compile_range in compile_df[compile_mask].itertuples():
439450
# Identify the slice of `compile_df` that overlaps in time with this XlaCompile
440451
# range

.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf.py

+71-19
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
from collections import defaultdict
21
import functools
32
import lzma
43
import pathlib
54
import typing
65

76

7+
def _host_memory_space(inst):
8+
return inst.shape.layout.memory_space == 5
9+
10+
811
class StackFrame(typing.NamedTuple):
912
column: int
1013
file: str
@@ -25,6 +28,35 @@ def __init__(self, wrapped_hlo_proto, proto):
2528
# proto representing the actual collective, which will be different if the
2629
# async launch is handled by an async-start op
2730
# TODO: can any of copy-start, custom-call, recv, send represent communication?
31+
# This also aims to identify, and (for now) flag as communication, kernels that
32+
# implement device-to-host and host-to-device copies for memory offloading.
33+
# For example, a device-to-host offload might look like
34+
# computation {
35+
# ...
36+
# ROOT r1 = bf16[2,8,128,2048]{3,2,1,0:S(5)} dynamic-update-slice(...)
37+
# }
38+
# async_computation {
39+
# ...
40+
# ROOT r2 = bf16[2,8,128,2048]{3,2,1,0:S(5)} fusion(...), calls=computation
41+
# }
42+
# start = (...) async-start(...), calls=async_computation
43+
# where the :S(5) annotation shows that a buffer is in host memory.
44+
# A host-to-device load might look like
45+
# computation {
46+
# param_0 = bf16[2,8,128,2048]{3,2,1,0:S(5)} parameter(0)
47+
# ...
48+
# ROOT r1 = bf16[2,8,128,2048]{3,2,1,0} dynamic-slice(param_0, ...)
49+
# }
50+
# async_computation {
51+
# param_0 = bf16[2,8,128,2048]{3,2,1,0:S(5)} parameter(0)
52+
# ...
53+
# ROOT r2 = bf16[2,8,128,2048]{3,2,1,0} fusion(param_0, ...), calls=computation
54+
# }
55+
# start = (...) async-start(...), calls=async_computation
56+
# where the :S(5) memory space annotation is in a parameter instead of in the
57+
# return value.
58+
# For now, handling host-device kernels as single-device "collective"
59+
# communication should be sufficient.
2860
self._comm_proto = None
2961
comm_opcodes = {
3062
"all-gather",
@@ -39,25 +71,50 @@ def __init__(self, wrapped_hlo_proto, proto):
3971
"all-reduce-start",
4072
"collective-permute-start",
4173
}
74+
75+
def _is_offloading_instruction(inst):
76+
host_dest = _host_memory_space(inst)
77+
78+
def _host_operand(i):
79+
_, op = wrapped_hlo_proto.find_instruction_by_id(inst.operand_ids[i])
80+
return _host_memory_space(op.proto())
81+
82+
if inst.opcode == "dynamic-slice" and host_dest != _host_operand(0):
83+
return True
84+
elif (
85+
inst.opcode == "dynamic-update-slice"
86+
and host_dest == _host_operand(0)
87+
and host_dest != _host_operand(1)
88+
):
89+
return True
90+
return False
91+
4292
if self._proto.opcode in comm_opcodes | comm_start_opcodes:
4393
self._comm_proto = self._proto
44-
elif self._proto.opcode == "async-start":
94+
elif self._proto.opcode in {"async-start", "fusion"}:
95+
# fusion example:
96+
# computation {
97+
# param_0 = f32[...]{...:S(5)} parameter(0)
98+
# ...
99+
# ROOT dus = f32[...]{...:S(5)} dynamic-update-slice(param_0, ...)
100+
# }
101+
# inst = f32[256,128,128]{2,1,0:S(5)} fusion(...), calls=computation
45102
# This might be thinly wrapping an opcode in `comm_opcodes`
46-
other_opcodes = defaultdict(int)
47-
for called_id in self._proto.called_computation_ids:
48-
for called_inst in wrapped_hlo_proto.find_computation(
49-
called_id
50-
).instructions:
51-
if called_inst.opcode in comm_opcodes:
103+
def _visit_computation(computation_id):
104+
computation = wrapped_hlo_proto.find_computation(computation_id)
105+
for called_inst in computation.instructions:
106+
for called_id in called_inst.called_computation_ids:
107+
_visit_computation(called_id)
108+
if called_inst.opcode in comm_opcodes or _is_offloading_instruction(
109+
called_inst
110+
):
52111
assert (
53112
self._comm_proto is None
54113
), f"Found {called_inst.opcode} child having already found {self._comm_proto.opcode}"
55114
self._comm_proto = called_inst
56-
else:
57-
other_opcodes[called_inst.opcode] += 1
58-
assert (
59-
other_opcodes.keys() == {"parameter"}
60-
), f"async-start op {self._proto.name} wrapped too many opcode types ({dict(other_opcodes)}) in addition to {self._comm_proto}"
115+
116+
for called_id in self._proto.called_computation_ids:
117+
_visit_computation(called_id)
61118

62119
def communication_proto(self):
63120
return self._comm_proto
@@ -68,12 +125,7 @@ def is_communication(self) -> bool:
68125
a little more complicated than you might hope, because async communications are
69126
not handled uniformly.
70127
"""
71-
if self._comm_proto is None:
72-
return False
73-
assert (
74-
self._comm_proto.channel_id != 0
75-
), f"Got channel_id={self._comm_proto.channel_id} for {self._comm_proto.name}"
76-
return True
128+
return self._comm_proto is not None
77129

78130
def proto(self):
79131
"""

.github/container/manifest.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,8 @@ orbax-checkpoint:
177177
tracking_ref: main
178178
latest_verified_commit: 16c2d409e365576284dbaf190ac002b24c1f927f
179179
mode: pip-vcs
180+
pathwaysutils:
181+
url: https://github.com/google/pathways-utils.git
182+
tracking_ref: main
183+
latest_verified_commit: 359776d454940ffaa337c36d1df16308d44a95a9
184+
mode: pip-vcs

.github/container/maxtext-mha.patch

-15
This file was deleted.

.github/container/test-jax.sh

-3
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,6 @@ else
109109
fi
110110

111111
for t in $*; do
112-
if [[ "$t" != "//tests:"* ]]; then
113-
t="//tests:${t}"
114-
fi
115112
BAZEL_TARGET="${BAZEL_TARGET} $t"
116113
done
117114

.github/workflows/_ci.yaml

+3-2
Original file line numberDiff line numberDiff line change
@@ -528,13 +528,14 @@ jobs:
528528
STATISTICS_SCRIPT: |
529529
summary_line=$(tail -n1 test-te.log)
530530
errors=$(echo $summary_line | grep -oE '[0-9]+ error' | awk '{print $1} END { if (!NR) print 0}')
531-
passed_tests=$(cat pytest-report.jsonl | jq -r 'select(."$report_type" == "CollectReport" and .outcome == "passed") | .outcome' | wc -l)
532-
failed_tests=$(cat pytest-report.jsonl | jq -r 'select(."$report_type" == "CollectReport" and .outcome == "failed") | .outcome' | wc -l)
531+
passed_tests=$(cat pytest-report.jsonl | jq -r 'select(."$report_type" == "TestReport" and .when == "call" and .outcome == "passed") | .outcome' | wc -l)
532+
failed_tests=$(cat pytest-report.jsonl | jq -r 'select(."$report_type" == "TestReport" and .when == "call" and .outcome == "failed") | .outcome' | wc -l)
533533
total_tests=$((failed_tests + passed_tests))
534534
echo "TOTAL_TESTS=${total_tests}" >> $GITHUB_OUTPUT
535535
echo "ERRORS=${errors}" >> $GITHUB_OUTPUT
536536
echo "PASSED_TESTS=${passed_tests}" >> $GITHUB_OUTPUT
537537
echo "FAILED_TESTS=${failed_tests}" >> $GITHUB_OUTPUT
538+
TIMEOUT_MINUTES: 120
538539
ARTIFACTS: |
539540
test-te.log
540541
pytest-report.jsonl

0 commit comments

Comments
 (0)