diff --git a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py index abe966229ffe..5f315397584b 100644 --- a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py +++ b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py @@ -44,6 +44,7 @@ def __init__( trust_remote_code=False, default_dtype=torch.bfloat16, load_in_4bit=False, + attn_implementation="sdpa", ): super().__init__() self.save_hyperparameters() @@ -58,6 +59,7 @@ def __init__( self.trust_remote_code = trust_remote_code self.default_dtype = default_dtype self.load_in_4bit = load_in_4bit + self.attn_implementation = attn_implementation @property def tokenizer(self): @@ -82,6 +84,7 @@ def configure_model(self): torch_dtype='auto', trust_remote_code=self.trust_remote_code, load_in_4bit=self.load_in_4bit, + attn_implementation=self.attn_implementation, ) else: from transformers import AutoConfig @@ -89,7 +92,10 @@ def configure_model(self): config = AutoConfig.from_pretrained(self.model_name, trust_remote_code=self.trust_remote_code) dtype = getattr(config, 'torch_dtype', self.default_dtype) self.model = AutoModelForCausalLM.from_config( - config, torch_dtype=dtype, trust_remote_code=self.trust_remote_code + config, + torch_dtype=dtype, + trust_remote_code=self.trust_remote_code, + attn_implementation=self.attn_implementation, ) # Apply FSDP2 and TP to the model