-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add transformer class for review #11491
base: main
Are you sure you want to change the base?
Conversation
|
||
self.d_model = d_model | ||
self.non_linearity = nn.GELU(approximate="tanh") | ||
self.proj = ConvNorm(d_model, d_model * 4, bias=bias, kernel_size=kernel_size, is_causal=is_causal) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the FFN size be a configuration instead of hardcoded to 4 * d_model
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Paarth changed this
q = self.q_net(query).reshape(Bq, Tq, self.n_heads, self.d_head) | ||
kv = self.kv_net(memory).reshape(Bkv, Tkv, 2, self.n_heads, self.d_head) | ||
if self.pos_emb_name == 'rope': | ||
q, kv = self.rope(q, kv) | ||
elif self.pos_emb_name == 'alibi': | ||
alibi_slopes = self.m[:, 0, 0] | ||
q = q[~query_mask].reshape(-1, self.n_heads, self.d_head) | ||
kv = kv[~memory_mask].reshape(-1, 2, self.n_heads, self.d_head) | ||
lengths_q = (~query_mask).sum(1) | ||
lengths_k = (~memory_mask).sum(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code like this will be a lot easier to read if we replace .reshape () with einops rearrange(), and add comments with the output shapes for operations that are not reshape/rearrange.
if self.has_xattn: | ||
self.cross_attention.reset_cache(use_cache) | ||
|
||
def forward(self, x, x_mask, cond, cond_mask, dump_attention=False, attn_prior=None, idx=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If cond and cond_mask are optional we should default them to None.
Should we throw an error if cond
is provided, but self.has_xattn
is False? Or if cond
is not provided, but self.has_xattn
is True?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can default them to None, but I wouldnt raise an error if has_xattn is True and cond is None. I use that feature to pretrain the decoder with context as None, but still having the same architecture and parameters when using it as the pretrained T5 decoder for TTS,
p_dropout=p_dropout, | ||
is_causal=False, | ||
is_self_attention=False, | ||
d_memory=params['d_heads'], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should rename d_heads
in params to d_memory
here. d_memory
is supposed to be the dimension of the context information for cross attention. d_heads
refers to the size of each attention head, but which this code hardcodes to be d_memory // n_heads
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We no longer use a params dict so this should no longer happen.
use_flash_self_attention=True, | ||
use_flash_x_attention=True, | ||
deterministic=False, | ||
pos_emb={"name": "learnable"}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make the pos_emb argument more structured? Either a dataclass, or similar to xattn flatten into parameters like pos_emb_name
, pos_emb_base
, pos_emb_kwargs
, etc.
Nitpick: Mutable objects like dictionaries should not be used as default arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vote for @dataclass
to group the configs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We got rid of this parameter, and changed it to a bool
has_xattn, | ||
xa_d_memory=None, | ||
xa_n_heads=None, | ||
xa_pos_emb=None, | ||
xa_max_length_causal_mask=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we group these into a CrossAttentionConfig
dataclass? To make it clear which arguments are related/optional. Then we can check if the config is None rather than the has_xattn
flag.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
went over the full implementations, and leave notes about bugfixes, and recommendations on refactoring for readability.
[1]: Attention scores used for CTC loss (only in naive attention). | ||
""" | ||
|
||
y, attn_prob = self.attn_naive(query, query_mask, memory, memory_mask, attn_prior) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can move the code from attn_naive() into forward() now that there is now flash attention branch.
x_mask <bool mask> (B, T1): True where ignoring is required | ||
cond <torch tensor> (B, T2, C): Conditioning tensor | ||
cond_mask <bool mask> (B, T2): True where ignoring is required |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we reverse this to the standard mask convention where you provide False for indices that should be ignored? The only time the inverted mask (True is ignore) is used in the code is when attention fills masked values with (-inf) before softmax.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We changed to the standard mask convention.
if self.use_cache and self.cache['memory'] is not None: | ||
memory = self.cache['memory'] | ||
else: | ||
memory = self.norm_xattn_memory(cond) if self.apply_norm_to_cond else cond | ||
if self.use_cache: | ||
self.cache['memory'] = memory |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of caching "cond"? Isn't it a constant (e.g. context information)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Need to address the minor bugs before approval.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added 3 commits to figure out comments left by myself. So approved from my side. Please have a look again before merging.
Signed-off-by: Jason <[email protected]>
Signed-off-by: Xuesong Yang <[email protected]>
Signed-off-by: Xuesong Yang <[email protected]>
Signed-off-by: Xuesong Yang <[email protected]>
bugfix. Signed-off-by: Xuesong Yang <[email protected]>
bugfix Signed-off-by: Xuesong Yang <[email protected]>
Signed-off-by: Jason <[email protected]>
Signed-off-by: blisc <[email protected]>
It requires that `xa_d_memory` and `xa_n_heads` are specified when `has_xattn` is True Signed-off-by: Xuesong Yang <[email protected]>
Signed-off-by: Xuesong Yang <[email protected]>
28fcc24
to
d6b05ac
Compare
Signed-off-by: XuesongYang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CodeQL found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.
Signed-off-by: Xuesong Yang <[email protected]>
Signed-off-by: Xuesong Yang <[email protected]>
added a unit test. This is the necessary test to ensure the forward pass of Transformer class succeed. The multiple conditions from difference encoders failed the tests ( pls run |
Signed-off-by: Paarth Neekhara <[email protected]>
Signed-off-by: paarthneekhara <[email protected]>
beep boop 🤖: 🚨 The following files must be fixed before merge! Your code was analyzed with PyLint. The following annotations have been identified:
Mitigation guide:
By applying these rules, we reduce the occurance of this message in future. Thank you for improving NeMo's documentation! |
@XuesongYang We need to pass FYI, I tested this new transformer code training and inference with t5tts locally, and it seems to be working fine. Also for a fixed set of weights, the transformer implementation in |
Added the transformer stack currently being used in T5TTS - Identify unused code paths, clean up the code and see what modules can be reused.