Skip to content

Remove low_cpu_mem_usage and _fast_init #36963

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 61 commits into from
Mar 31, 2025
Merged

Remove low_cpu_mem_usage and _fast_init #36963

merged 61 commits into from
Mar 31, 2025

Conversation

Cyrilvallez
Copy link
Member

@Cyrilvallez Cyrilvallez commented Mar 25, 2025

What does this PR do?

This PR removes the now useless _fast_init and low_cpu_mem_usage in from_pretrained, in order to simplify even more and limit the number of code paths, in the end making it much easier to maintain/debug. These 2 parameters should always be True anyway for optimized model loading.

Because a LOT of models have bad _init_weights() methods (i.e. it does not init ALL parameters), it might be an issue if loading corrupted state dict (i.e. loading a state dict with missing weight, and one of the missing weight not being handled by _init_weights properly). However, this should not be an issue in general as we don't expect to have too many corrupted state dicts on the hub. Moreover, this bug is ALREADY PRESENT whenever loading such a model with a device_map, or low_cpu_mem_usage=True (or whatever option ending in activating low_cpu_mem_usage=True). This is because doing so will force to load the parameters on meta, so weights initialized in the __init__ of a Layer or similar (which assumes instantiating the model on cpu) will result in wrong weight init when moving back to cpu.

Nevertheless, it can be hard to debug, and should not be the case, so this PR already fixes some model's _init_weights. Jointly, #37070 adds a test to always detect if a model's _init_weights is missing a few parameters, and I will fix more models directly in it (it relies on the fact that _fast_init and low_cpu_mem_usage are already gone).
Fun fact: even our faithful Llama has a bad _init_weights!! (missing the RMSNorm) 🤯

Most of the files changed are simply removing old _fast_init tests (which were skipped anyway 🙃🙃), as well as fixing weight initialization for a few models that were blocking general CI tests.

Copy link

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers.

@github-actions github-actions bot marked this pull request as draft March 25, 2025 13:58
@Cyrilvallez Cyrilvallez marked this pull request as ready for review March 25, 2025 13:59
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mega nice! 🤗

Comment on lines +1401 to +1403
this check on each state dict at loading time (after the first loaded checkpoint, there are no way to initialize only the
mismatched weights if any, without overwriting the previously loaded weights as well because all the module will be
initialized, not only the weights that are mismatched).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be False if we define _linear_init and embeding_init or a mapping from layer type to mapping func in the _init_weights for example. Time for refactor so if this can be simplified we would only have to, at the weight loading time, call the func

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like iterating twice but this should happen less than before so all good

Copy link
Member Author

@Cyrilvallez Cyrilvallez Mar 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, not a fan either. It's definitely something to keep in mind - doing the inits at the weight level would be super nice in the future!

@ArthurZucker
Copy link
Collaborator

FIX CIs please!

@Cyrilvallez
Copy link
Member Author

Yup, it's just the test that I'm removing that got added to qwen3 - just removed it with fix-copies!

@ydshieh ydshieh merged commit f304318 into main Mar 31, 2025
17 of 21 checks passed
@ydshieh ydshieh deleted the remove-low-mem branch March 31, 2025 15:18
dmdaksh pushed a commit to dmdaksh/transformers that referenced this pull request Apr 2, 2025
* Remove low_cpu_mem_usage and _fast_init

* Update deepspeed.py

* Update modeling_utils.py

* remove the first 2 tests everywhere

* Update test_modeling_common.py

* remove what was remaining about fast_init

* fix logic and simplify

* mismatched keys logic update

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* fix 2 models init_weights

* extend to others

* remove grad

* Update modeling_fsmt.py

* init weights in tests

* style

* Update test_modeling_fsmt.py

* more old models

* fix more init_weights

* copies

* fix

* style

* Update modeling_lxmert.py

* fix inits

* more and more

* more

* should finalize

* style

* Update modeling_dinov2_with_registers.py

* fix

* Update modeling_encoder_decoder.py

* fix

* style

* Update modeling_lxmert.py

* post rebase cleanup

* Update modeling_informer.py

* back to start for device

* fix

* add test to detect all failing cases correctly

* Update test_modeling_common.py

* fix

* fix

* sam

* style

* Update modeling_maskformer_swin.py

* CIs

* CIs

* remove test - will add it on separate PR

* fix

* fix

* Update modeling_sam.py

* CIs

* CIs

* CIs

* convnext

* suggestions

* CIs

* fix copies after merge

---------

Co-authored-by: Yih-Dar <[email protected]>
zucchini-nlp pushed a commit to BakerBunker/transformers that referenced this pull request Apr 2, 2025
* Remove low_cpu_mem_usage and _fast_init

* Update deepspeed.py

* Update modeling_utils.py

* remove the first 2 tests everywhere

* Update test_modeling_common.py

* remove what was remaining about fast_init

* fix logic and simplify

* mismatched keys logic update

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* fix 2 models init_weights

* extend to others

* remove grad

* Update modeling_fsmt.py

* init weights in tests

* style

* Update test_modeling_fsmt.py

* more old models

* fix more init_weights

* copies

* fix

* style

* Update modeling_lxmert.py

* fix inits

* more and more

* more

* should finalize

* style

* Update modeling_dinov2_with_registers.py

* fix

* Update modeling_encoder_decoder.py

* fix

* style

* Update modeling_lxmert.py

* post rebase cleanup

* Update modeling_informer.py

* back to start for device

* fix

* add test to detect all failing cases correctly

* Update test_modeling_common.py

* fix

* fix

* sam

* style

* Update modeling_maskformer_swin.py

* CIs

* CIs

* remove test - will add it on separate PR

* fix

* fix

* Update modeling_sam.py

* CIs

* CIs

* CIs

* convnext

* suggestions

* CIs

* fix copies after merge

---------

Co-authored-by: Yih-Dar <[email protected]>
@sfc-gh-sbekman
Copy link

sfc-gh-sbekman commented Apr 3, 2025

This commit broke Deepspeed ZeRO3.

[rank0]:   File "/code/users/stas/github/sf/arctictraining/arctic_training/model/hf_factory.py", line 32, in create_model
[rank0]:     return AutoModelForCausalLM.from_pretrained(
[rank0]:   File "/code/users/stas/github/transformers/src/transformers/models/auto/auto_factory.py", line 573, in from_pretrained
[rank0]:     return model_class.from_pretrained(
[rank0]:   File "/code/users/stas/github/transformers/src/transformers/modeling_utils.py", line 275, in _wrapper
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/code/users/stas/github/transformers/src/transformers/modeling_utils.py", line 4420, in from_pretrained
[rank0]:     ) = cls._load_pretrained_model(
[rank0]:   File "/code/users/stas/github/transformers/src/transformers/modeling_utils.py", line 4932, in _load_pretrained_model
[rank0]:     raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
[rank0]: RuntimeError: Error(s) in loading state_dict for LlamaForCausalLM:
[rank0]:        While copying the parameter named "model.embed_tokens.weight", whose dimensions in the model are torch.Size([128256, 4096]) and whose dimensions in the checkpoint are torch.Size([128256, 4096]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).
[rank0]:        While copying the parameter named "model.layers.0.self_attn.q_proj.weight", whose dimensions in the model are torch.Size([4096, 4096]) and whose dimensions in the checkpoint are torch.Size([4096, 4096]), an exception occurred : ('Cannot copy out of meta tensor; no data!',).

[...] more of the above per weight

@sfc-gh-sbekman
Copy link

If you run the deepspeed tests they fail:

$ pytest tests/deepspeed
tests/deepspeed/test_deepspeed.py::CoreIntegrationDeepSpeed::test_arange_bf16 PASSED                                                           [  0%]
tests/deepspeed/test_deepspeed.py::CoreIntegrationDeepSpeed::test_init_zero3 FAILED                                                            [  1%]
tests/deepspeed/test_deepspeed.py::CoreIntegrationDeepSpeed::test_init_zero3_fp16 FAILED                                                       [  1%]
tests/deepspeed/test_deepspeed.py::CoreIntegrationDeepSpeed::test_init_zero3_missing_params FAILED                                             [  2%]

cc: @ydshieh

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 4, 2025

We have some issues with deepspeed CI job and fail to detect any issue. I opened a PR to resolve it , but we can proceed this issue directly.

@BenjaminBossan
Copy link
Member

Hey, just a small heads up @Cyrilvallez

The line:

init_contexts = [no_init_weights(), init_empty_weights()]

may fail if accelerate is not installed, since the init_empty_weights import is conditioned on accelerate being available.

@Cyrilvallez
Copy link
Member Author

Cyrilvallez commented Apr 7, 2025

Hey @BenjaminBossan! Thanks a lot for checking in! We took care of this one already! 🤗

@BenjaminBossan
Copy link
Member

Great to hear, thanks for your quick work. Is the fix merged on main?

@Cyrilvallez
Copy link
Member Author

Yes!

@farzadab
Copy link

farzadab commented Apr 7, 2025

Hi there,
I'm having issues with the latest release of transformers and by doing git bisect I can see this is the PR/commit that introduces the issue:

import transformers
model = transformers.AutoModel.from_pretrained('fixie-ai/ultravox-v0_5-llama-3_2-1b', trust_remote_code=True)
print(model.language_model.device)
# with transformers==4.50.3 model is on cpu

# but on transformers==4.51, on main, and from this PR onwards
# model is on meta device

Any idea why?

For more info, this is likely related to the fact that .from_pretrained tries to load the language_model in turn by calling AutoModelForCausalLM.from_pretrained inside UltravoxModel.__init__.
This behaviour was working fine until before this PR. If that's no longer accepted, please advise what should be done instead.

Related issue: vllm-project/vllm#16149

@Cyrilvallez
Copy link
Member Author

Hey @farzadab! Note that this was already broken before when using e.g. a device_map!
You should not rely on initializing weights anywhere else than inside the _init_weights method of a particular XXXPreTrainedModel. Here, they are silently skipped because you added _keys_to_ignore_on_load_missing = ["audio_tower.*", "language_model.*"] as well, so they are not treated at all (i.e. without it they would be moved back to cpu, and randomly initialized as missing, with a warning).
So TLDR -> do not skip keys, and initialize in _init_weights 🤗

@farzadab
Copy link

farzadab commented Apr 8, 2025

Thanks for the answer.

I'll take a stab at using _init_weights instead.

Here's my current (and inefficient) workaround fwiw:
image

@sssshhhhhh
Copy link

Hi this is causing weights to all be meta tensors when from_flax=True for at least whisper and bert. This was already broken with low_cpu_mem_usage before so not unexpected I guess.

from transformers import BertModel
model = BertModel.from_pretrained("google-bert/bert-base-uncased", from_flax=True)  # also 'openai/whisper-tiny'
assert model.state_dict()['embeddings.word_embeddings.weight'].is_meta

I'm only using this to convert from jax so not a big deal to use an older version. Sorry if you're already aware of this issue.

@Cyrilvallez
Copy link
Member Author

Hey @sssshhhhhh! We were actually not aware of it, it got through the radars as from_flax/from_tf are extremely rarely used! So rarely that apparently nobody ever reported that it was broken with a device_map (which implicitly used to activate low_cpu_mem_usage) 😳 I must say I'm quite surprised by this as this was in the codebase for quite a long time.

However, we will very soon start to deprecate flax and tf. As loading from_flax/from_tf uses the model architecture in the underlying library, it means that we will also stop supporting the from_flax/from_tf flags in from_pretrained. As a result, I don't think loading with these flags will be fixed in the current library state (main).

As a result, I do think the easiest is indeed to use older version to convert to pytorch if needed, then resave them. Would that be an acceptable way to proceed for your use-case? Or would that provide too much friction/disconfort?

@farzadab
Copy link

farzadab commented Apr 14, 2025

@Cyrilvallez I spent a lot of time trying to figure this out but I'm still left with no good solution.

Changing the behaviour of _init_weights as you suggested, cannot solve this issue because it assumes each inner module can be initialized separately (since it's called with model.apply), but that's not what I want. I want to be able to load sub-models (e.g. language_model and audio_tower) from checkpoints (e.g. HF Hub).

The best I can think of is to overwrite _load_pretrained_model, then somehow find the checkpoints for the sub-models and call language_model._load_pretrained_model on them.
What makes this extremely hard is that replicating the same behaviour for finding the checkpoints as .from_pretrained is very hard since from_pretrained is not very modular (1300 lines).

@Cyrilvallez
Copy link
Member Author

Hey! Yes, repos are expected to contain all their weights, so things would be much simpler if you added all weights to your repo directly (i.e. the weights of the submodels). However, if you don't want to do that, I believe _init_weights can still be used with something along the lines of:

def _init_weights(self, module):
        
    if module is self.language_model:
        self.language_model = module.from_pretrained(...)
    elif module in self.language_model.modules():
        pass
    ....

but passing specific args (i.e. same args as the outer call) to that inner from_pretrained will require a bit more hack

from_pretrained should be much less lines now as well, we simplified a lot 🤗 Are you sure you're looking at main?

Hope this solves your issue 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants