Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

a solution to solve memory issues (but slows down training a bit) #269

Open
fawazsammani opened this issue Dec 27, 2023 · 0 comments
Open

Comments

@fawazsammani
Copy link

fawazsammani commented Dec 27, 2023

Not really an issue, just a solution which requires a lot less memory (18x less). I think it would be helpful for lots of people. So i'll post it:

MultiCropping eats a lot of GPU memory, because instead of saving 1 computation graph, you end up saving 18 computation graphs (18 is the n_loss_terms in the code below if the n_local_crops = 8). So just run every crop separately through the student and backprop with loss.backward() (don't update the weights with optimizer.step() yet, rather accumulate gradients for all global-local pairs). This will compute the gradients for every global-local pair and clear its computation graph before starting a new pair. After accumulating grads for all pairs, then run optimizer.step(). Using this implementation saves a lot of memory. I was able to use a large batch size and train it on a single GPU.

class DINOLoss(nn.Module):
    def __init__(self, out_dim = 65536, teacher_temp = 0.04, student_temp=0.1, center_momentum=0.9):
        
        super().__init__()
        self.teacher_temp = teacher_temp
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.register_buffer("center", torch.zeros(1, out_dim))

    def forward(self, student, student_feats, teacher_output, epoch):
        """
        Cross-entropy between softmax outputs of the teacher and student networks.
        student_feats contains a list of tensors and len(student_feats) = n_local_crops + 2
        """
        teacher_out = F.softmax((teacher_output - self.center) / self.teacher_temp, dim=-1)
        teacher_out = teacher_out.detach().chunk(2)
        self.update_center(teacher_output)
        n_loss_terms = (len(teacher_out) * len(student_feats)) - len(teacher_out)
        total_loss = 0
        
        for iq, q in enumerate(teacher_out):
            for v, chunk in enumerate(student_feats):
                if iq == v:
                    continue
                student_output = student(chunk)   # forward computation graph
                student_output = student_output / self.student_temp
                loss = torch.sum(-q * F.log_softmax(student_output, dim=-1), dim=-1)
                loss = loss.mean() / n_loss_terms     
                loss.backward()        # accumulate grads and then clear computation graph
                total_loss += loss    # for printing 

        return total_loss
dino_loss = DINOLoss()
teacher_feats = torch.cat(student_feats[:2]).clone().detach() 
teacher_output = teacher(teacher_feats)  # only the 2 global views pass through the teacher
loss = dino_loss(student, student_feats, teacher_output, epoch)

Note that in the code student_feats are the images (they are named feats for another reason)
Hope it helps :)

@fawazsammani fawazsammani changed the title a solution to solve memory issues (but slows down training) a solution to solve memory issues (slows down training a tiny bit) Dec 27, 2023
@fawazsammani fawazsammani changed the title a solution to solve memory issues (slows down training a tiny bit) a solution to solve memory issues (slows down training a bit) Dec 27, 2023
@fawazsammani fawazsammani changed the title a solution to solve memory issues (slows down training a bit) a solution to solve memory issues (but slows down training a bit) Dec 27, 2023
@fawazsammani fawazsammani reopened this Feb 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant