Skip to content

Commit

Permalink
surface attn_implementation option (#11873)
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Koumparoulis <[email protected]>
  • Loading branch information
akoumpa authored Jan 19, 2025
1 parent fcd4807 commit 102bac6
Showing 1 changed file with 7 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
trust_remote_code=False,
default_dtype=torch.bfloat16,
load_in_4bit=False,
attn_implementation="sdpa",
):
super().__init__()
self.save_hyperparameters()
Expand All @@ -58,6 +59,7 @@ def __init__(
self.trust_remote_code = trust_remote_code
self.default_dtype = default_dtype
self.load_in_4bit = load_in_4bit
self.attn_implementation = attn_implementation

@property
def tokenizer(self):
Expand All @@ -82,14 +84,18 @@ def configure_model(self):
torch_dtype='auto',
trust_remote_code=self.trust_remote_code,
load_in_4bit=self.load_in_4bit,
attn_implementation=self.attn_implementation,
)
else:
from transformers import AutoConfig

config = AutoConfig.from_pretrained(self.model_name, trust_remote_code=self.trust_remote_code)
dtype = getattr(config, 'torch_dtype', self.default_dtype)
self.model = AutoModelForCausalLM.from_config(
config, torch_dtype=dtype, trust_remote_code=self.trust_remote_code
config,
torch_dtype=dtype,
trust_remote_code=self.trust_remote_code,
attn_implementation=self.attn_implementation,
)

# Apply FSDP2 and TP to the model
Expand Down

0 comments on commit 102bac6

Please sign in to comment.