Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add transformer class for review #11491

Draft
wants to merge 44 commits into
base: main
Choose a base branch
from

Conversation

paarthneekhara
Copy link
Collaborator

Added the transformer stack currently being used in T5TTS - Identify unused code paths, clean up the code and see what modules can be reused.

@github-actions github-actions bot added the TTS label Dec 5, 2024
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
@blisc blisc requested review from XuesongYang and rlangman December 9, 2024 18:33
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved

self.d_model = d_model
self.non_linearity = nn.GELU(approximate="tanh")
self.proj = ConvNorm(d_model, d_model * 4, bias=bias, kernel_size=kernel_size, is_causal=is_causal)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the FFN size be a configuration instead of hardcoded to 4 * d_model?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Paarth changed this

nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
Comment on lines 268 to 277
q = self.q_net(query).reshape(Bq, Tq, self.n_heads, self.d_head)
kv = self.kv_net(memory).reshape(Bkv, Tkv, 2, self.n_heads, self.d_head)
if self.pos_emb_name == 'rope':
q, kv = self.rope(q, kv)
elif self.pos_emb_name == 'alibi':
alibi_slopes = self.m[:, 0, 0]
q = q[~query_mask].reshape(-1, self.n_heads, self.d_head)
kv = kv[~memory_mask].reshape(-1, 2, self.n_heads, self.d_head)
lengths_q = (~query_mask).sum(1)
lengths_k = (~memory_mask).sum(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code like this will be a lot easier to read if we replace .reshape () with einops rearrange(), and add comments with the output shapes for operations that are not reshape/rearrange.

nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
if self.has_xattn:
self.cross_attention.reset_cache(use_cache)

def forward(self, x, x_mask, cond, cond_mask, dump_attention=False, attn_prior=None, idx=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If cond and cond_mask are optional we should default them to None.

Should we throw an error if cond is provided, but self.has_xattn is False? Or if cond is not provided, but self.has_xattn is True?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can default them to None, but I wouldnt raise an error if has_xattn is True and cond is None. I use that feature to pretrain the decoder with context as None, but still having the same architecture and parameters when using it as the pretrained T5 decoder for TTS,

nemo/collections/tts/modules/transformer_dec24.py Outdated Show resolved Hide resolved
p_dropout=p_dropout,
is_causal=False,
is_self_attention=False,
d_memory=params['d_heads'],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should rename d_heads in params to d_memory here. d_memory is supposed to be the dimension of the context information for cross attention. d_heads refers to the size of each attention head, but which this code hardcodes to be d_memory // n_heads.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We no longer use a params dict so this should no longer happen.

@blisc blisc self-requested a review December 17, 2024 19:04
use_flash_self_attention=True,
use_flash_x_attention=True,
deterministic=False,
pos_emb={"name": "learnable"},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make the pos_emb argument more structured? Either a dataclass, or similar to xattn flatten into parameters like pos_emb_name, pos_emb_base, pos_emb_kwargs, etc.

Nitpick: Mutable objects like dictionaries should not be used as default arguments.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vote for @dataclass to group the configs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We got rid of this parameter, and changed it to a bool

Comment on lines 376 to 380
has_xattn,
xa_d_memory=None,
xa_n_heads=None,
xa_pos_emb=None,
xa_max_length_causal_mask=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we group these into a CrossAttentionConfig dataclass? To make it clear which arguments are related/optional. Then we can check if the config is None rather than the has_xattn flag.

Copy link
Collaborator

@XuesongYang XuesongYang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

went over the full implementations, and leave notes about bugfixes, and recommendations on refactoring for readability.

nemo/collections/tts/modules/transformer_2412.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_2412.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_2412.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_2412.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_2412.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_2412.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_2412.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_2412.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_2412.py Outdated Show resolved Hide resolved
blisc
blisc previously approved these changes Jan 7, 2025
[1]: Attention scores used for CTC loss (only in naive attention).
"""

y, attn_prob = self.attn_naive(query, query_mask, memory, memory_mask, attn_prior)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can move the code from attn_naive() into forward() now that there is now flash attention branch.

Comment on lines 493 to 495
x_mask <bool mask> (B, T1): True where ignoring is required
cond <torch tensor> (B, T2, C): Conditioning tensor
cond_mask <bool mask> (B, T2): True where ignoring is required
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we reverse this to the standard mask convention where you provide False for indices that should be ignored? The only time the inverted mask (True is ignore) is used in the code is when attention fills masked values with (-inf) before softmax.

Copy link
Collaborator

@blisc blisc Jan 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We changed to the standard mask convention.

if self.use_cache and self.cache['memory'] is not None:
memory = self.cache['memory']
else:
memory = self.norm_xattn_memory(cond) if self.apply_norm_to_cond else cond
if self.use_cache:
self.cache['memory'] = memory
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of caching "cond"? Isn't it a constant (e.g. context information)?

@XuesongYang XuesongYang self-requested a review January 9, 2025 09:36
Copy link
Collaborator

@XuesongYang XuesongYang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Need to address the minor bugs before approval.

nemo/collections/tts/modules/transformer_2412.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_2412.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_2412.py Outdated Show resolved Hide resolved
nemo/collections/tts/modules/transformer_2412.py Outdated Show resolved Hide resolved
XuesongYang
XuesongYang previously approved these changes Jan 9, 2025
Copy link
Collaborator

@XuesongYang XuesongYang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added 3 commits to figure out comments left by myself. So approved from my side. Please have a look again before merging.

Copy link

@github-advanced-security github-advanced-security bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CodeQL found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.

@github-actions github-actions bot removed the ASR label Jan 18, 2025
@XuesongYang XuesongYang removed the audio label Jan 18, 2025
tests/collections/tts/modules/test_tts_new_transformer.py Dismissed Show dismissed Hide dismissed
tests/collections/tts/modules/test_tts_new_transformer.py Dismissed Show dismissed Hide dismissed
tests/collections/tts/modules/test_tts_new_transformer.py Dismissed Show dismissed Hide dismissed
tests/collections/tts/modules/test_tts_new_transformer.py Dismissed Show dismissed Hide dismissed
tests/collections/tts/modules/test_tts_new_transformer.py Dismissed Show dismissed Hide dismissed
tests/collections/tts/modules/test_tts_new_transformer.py Dismissed Show dismissed Hide dismissed
tests/collections/tts/modules/test_tts_new_transformer.py Dismissed Show dismissed Hide dismissed
tests/collections/tts/modules/test_tts_new_transformer.py Dismissed Show dismissed Hide dismissed
tests/collections/tts/modules/test_tts_new_transformer.py Dismissed Show dismissed Hide dismissed
@XuesongYang
Copy link
Collaborator

XuesongYang commented Jan 18, 2025

added a unit test. This is the necessary test to ensure the forward pass of Transformer class succeed.

The multiple conditions from difference encoders failed the tests (test_forward_causal_self_attn_and_has_xattn). It seems a list of tensors are not supported. @paarthneekhara could you pls verify?

pls run pytest -s -vvv tests/collections/tts/modules/test_tts_new_transformer.py locally to test the code.

Copy link
Contributor

beep boop 🤖: 🚨 The following files must be fixed before merge!


Your code was analyzed with PyLint. The following annotations have been identified:

************* Module nemo.collections.tts.modules.transformer_2412
nemo/collections/tts/modules/transformer_2412.py:26:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/tts/modules/transformer_2412.py:85:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/tts/modules/transformer_2412.py:94:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/tts/modules/transformer_2412.py:138:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/tts/modules/transformer_2412.py:171:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/tts/modules/transformer_2412.py:191:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/tts/modules/transformer_2412.py:195:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/tts/modules/transformer_2412.py:280:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/tts/modules/transformer_2412.py:340:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/tts/modules/transformer_2412.py:398:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/tts/modules/transformer_2412.py:471:4: C0116: Missing function or method docstring (missing-function-docstring)
nemo/collections/tts/modules/transformer_2412.py:535:0: C0115: Missing class docstring (missing-class-docstring)
nemo/collections/tts/modules/transformer_2412.py:626:4: C0116: Missing function or method docstring (missing-function-docstring)

-----------------------------------
Your code has been rated at 9.47/10

Mitigation guide:

  • Add sensible and useful docstrings to functions and methods
  • For trivial methods like getter/setters, consider adding # pylint: disable=C0116 inside the function itself
  • To disable multiple functions/methods at once, put a # pylint: disable=C0116 before the first and a # pylint: enable=C0116 after the last.

By applying these rules, we reduce the occurance of this message in future.

Thank you for improving NeMo's documentation!

@paarthneekhara
Copy link
Collaborator Author

@XuesongYang We need to pass multi_encoder_mapping to the forward function as well for multi-encoder case. I have updated the test case with some comments (that still need to be incorporated) and fixes. Also corrected the x = (x + x_) * x_mask.unsqueeze(-1) bug which I believe was inserted when the mask was flipped.

FYI, I tested this new transformer code training and inference with t5tts locally, and it seems to be working fine. Also for a fixed set of weights, the transformer implementation in experimentalt5tts and this branch, give the same output, so I think we should be good.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants