Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

a suggestion about video generation model finetune #242

Open
zhangvia opened this issue Jan 24, 2025 · 1 comment
Open

a suggestion about video generation model finetune #242

zhangvia opened this issue Jan 24, 2025 · 1 comment

Comments

@zhangvia
Copy link

zhangvia commented Jan 24, 2025

we know that training video generation model is significantly different from training LLM. When training video generation models, the most significant GPU memory consumption comes from intermediate activations, unlike training large language models, where model parameters occupy more memory.
but if we checkpoint all blocks in model, that will significantly slow down the training process. so maybe we could add some feature like selective activation checkpoint and cpu offload activation checkpoint, which can trade-off the speed and vram cost

something like this:

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    CheckpointImpl,
    offload_wrapper,
    apply_activation_checkpointing,
    checkpoint_wrapper,
)

def apply_selective_checkpointing(model, block_types, p, use_deepspeed_ac):
    '''
    block_types: a list of nn.Module types to be checkpointed
    p: the fraction of the all blocks to be checkpointed
    '''
    block_idx = 0
    cut_off = 1 / 2
    # when passing p as a fraction number (e.g. 1/3), it will be interpreted
    # as a string in argv, thus we need eval("1/3") here for fractions.
    p = eval(p) if isinstance(p, str) else p

    def selective_checkpointing(submodule):
        nonlocal block_idx
        nonlocal cut_off

        if isinstance(submodule, block_types):
            block_idx += 1
            if block_idx * p >= cut_off:
                cut_off += 1

                return True
        return False
    
    def count_total_blocks(model, block_types):
        total_blocks = 0

        def count_blocks(module):
            nonlocal total_blocks
            if isinstance(module, block_types):
                total_blocks += 1

        model.apply(count_blocks)
        return total_blocks
    
    if use_deepspeed_ac:
        from deepspeed.runtime.activation_checkpointing import checkpointing
        total_block_num = count_total_blocks(model, block_types)
        num_checkpoints = round(p * total_block_num)
        checkpointing.configure(
            mpu_=None,
            deepspeed_config=None,
            partition_activations=False,
            contiguous_checkpointing=False,
            num_checkpoints=num_checkpoints,
            checkpoint_in_cpu=True,
            synchronize=False,
            profile=True,
        )
        checkpoint_fn = checkpointing.checkpoint
        checkpointing_wrapper = partial(checkpoint_wrapper, checkpoint_fn=checkpoint_fn)
    else:
        # checkpointing_wrapper = partial(checkpoint_wrapper,checkpoint_impl=CheckpointImpl.NO_REENTRANT,)
        checkpointing_wrapper = offload_wrapper

    apply_activation_checkpointing(
        model,
        checkpoint_wrapper_fn=checkpointing_wrapper,
        check_fn=selective_checkpointing,
    )

i've already use this to lora-training hunyuanvideo. but unfortunately, deepspeed activation checkpoint function may manipulate the input variable of block,which will cause an error when training hunyuanvideo. besides, when i use the torch cpu offload checkpoint, it still will be oom on a cpu offload activation checkpoint hook of F.scaled_dot_product_attention, i'm still try to figure out what is happening.

@a-r-r-o-w
Copy link
Owner

Thanks for the awesome suggestion! I'm working on SAC here:

class CheckpointType(str, Enum):
FULL = "full"
OPS = "ops"
BLOCK_SKIP = "block_skip"
_SELECTIVE_ACTIVATION_CHECKPOINTING_OPS = {
torch.ops.aten.mm.default,
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
}

I just need to do some more tests with multiple models and enable modification via the CLI

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants