You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
we know that training video generation model is significantly different from training LLM. When training video generation models, the most significant GPU memory consumption comes from intermediate activations, unlike training large language models, where model parameters occupy more memory.
but if we checkpoint all blocks in model, that will significantly slow down the training process. so maybe we could add some feature like selective activation checkpoint and cpu offload activation checkpoint, which can trade-off the speed and vram cost
something like this:
fromtorch.distributed.algorithms._checkpoint.checkpoint_wrapperimport (
CheckpointImpl,
offload_wrapper,
apply_activation_checkpointing,
checkpoint_wrapper,
)
defapply_selective_checkpointing(model, block_types, p, use_deepspeed_ac):
''' block_types: a list of nn.Module types to be checkpointed p: the fraction of the all blocks to be checkpointed '''block_idx=0cut_off=1/2# when passing p as a fraction number (e.g. 1/3), it will be interpreted# as a string in argv, thus we need eval("1/3") here for fractions.p=eval(p) ifisinstance(p, str) elsepdefselective_checkpointing(submodule):
nonlocalblock_idxnonlocalcut_offifisinstance(submodule, block_types):
block_idx+=1ifblock_idx*p>=cut_off:
cut_off+=1returnTruereturnFalsedefcount_total_blocks(model, block_types):
total_blocks=0defcount_blocks(module):
nonlocaltotal_blocksifisinstance(module, block_types):
total_blocks+=1model.apply(count_blocks)
returntotal_blocksifuse_deepspeed_ac:
fromdeepspeed.runtime.activation_checkpointingimportcheckpointingtotal_block_num=count_total_blocks(model, block_types)
num_checkpoints=round(p*total_block_num)
checkpointing.configure(
mpu_=None,
deepspeed_config=None,
partition_activations=False,
contiguous_checkpointing=False,
num_checkpoints=num_checkpoints,
checkpoint_in_cpu=True,
synchronize=False,
profile=True,
)
checkpoint_fn=checkpointing.checkpointcheckpointing_wrapper=partial(checkpoint_wrapper, checkpoint_fn=checkpoint_fn)
else:
# checkpointing_wrapper = partial(checkpoint_wrapper,checkpoint_impl=CheckpointImpl.NO_REENTRANT,)checkpointing_wrapper=offload_wrapperapply_activation_checkpointing(
model,
checkpoint_wrapper_fn=checkpointing_wrapper,
check_fn=selective_checkpointing,
)
i've already use this to lora-training hunyuanvideo. but unfortunately, deepspeed activation checkpoint function may manipulate the input variable of block,which will cause an error when training hunyuanvideo. besides, when i use the torch cpu offload checkpoint, it still will be oom on a cpu offload activation checkpoint hook of F.scaled_dot_product_attention, i'm still try to figure out what is happening.
The text was updated successfully, but these errors were encountered:
we know that training video generation model is significantly different from training LLM. When training video generation models, the most significant GPU memory consumption comes from intermediate activations, unlike training large language models, where model parameters occupy more memory.
but if we checkpoint all blocks in model, that will significantly slow down the training process. so maybe we could add some feature like selective activation checkpoint and cpu offload activation checkpoint, which can trade-off the speed and vram cost
something like this:
i've already use this to lora-training hunyuanvideo. but unfortunately, deepspeed activation checkpoint function may manipulate the input variable of block,which will cause an error when training hunyuanvideo. besides, when i use the torch cpu offload checkpoint, it still will be oom on a cpu offload activation checkpoint hook of F.scaled_dot_product_attention, i'm still try to figure out what is happening.
The text was updated successfully, but these errors were encountered: