Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Found two forward recomputation exist in a single backward when using FSDP with activation checkpointing #740

Open
1 of 2 tasks
mingyuanw-mt opened this issue Oct 21, 2024 · 1 comment
Assignees

Comments

@mingyuanw-mt
Copy link

System Info

PyTorch version: 2.2.0
Is debug build: False
CUDA used to build PyTorch: 11.8
GPU: A100 PCIe * 4
transformers: 4.45.2

Information

  • The official example scripts
  • My own modified scripts

🐛 Describe the bug

Hi, I am trying fintuning Llama2-7B using FSDP , but I found there exits two forward recomputations in a single backward when using FSDP with activation checkpointing, but there should only be one recomputation normally? This strange phenomenon decreases the throughput of the model’training

As you can see in the diagram, there are two flash_attention operations
image

here is my tracing file:
fsdp_ac_cuda_profile.zip

It seems that using both fsdp and transformer's activation checkpointing has caused this problem

model.gradient_checkpointing_enable()
apply_fsdp_checkpointing(model)

fsdp's activation checkpointing will use non reentrant version of torch.utils.checkpoint, and transformer will do the opposite, please correct me if my understanding is incorrect.

When I delete model.gradient_checkpointing_enable(), the behavior of backward becomes normal, and the train_epoch_time before this change is 248.88s, after this change is 185.07s

Error logs

Please check the attachment above

Expected behavior

One backward operation corresponds to one recomputation

@wukaixingxp wukaixingxp assigned wukaixingxp and mreso and unassigned wukaixingxp Oct 21, 2024
@mreso
Copy link
Contributor

mreso commented Oct 28, 2024

Hi @mingyuanw-mt thanks for flagging this. I think they are using reentrant as well but in any case this will wrap with torch.utils.checkpoint twice which leads to the trace. I'll be creating a related PR this week and can disable the second call in it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants