diff --git a/i6_models/parts/best_rq/mask.py b/i6_models/parts/best_rq/mask.py index 2e88479b..62b09ae3 100644 --- a/i6_models/parts/best_rq/mask.py +++ b/i6_models/parts/best_rq/mask.py @@ -45,7 +45,7 @@ def forward( ) -> Tuple[torch.Tensor, torch.Tensor]: ndim_batch, ndim_time, _ = tensor.size() - mask = torch.zeros((ndim_batch, ndim_time), dtype=torch.bool) + mask = torch.ones((ndim_batch, ndim_time), dtype=torch.bool) mask_idcs = [] for i in range(ndim_batch): @@ -67,7 +67,7 @@ def forward( mask_idc = np.random.choice(seq_len - min_len, num_mask, replace=False) for j in mask_idc: - mask[i, j : j + self.mask_length] = True + mask[i, j : j + self.mask_length] = False tensor[mask] = self.mask_emb.to(tensor.device) diff --git a/i6_models/parts/best_rq/quantizer.py b/i6_models/parts/best_rq/quantizer.py index 8633eb42..ce4b77d7 100644 --- a/i6_models/parts/best_rq/quantizer.py +++ b/i6_models/parts/best_rq/quantizer.py @@ -12,7 +12,7 @@ class RandomProjectionQuantizer(nn.Module): """ implement the fixed random projection quantizer from BestRQ C.f. https://arxiv.org/pdf/2202.01855 for theoretic background - code adapted from https://github.com/speechbrain/speechbrain/blob/16b6420d4ff23210cfca2e888be8853264e0cb17/speechbrain/nnet/quantisers.py#L127 + code adapted from https://github.com/speechbrain/speechbrain/blob/7edb1397d8f92bb4fcaf17eb08e366e5b639ae88/speechbrain/nnet/quantisers.py#L127 """ def __init__(self, input_dim, codebook_dim, codebook_num_vars): @@ -33,5 +33,5 @@ def __init__(self, input_dim, codebook_dim, codebook_num_vars): self.register_buffer("CB", F.normalize(torch.randn(codebook_num_vars, codebook_dim))) def forward(self, x: torch.tensor) -> torch.tensor: - x = F.normalize(x @ self.P) + x = F.normalize(x @ self.P, dim=2) return vector_norm((self.CB.unsqueeze(1) - x.unsqueeze(1)), dim=-1).argmin(dim=1)