Skip to content

Commit 5d5d8bc

Browse files
committed
Add max shard size to transformers save_pretrained (#1648)
1 parent 83c1afd commit 5d5d8bc

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

llmfoundry/callbacks/hf_checkpointer.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,10 @@ def tensor_hook(
742742
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext(
743743
)
744744
with context_manager:
745-
new_model_instance.save_pretrained(temp_save_dir)
745+
new_model_instance.save_pretrained(
746+
temp_save_dir,
747+
max_shard_size='1GB',
748+
)
746749
if original_tokenizer is not None:
747750
assert isinstance(
748751
original_tokenizer,
@@ -798,7 +801,10 @@ def tensor_hook(
798801
new_model_instance = self.transform_model_pre_registration(
799802
new_model_instance,
800803
)
801-
new_model_instance.save_pretrained(register_save_dir)
804+
new_model_instance.save_pretrained(
805+
register_save_dir,
806+
max_shard_size='1GB',
807+
)
802808
if original_tokenizer:
803809
original_tokenizer.save_pretrained(register_save_dir)
804810

0 commit comments

Comments
 (0)