1+ from typing import Optional
2+
13import torch
24import torch .nn as nn
35from torch .nn import functional as F
@@ -102,8 +104,14 @@ def get_ground_truth(self, device, num_logits) -> torch.Tensor:
102104 def get_logits (self , image_features , text_features , logit_scale ):
103105 if self .world_size > 1 :
104106 all_image_features , all_text_features = gather_features (
105- image_features , text_features ,
106- self .local_loss , self .gather_with_grad , self .rank , self .world_size , self .use_horovod )
107+ image_features ,
108+ text_features ,
109+ local_loss = self .local_loss ,
110+ gather_with_grad = self .gather_with_grad ,
111+ rank = self .rank ,
112+ world_size = self .world_size ,
113+ use_horovod = self .use_horovod ,
114+ )
107115
108116 if self .local_loss :
109117 logits_per_image = logit_scale * image_features @ all_text_features .T
@@ -158,12 +166,11 @@ def __init__(
158166 self .caption_loss = nn .CrossEntropyLoss (ignore_index = pad_id )
159167
160168 def forward (self , image_features , text_features , logits , labels , logit_scale , output_dict = False ):
161-
162- clip_loss = torch .tensor (0 )
163-
164169 if self .clip_loss_weight :
165170 clip_loss = super ().forward (image_features , text_features , logit_scale )
166171 clip_loss = self .clip_loss_weight * clip_loss
172+ else :
173+ clip_loss = torch .tensor (0 , device = logits .device )
167174
168175 caption_loss = self .caption_loss (
169176 logits .permute (0 , 2 , 1 ),
@@ -316,19 +323,17 @@ class SigLipLoss(nn.Module):
316323 """
317324 def __init__ (
318325 self ,
319- cache_labels = False ,
320- rank = 0 ,
321- world_size = 1 ,
322- use_horovod = False ,
323- impl = 'bidir' ,
326+ cache_labels : bool = False ,
327+ rank : int = 0 ,
328+ world_size : int = 1 ,
329+ dist_impl : Optional [str ] = None ,
324330 ):
325331 super ().__init__ ()
326332 self .cache_labels = cache_labels
327333 self .rank = rank
328334 self .world_size = world_size
329- assert not use_horovod # FIXME need to look at hvd ops for ring transfers
330- self .use_horovod = use_horovod
331- self .impl = impl
335+ self .dist_impl = dist_impl or 'bidir' # default to bidir exchange for now, this will likely change
336+ assert self .dist_impl in ('bidir' , 'shift' , 'reduce' , 'gather' )
332337
333338 # cache state FIXME cache not currently used, worthwhile?
334339 self .prev_num_logits = 0
@@ -361,7 +366,7 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output
361366 loss = self ._loss (image_features , text_features , logit_scale , logit_bias )
362367
363368 if self .world_size > 1 :
364- if self .impl == 'bidir' :
369+ if self .dist_impl == 'bidir' :
365370 right_rank = (self .rank + 1 ) % self .world_size
366371 left_rank = (self .rank - 1 + self .world_size ) % self .world_size
367372 text_features_to_right = text_features_to_left = text_features
@@ -396,7 +401,7 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output
396401 logit_bias ,
397402 negative_only = True ,
398403 )
399- elif self .impl == "shift" :
404+ elif self .dist_impl == "shift" :
400405 right_rank = (self .rank + 1 ) % self .world_size
401406 left_rank = (self .rank - 1 + self .world_size ) % self .world_size
402407 text_features_to_right = text_features
@@ -414,7 +419,7 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output
414419 negative_only = True ,
415420 )
416421 text_features_to_right = text_features_from_left
417- elif self .impl == "reduce" :
422+ elif self .dist_impl == "reduce" :
418423 for i in range (self .world_size ):
419424 text_from_other = torch .distributed .nn .all_reduce (
420425 text_features * (self .rank == i ),
@@ -427,7 +432,7 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output
427432 logit_bias ,
428433 negative_only = True ,
429434 )
430- elif self .impl == "gather" :
435+ elif self .dist_impl == "gather" :
431436 all_text = torch .distributed .nn .all_gather (text_features )
432437 for i in range (self .world_size ):
433438 loss += float (i != self .rank ) * self ._loss (
0 commit comments