-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add NGen3 #36901
base: main
Are you sure you want to change the base?
Add NGen3 #36901
Conversation
@ArthurZucker and @Rocketknight1 |
class Block(nn.Module): | ||
def __init__(self, config: NGEN3Config): | ||
super().__init__() | ||
self.ln1 = nn.LayerNorm(config.n_embd) | ||
self.attn = CausalSelfAttention(config) | ||
self.ln2 = nn.LayerNorm(config.n_embd) | ||
self.mlp = MoEMLP(config) if config.use_moe else MLP(config) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
# Residual connection for attention | ||
residual = x | ||
x = self.ln1(x) | ||
x = self.attn(x) | ||
x = residual + x | ||
# Residual connection for feedforward | ||
residual = x | ||
x = self.ln2(x) | ||
x = self.mlp(x) | ||
return residual + x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you try to align variable names / layer names with LlamaDecoderLayer
here?
x = self.mlp(x) | ||
return residual + x | ||
|
||
class CausalSelfAttention(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be completely equivalent to a variation of LlamaAttention with fused qkv. Nothing very new and the mask should not be saved there
@Thishyaketh we need a proper description of the PR with:
|
This PR Fixes Issues in uploading NGen 3 to Transformers