Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

k-NN clusters as labels or pairs? #741

Open
stanleyjs opened this issue Feb 7, 2025 · 1 comment
Open

k-NN clusters as labels or pairs? #741

stanleyjs opened this issue Feb 7, 2025 · 1 comment

Comments

@stanleyjs
Copy link

Hello Kevin,
Thank you for your work on pytorch_metric_learning. I'm looking to reimplement some experiments of my own under your package. I'm not sure what the best way to go about implementing this might be, so wondering if you could give some pointers.

We are training a contrastive model that is similar to SpectralNet https://github.com/shaham-lab/SpectralNet/tree/main. We're doing domain adaptation in which there are pairs of points from each domain, and, critically, we have the ground-truth knns for one domain. That is, for a point sampled in two domains (xi,yi), we know xi's k nearest neighbors kNN(xi). We would like to learn two embedding functions (g_X and g_Y) such that kNN(g_X(xi)) and kNN(g_Y(yi)) approximate kNN(xi).

The way I do this right now is to take a batch of paired xs and ys, and for each point i (which is now an anchor) I get its pre-computed kNN(xis). These are positive examples. The negative examples are simply the remainder of the batch (for each xi). This works but is slow and perhaps suboptimal.

I think the current strategy I just described can be done just using the indices_tuple argument to many of the losses in this library.

However, I would like to implement this using a miner. What I am unclear about is how to define labels in this context. Essentially, every point is its own class centroid (as defined by kNN), but it also belongs to upto batch_size-1 many other classes. Of course, kNN is taken to be small so this isn't true, but the point is that this is a multi-label situation.

Do you have any recommendations for how to approach this?

@KevinMusgrave
Copy link
Owner

KevinMusgrave commented Feb 14, 2025

Apologies for the late response.

I can't think of a way to construct the labels for the miner. As you say it's a multi-label situation.

Are you already able to construct the indices_tuple? If so, is there something you want to do with the miner that you can't do with the indices_tuple you already have?

If you want to filter or weight the loss for specific pairs, maybe you could write a custom reducer, or pass reducer=DoNothingReducer() to the loss function and filter/weight each pair's loss before computing the average.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants