Skip to content

flip the mask in BestRQ #68

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions i6_models/parts/best_rq/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def forward(
) -> Tuple[torch.Tensor, torch.Tensor]:
ndim_batch, ndim_time, _ = tensor.size()

mask = torch.zeros((ndim_batch, ndim_time), dtype=torch.bool)
mask = torch.ones((ndim_batch, ndim_time), dtype=torch.bool)

mask_idcs = []
for i in range(ndim_batch):
Expand All @@ -67,7 +67,7 @@ def forward(
mask_idc = np.random.choice(seq_len - min_len, num_mask, replace=False)

for j in mask_idc:
mask[i, j : j + self.mask_length] = True
mask[i, j : j + self.mask_length] = False

tensor[mask] = self.mask_emb.to(tensor.device)

Expand Down
4 changes: 2 additions & 2 deletions i6_models/parts/best_rq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class RandomProjectionQuantizer(nn.Module):
"""
implement the fixed random projection quantizer from BestRQ
C.f. https://arxiv.org/pdf/2202.01855 for theoretic background
code adapted from https://github.com/speechbrain/speechbrain/blob/16b6420d4ff23210cfca2e888be8853264e0cb17/speechbrain/nnet/quantisers.py#L127
code adapted from https://github.com/speechbrain/speechbrain/blob/7edb1397d8f92bb4fcaf17eb08e366e5b639ae88/speechbrain/nnet/quantisers.py#L127
"""

def __init__(self, input_dim, codebook_dim, codebook_num_vars):
Expand All @@ -33,5 +33,5 @@ def __init__(self, input_dim, codebook_dim, codebook_num_vars):
self.register_buffer("CB", F.normalize(torch.randn(codebook_num_vars, codebook_dim)))

def forward(self, x: torch.tensor) -> torch.tensor:
x = F.normalize(x @ self.P)
x = F.normalize(x @ self.P, dim=2)
return vector_norm((self.CB.unsqueeze(1) - x.unsqueeze(1)), dim=-1).argmin(dim=1)