Skip to content

Commit 465264b

Browse files
authored
Merge branch 'main' into sbosisio/support_axlearn
2 parents faf0b83 + 8edb63e commit 465264b

File tree

8 files changed

+1093
-36
lines changed

8 files changed

+1093
-36
lines changed

.github/container/Dockerfile.maxtext

+9-1
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,18 @@ for pattern in \
2222
"s|absl-py|absl-py>=2.1.0|g" \
2323
"s|protobuf==3.20.3|protobuf>=3.19.0|g" \
2424
"s|tensorflow-datasets|tensorflow-datasets>=4.8.0|g" \
25+
"s|sentencepiece==0.1.97|sentencepiece>=0.2|g" \
2526
; do
2627
sed -i "${pattern}" ${SRC_PATH_MAXTEXT}/requirements.txt;
2728
done
28-
echo -e "\ntensorflow-metadata>=1.15.0" >> ${SRC_PATH_MAXTEXT}/requirements.txt
29+
# add new line in case requirements.txt does not end with a new line
30+
echo >> ${SRC_PATH_MAXTEXT}/requirements.txt
31+
for requirement in \
32+
"tensorflow-metadata>=1.15.0" \
33+
"seqio@git+https://github.com/google/seqio.git" \
34+
; do
35+
echo "${requirement}" >> ${SRC_PATH_MAXTEXT}/requirements.txt
36+
done
2937
EOF
3038

3139
###############################################################################

.github/container/nsys_jax/nsys_jax/data_loaders.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,9 @@ def _load_nvtx_gpu_proj_trace_single(
230230
mod_id_names = df.loc[mod_ids, "Name"]
231231
assert mod_ids.shape == mod_id_names.shape
232232
# Get a mask in mod_id_names of entries where ModuleId in the original
233-
# Thunk is not referring to a Module. If it's not a module, it should
234-
# be a thunk.
233+
# Thunk is not referring to a Module yet. Intermediate levels of the
234+
# hierarchy can be other thunks (e.g. an individual graph node may
235+
# have a thunk representing the whole graph as a parent).
235236
mask = ~mod_id_names.str.startswith(module_prefix)
236237
assert (mask == mod_id_names.str.startswith(thunk_prefix)).all()
237238
assert mask.shape == mod_ids.shape

.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py

+2
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,8 @@ def gather_source_files(
574574
if src_file == "<string>":
575575
# This can appear due to python -c "...", for example.
576576
continue
577+
if src_file == "<frozen runpy>":
578+
continue
577579
assert osp.isabs(src_file), f"{src_file} is not absolute"
578580
output_queue.put(("sources" + src_file, src_file, COMPRESS_DEFLATE))
579581
print(f"{archive_name}: gathered source code in {time.time() - start:.2f}s")

0 commit comments

Comments
 (0)