Skip to content

Commit aeaf2a0

Browse files
committed
Add --lost-dist-impl argument to pick different distributed loss implementations
1 parent 9451ba8 commit aeaf2a0

File tree

3 files changed

+30
-17
lines changed

3 files changed

+30
-17
lines changed

src/open_clip/factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,9 @@ def create_loss(args):
448448
return SigLipLoss(
449449
rank=args.rank,
450450
world_size=args.world_size,
451+
dist_impl=args.loss_dist_impl, # siglip has multiple distributed implementations to choose from
451452
)
453+
452454
return ClipLoss(
453455
local_loss=args.local_loss,
454456
gather_with_grad=args.gather_with_grad,

src/open_clip/loss.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import torch
24
import torch.nn as nn
35
from torch.nn import functional as F
@@ -102,8 +104,14 @@ def get_ground_truth(self, device, num_logits) -> torch.Tensor:
102104
def get_logits(self, image_features, text_features, logit_scale):
103105
if self.world_size > 1:
104106
all_image_features, all_text_features = gather_features(
105-
image_features, text_features,
106-
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
107+
image_features,
108+
text_features,
109+
local_loss=self.local_loss,
110+
gather_with_grad=self.gather_with_grad,
111+
rank=self.rank,
112+
world_size=self.world_size,
113+
use_horovod=self.use_horovod,
114+
)
107115

108116
if self.local_loss:
109117
logits_per_image = logit_scale * image_features @ all_text_features.T
@@ -158,12 +166,11 @@ def __init__(
158166
self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
159167

160168
def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
161-
162-
clip_loss = torch.tensor(0)
163-
164169
if self.clip_loss_weight:
165170
clip_loss = super().forward(image_features, text_features, logit_scale)
166171
clip_loss = self.clip_loss_weight * clip_loss
172+
else:
173+
clip_loss = torch.tensor(0, device=logits.device)
167174

168175
caption_loss = self.caption_loss(
169176
logits.permute(0, 2, 1),
@@ -316,19 +323,17 @@ class SigLipLoss(nn.Module):
316323
"""
317324
def __init__(
318325
self,
319-
cache_labels=False,
320-
rank=0,
321-
world_size=1,
322-
use_horovod=False,
323-
impl='bidir',
326+
cache_labels: bool = False,
327+
rank: int = 0,
328+
world_size: int = 1,
329+
dist_impl: Optional[str] = None,
324330
):
325331
super().__init__()
326332
self.cache_labels = cache_labels
327333
self.rank = rank
328334
self.world_size = world_size
329-
assert not use_horovod # FIXME need to look at hvd ops for ring transfers
330-
self.use_horovod = use_horovod
331-
self.impl = impl
335+
self.dist_impl = dist_impl or 'bidir' # default to bidir exchange for now, this will likely change
336+
assert self.dist_impl in ('bidir', 'shift', 'reduce', 'gather')
332337

333338
# cache state FIXME cache not currently used, worthwhile?
334339
self.prev_num_logits = 0
@@ -361,7 +366,7 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output
361366
loss = self._loss(image_features, text_features, logit_scale, logit_bias)
362367

363368
if self.world_size > 1:
364-
if self.impl == 'bidir':
369+
if self.dist_impl == 'bidir':
365370
right_rank = (self.rank + 1) % self.world_size
366371
left_rank = (self.rank - 1 + self.world_size) % self.world_size
367372
text_features_to_right = text_features_to_left = text_features
@@ -396,7 +401,7 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output
396401
logit_bias,
397402
negative_only=True,
398403
)
399-
elif self.impl == "shift":
404+
elif self.dist_impl == "shift":
400405
right_rank = (self.rank + 1) % self.world_size
401406
left_rank = (self.rank - 1 + self.world_size) % self.world_size
402407
text_features_to_right = text_features
@@ -414,7 +419,7 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output
414419
negative_only=True,
415420
)
416421
text_features_to_right = text_features_from_left
417-
elif self.impl == "reduce":
422+
elif self.dist_impl == "reduce":
418423
for i in range(self.world_size):
419424
text_from_other = torch.distributed.nn.all_reduce(
420425
text_features * (self.rank == i),
@@ -427,7 +432,7 @@ def forward(self, image_features, text_features, logit_scale, logit_bias, output
427432
logit_bias,
428433
negative_only=True,
429434
)
430-
elif self.impl == "gather":
435+
elif self.dist_impl == "gather":
431436
all_text = torch.distributed.nn.all_gather(text_features)
432437
for i in range(self.world_size):
433438
loss += float(i != self.rank) * self._loss(

src/open_clip_train/params.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,12 @@ def parse_args(args):
469469
action="store_true",
470470
help='Use SigLip (sigmoid) loss.'
471471
)
472+
parser.add_argument(
473+
"--loss-dist-impl",
474+
default=None,
475+
type=str,
476+
help='A string to specify a specific distributed loss implementation.'
477+
)
472478

473479
args = parser.parse_args(args)
474480

0 commit comments

Comments
 (0)