Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Similarity matrix shape does not match the shape of the mask #60

Open
hugofigueiras opened this issue Feb 16, 2023 · 2 comments
Open

Similarity matrix shape does not match the shape of the mask #60

hugofigueiras opened this issue Feb 16, 2023 · 2 comments

Comments

@hugofigueiras
Copy link

Hello,

I was currently testing the implementation when an error occured: The shape of the mask [512, 512] at index 0 does not match the shape of the indexed tensor [2, 2] at index 0.
My batch size is 256.

The error occurs in this part of the code:
similarity_matrix = torch.matmul(features, features.T)
mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)
labels = labels[~mask].view(labels.shape[0], -1)
similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)

I'm wondering if this something I'm doing wrong and how do I match the shape of tensors?

Thanks in advance!

@laiyingxin2
Copy link

same question~

@sarda-devesh
Copy link

sarda-devesh commented Jan 18, 2024

The issue is because of the following line of code

labels = torch.cat([torch.arange(self.args.batch_size) for i in range(self.args.n_views)], dim=0)

which assumes that the number of features you have is a multiple of the batch size but this is not always true.

For example, consider a dataset with only a 100 elements and a batch size of 256. In that case it will create a labels of size (256, 256) even though it should only be (200, 200) (assuming you are using n_views of 2). The way to resolve this is by updating the above line to:

labels = torch.cat([torch.arange(int(features.size(0)/2)) for i in range(self.args.n_views)], dim=0)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants