Skip to content

Commit e9d6a90

Browse files
committed
more efficient implementation
1 parent b7d9eb6 commit e9d6a90

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

src/pytorch_metric_learning/losses/generic_pair_loss.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,4 @@ def pair_based_loss(self, mat, indices_tuple):
4242

4343
@staticmethod
4444
def _assert_either_pos_or_neg(pos_mask, neg_mask):
45-
pos_indices = set(pos_mask.flatten().nonzero().flatten().tolist())
46-
neg_indices = set(neg_mask.flatten().nonzero().flatten().tolist())
47-
assert (
48-
pos_indices.isdisjoint(neg_indices)
49-
), "Each pair should be either be positive or negative"
45+
assert not torch.any((pos_mask != 0) & (neg_mask != 0)), "Each pair should be either be positive or negative"

0 commit comments

Comments
 (0)