You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Not really an issue, just a solution which requires a lot less memory (18x less). I think it would be helpful for lots of people. So i'll post it:
MultiCropping eats a lot of GPU memory, because instead of saving 1 computation graph, you end up saving 18 computation graphs (18 is the n_loss_terms in the code below if the n_local_crops = 8). So just run every crop separately through the student and backprop with loss.backward() (don't update the weights with optimizer.step() yet, rather accumulate gradients for all global-local pairs). This will compute the gradients for every global-local pair and clear its computation graph before starting a new pair. After accumulating grads for all pairs, then run optimizer.step(). Using this implementation saves a lot of memory. I was able to use a large batch size and train it on a single GPU.
class DINOLoss(nn.Module):
def __init__(self, out_dim = 65536, teacher_temp = 0.04, student_temp=0.1, center_momentum=0.9):
super().__init__()
self.teacher_temp = teacher_temp
self.student_temp = student_temp
self.center_momentum = center_momentum
self.register_buffer("center", torch.zeros(1, out_dim))
def forward(self, student, student_feats, teacher_output, epoch):
"""
Cross-entropy between softmax outputs of the teacher and student networks.
student_feats contains a list of tensors and len(student_feats) = n_local_crops + 2
"""
teacher_out = F.softmax((teacher_output - self.center) / self.teacher_temp, dim=-1)
teacher_out = teacher_out.detach().chunk(2)
self.update_center(teacher_output)
n_loss_terms = (len(teacher_out) * len(student_feats)) - len(teacher_out)
total_loss = 0
for iq, q in enumerate(teacher_out):
for v, chunk in enumerate(student_feats):
if iq == v:
continue
student_output = student(chunk) # forward computation graph
student_output = student_output / self.student_temp
loss = torch.sum(-q * F.log_softmax(student_output, dim=-1), dim=-1)
loss = loss.mean() / n_loss_terms
loss.backward() # accumulate grads and then clear computation graph
total_loss += loss # for printing
return total_loss
dino_loss = DINOLoss()
teacher_feats = torch.cat(student_feats[:2]).clone().detach()
teacher_output = teacher(teacher_feats) # only the 2 global views pass through the teacher
loss = dino_loss(student, student_feats, teacher_output, epoch)
Note that in the code student_feats are the images (they are named feats for another reason)
Hope it helps :)
The text was updated successfully, but these errors were encountered:
fawazsammani
changed the title
a solution to solve memory issues (but slows down training)
a solution to solve memory issues (slows down training a tiny bit)
Dec 27, 2023
fawazsammani
changed the title
a solution to solve memory issues (slows down training a tiny bit)
a solution to solve memory issues (slows down training a bit)
Dec 27, 2023
fawazsammani
changed the title
a solution to solve memory issues (slows down training a bit)
a solution to solve memory issues (but slows down training a bit)
Dec 27, 2023
Not really an issue, just a solution which requires a lot less memory (18x less). I think it would be helpful for lots of people. So i'll post it:
MultiCropping eats a lot of GPU memory, because instead of saving 1 computation graph, you end up saving 18 computation graphs (18 is the
n_loss_terms
in the code below if then_local_crops = 8
). So just run every crop separately through the student and backprop withloss.backward()
(don't update the weights withoptimizer.step()
yet, rather accumulate gradients for all global-local pairs). This will compute the gradients for every global-local pair and clear its computation graph before starting a new pair. After accumulating grads for all pairs, then runoptimizer.step()
. Using this implementation saves a lot of memory. I was able to use a large batch size and train it on a single GPU.Note that in the code
student_feats
are the images (they are named feats for another reason)Hope it helps :)
The text was updated successfully, but these errors were encountered: