-
Notifications
You must be signed in to change notification settings - Fork 28.7k
Remove low_cpu_mem_usage and _fast_init #36963
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
dfa2228
to
3eefc2c
Compare
1551862
to
7658fc2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mega nice! 🤗
this check on each state dict at loading time (after the first loaded checkpoint, there are no way to initialize only the | ||
mismatched weights if any, without overwriting the previously loaded weights as well because all the module will be | ||
initialized, not only the weights that are mismatched). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be False if we define _linear_init
and embeding_init
or a mapping from layer type to mapping func in the _init_weights
for example. Time for refactor so if this can be simplified we would only have to, at the weight loading time, call the func
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like iterating twice but this should happen less than before so all good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, not a fan either. It's definitely something to keep in mind - doing the inits at the weight level would be super nice in the future!
FIX CIs please! |
Yup, it's just the test that I'm removing that got added to qwen3 - just removed it with |
* Remove low_cpu_mem_usage and _fast_init * Update deepspeed.py * Update modeling_utils.py * remove the first 2 tests everywhere * Update test_modeling_common.py * remove what was remaining about fast_init * fix logic and simplify * mismatched keys logic update * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * fix 2 models init_weights * extend to others * remove grad * Update modeling_fsmt.py * init weights in tests * style * Update test_modeling_fsmt.py * more old models * fix more init_weights * copies * fix * style * Update modeling_lxmert.py * fix inits * more and more * more * should finalize * style * Update modeling_dinov2_with_registers.py * fix * Update modeling_encoder_decoder.py * fix * style * Update modeling_lxmert.py * post rebase cleanup * Update modeling_informer.py * back to start for device * fix * add test to detect all failing cases correctly * Update test_modeling_common.py * fix * fix * sam * style * Update modeling_maskformer_swin.py * CIs * CIs * remove test - will add it on separate PR * fix * fix * Update modeling_sam.py * CIs * CIs * CIs * convnext * suggestions * CIs * fix copies after merge --------- Co-authored-by: Yih-Dar <[email protected]>
* Remove low_cpu_mem_usage and _fast_init * Update deepspeed.py * Update modeling_utils.py * remove the first 2 tests everywhere * Update test_modeling_common.py * remove what was remaining about fast_init * fix logic and simplify * mismatched keys logic update * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * fix 2 models init_weights * extend to others * remove grad * Update modeling_fsmt.py * init weights in tests * style * Update test_modeling_fsmt.py * more old models * fix more init_weights * copies * fix * style * Update modeling_lxmert.py * fix inits * more and more * more * should finalize * style * Update modeling_dinov2_with_registers.py * fix * Update modeling_encoder_decoder.py * fix * style * Update modeling_lxmert.py * post rebase cleanup * Update modeling_informer.py * back to start for device * fix * add test to detect all failing cases correctly * Update test_modeling_common.py * fix * fix * sam * style * Update modeling_maskformer_swin.py * CIs * CIs * remove test - will add it on separate PR * fix * fix * Update modeling_sam.py * CIs * CIs * CIs * convnext * suggestions * CIs * fix copies after merge --------- Co-authored-by: Yih-Dar <[email protected]>
This commit broke Deepspeed ZeRO3.
|
If you run the deepspeed tests they fail:
cc: @ydshieh |
We have some issues with deepspeed CI job and fail to detect any issue. I opened a PR to resolve it , but we can proceed this issue directly. |
This reverts commit f304318.
Hey, just a small heads up @Cyrilvallez The line: init_contexts = [no_init_weights(), init_empty_weights()] may fail if accelerate is not installed, since the |
Hey @BenjaminBossan! Thanks a lot for checking in! We took care of this one already! 🤗 |
Great to hear, thanks for your quick work. Is the fix merged on main? |
Yes! |
Hi there, import transformers
model = transformers.AutoModel.from_pretrained('fixie-ai/ultravox-v0_5-llama-3_2-1b', trust_remote_code=True)
print(model.language_model.device)
# with transformers==4.50.3 model is on cpu
# but on transformers==4.51, on main, and from this PR onwards
# model is on meta device Any idea why? For more info, this is likely related to the fact that Related issue: vllm-project/vllm#16149 |
Hey @farzadab! Note that this was already broken before when using e.g. a |
Hi this is causing weights to all be meta tensors when from_flax=True for at least whisper and bert. This was already broken with low_cpu_mem_usage before so not unexpected I guess. from transformers import BertModel
model = BertModel.from_pretrained("google-bert/bert-base-uncased", from_flax=True) # also 'openai/whisper-tiny'
assert model.state_dict()['embeddings.word_embeddings.weight'].is_meta I'm only using this to convert from jax so not a big deal to use an older version. Sorry if you're already aware of this issue. |
Hey @sssshhhhhh! We were actually not aware of it, it got through the radars as However, we will very soon start to deprecate flax and tf. As loading As a result, I do think the easiest is indeed to use older version to convert to pytorch if needed, then resave them. Would that be an acceptable way to proceed for your use-case? Or would that provide too much friction/disconfort? |
@Cyrilvallez I spent a lot of time trying to figure this out but I'm still left with no good solution. Changing the behaviour of The best I can think of is to overwrite |
Hey! Yes, repos are expected to contain all their weights, so things would be much simpler if you added all weights to your repo directly (i.e. the weights of the submodels). However, if you don't want to do that, I believe def _init_weights(self, module):
if module is self.language_model:
self.language_model = module.from_pretrained(...)
elif module in self.language_model.modules():
pass
.... but passing specific args (i.e. same args as the outer call) to that inner from_pretrained will require a bit more hack
Hope this solves your issue 🤗 |
What does this PR do?
This PR removes the now useless
_fast_init
andlow_cpu_mem_usage
infrom_pretrained
, in order to simplify even more and limit the number of code paths, in the end making it much easier to maintain/debug. These 2 parameters should always be True anyway for optimized model loading.Because a LOT of models have bad
_init_weights()
methods (i.e. it does not init ALL parameters), it might be an issue if loading corrupted state dict (i.e. loading a state dict with missing weight, and one of the missing weight not being handled by_init_weights
properly). However, this should not be an issue in general as we don't expect to have too many corrupted state dicts on the hub. Moreover, this bug is ALREADY PRESENT whenever loading such a model with adevice_map
, orlow_cpu_mem_usage=True
(or whatever option ending in activatinglow_cpu_mem_usage=True
). This is because doing so will force to load the parameters on meta, so weights initialized in the__init__
of aLayer
or similar (which assumes instantiating the model on cpu) will result in wrong weight init when moving back to cpu.Nevertheless, it can be hard to debug, and should not be the case, so this PR already fixes some model's
_init_weights
. Jointly, #37070 adds a test to always detect if a model's_init_weights
is missing a few parameters, and I will fix more models directly in it (it relies on the fact that_fast_init
andlow_cpu_mem_usage
are already gone).Fun fact: even our faithful Llama has a bad
_init_weights
!! (missing the RMSNorm) 🤯Most of the files changed are simply removing old
_fast_init
tests (which were skipped anyway 🙃🙃), as well as fixing weight initialization for a few models that were blocking general CI tests.