Skip to content

Commit

Permalink
feat: Added support of Mutual Channel Loss (#100)
Browse files Browse the repository at this point in the history
* feat: Added implementation of Mutual Channel Loss

* docs: Updated documentation

* test: Updated unittests

* docs: Updated readme

* style: Fixed lint

* test: Expanded MC Loss unittest
  • Loading branch information
frgfm authored Oct 26, 2020
1 parent bcc3ea1 commit d41610b
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 10 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ conda install -c frgfm pylocron
##### Main features

- Activation: [SiLU/Swish](https://arxiv.org/abs/1606.08415), [Mish](https://arxiv.org/abs/1908.08681), [HardMish](https://github.com/digantamisra98/H-Mish), [NLReLU](https://arxiv.org/abs/1908.03682), [FReLU](https://arxiv.org/abs/2007.11824)
- Loss: [Focal Loss](https://arxiv.org/abs/1708.02002), MultiLabelCrossEntropy, [LabelSmoothingCrossEntropy](https://arxiv.org/pdf/1706.03762.pdf), [MixupLoss](https://arxiv.org/pdf/1710.09412.pdf), [ClassBalancedWrapper](https://arxiv.org/abs/1901.05555), [ComplementCrossEntropy](https://arxiv.org/abs/2009.02189)
- Loss: [Focal Loss](https://arxiv.org/abs/1708.02002), MultiLabelCrossEntropy, [LabelSmoothingCrossEntropy](https://arxiv.org/pdf/1706.03762.pdf), [MixupLoss](https://arxiv.org/pdf/1710.09412.pdf), [ClassBalancedWrapper](https://arxiv.org/abs/1901.05555), [ComplementCrossEntropy](https://arxiv.org/abs/2009.02189), [MutualChannelLoss](https://arxiv.org/abs/2002.04264)
- Convolutions: [NormConv2d](https://arxiv.org/pdf/2005.05274v2.pdf), [Add2d](https://arxiv.org/pdf/1912.13200.pdf), [SlimConv2d](https://arxiv.org/pdf/2003.07469.pdf), [PyConv2d](https://arxiv.org/abs/2006.11538)
- Regularization: [DropBlock](https://arxiv.org/abs/1810.12890)
- Pooling: [BlurPool2d](https://arxiv.org/abs/1904.11486), [SPP](https://arxiv.org/abs/1406.4729)
Expand Down
3 changes: 3 additions & 0 deletions docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ Loss functions

.. autoclass:: ComplementCrossEntropy

.. autoclass:: MutualChannelLoss


Loss wrappers
--------------

Expand Down
69 changes: 67 additions & 2 deletions holocron/nn/functional.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from math import floor
from math import floor, ceil
import torch
from torch import Tensor
import torch.nn.functional as F
from typing import Optional, Callable, Union, Tuple, List


__all__ = ['silu', 'mish', 'hard_mish', 'nl_relu', 'focal_loss', 'multilabel_cross_entropy', 'ls_cross_entropy',
'complement_cross_entropy', 'norm_conv2d', 'add2d', 'dropblock2d']
'complement_cross_entropy', 'mutual_channel_loss', 'norm_conv2d', 'add2d', 'dropblock2d']


def silu(x: Tensor) -> Tensor:
Expand Down Expand Up @@ -315,6 +315,71 @@ def complement_cross_entropy(
return F.cross_entropy(x, target, weight, ignore_index=ignore_index, reduction=reduction) + gamma * loss


def mutual_channel_loss(
x: Tensor,
target: Tensor,
weight: Optional[Tensor] = None,
ignore_index: int = -100,
reduction: str = 'mean',
chi: int = 2,
alpha: float = 1.
) -> Tensor:
"""Implements the mutual channel loss from
`"The Devil is in the Channels: Mutual-Channel Loss for Fine-Grained Image Classification"
<https://arxiv.org/pdf/2002.04264.pdf>`_.
Args:
x (torch.Tensor[N, K, ...]): input tensor
target (torch.Tensor[N, ...]): target tensor
weight (torch.Tensor[K], optional): manual rescaling of each class
ignore_index (int, optional): specifies target value that is ignored and do not contribute to gradient
reduction (str, optional): reduction method
chi (int, optional): num of features per class
alpha (float, optional): diversity factor
Returns:
torch.Tensor: loss reduced with `reduction` method
"""

# Flatten spatial dimension
b, c = x.shape[:2]
spatial_dims = x.shape[2:]
cnum = c // chi
x = x.view(b, cnum, chi, -1)

# CWA
base_mask = torch.zeros(chi, device=x.device)
base_mask[:ceil(chi / 2)] = 1
chan_mask = torch.zeros((cnum, chi), device=x.device)
for idx in range(cnum):
chan_mask[idx] = base_mask[torch.randperm(chi)]
discr_out = x * chan_mask.view(1, cnum, chi, 1)
# CCMP
discr_out = discr_out.max(dim=2).values
discr_out = discr_out.view(b, cnum, *spatial_dims)
# Weight casting
if isinstance(weight, torch.Tensor) and weight.type() != x.data.type():
weight = weight.type_as(x.data)

discr_loss = F.cross_entropy(discr_out, target, weight, ignore_index=ignore_index, reduction=reduction)

# Softmax
div_out = F.softmax(x, dim=-1)
# CCMP
div_out = div_out.max(dim=2).values

diversity_loss = div_out.mean(dim=1)

if reduction == 'sum':
diversity_loss = diversity_loss.sum()
elif reduction == 'mean':
diversity_loss = diversity_loss.mean()
else:
diversity_loss = diversity_loss.view(b, *spatial_dims)

return discr_loss - alpha * diversity_loss


def _xcorr2d(
fn: Callable[[Tensor, Tensor], Tensor],
x: Tensor,
Expand Down
38 changes: 35 additions & 3 deletions holocron/nn/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .. import functional as F

__all__ = ['FocalLoss', 'MultiLabelCrossEntropy', 'LabelSmoothingCrossEntropy', 'ComplementCrossEntropy',
'MixupLoss', 'ClassBalancedWrapper']
'MixupLoss', 'ClassBalancedWrapper', 'MutualChannelLoss']


class _Loss(nn.Module):
Expand Down Expand Up @@ -79,8 +79,8 @@ class MultiLabelCrossEntropy(_Loss):
reduction (str, optional): type of reduction to apply to the final loss
"""

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

def forward(self, x: Tensor, target: Tensor) -> Tensor:
return F.multilabel_cross_entropy(x, target, self.weight, self.ignore_index, self.reduction)
Expand Down Expand Up @@ -194,3 +194,35 @@ def forward(self, x: Tensor, target: Tensor) -> Tensor:

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.criterion.__repr__()}, beta={self.beta})"


class MutualChannelLoss(_Loss):
"""Implements the mutual channel loss from
`"The Devil is in the Channels: Mutual-Channel Loss for Fine-Grained Image Classification"
<https://arxiv.org/pdf/2002.04264.pdf>`_.
Args:
weight (torch.Tensor[K], optional): class weight for loss computation
ignore_index (int, optional): specifies target value that is ignored and do not contribute to gradient
reduction (str, optional): type of reduction to apply to the final loss
chi (in, optional): num of features per class
alpha (float, optional): diversity factor
"""

def __init__(
self,
weight: Optional[Union[float, List[float], Tensor]] = None,
ignore_index: int = -100,
reduction: str = 'mean',
chi: int = 2,
alpha: float = 1,
) -> None:
super().__init__(weight, ignore_index, reduction)
self.chi = chi
self.alpha = alpha

def forward(self, x: Tensor, target: Tensor) -> Tensor:
return F.mutual_channel_loss(x, target, self.weight, self.ignore_index, self.reduction, self.chi, self.alpha)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(reduction='{self.reduction}', chi={self.chi}, alpha={self.alpha})"
44 changes: 40 additions & 4 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,22 @@ def _test_activation_function(self, name, input_shape):
if kwargs.get('inplace', False):
self.assertEqual(x.data_ptr(), out.data_ptr())

def _test_loss_function(self, name, same_loss=0, multi_label=False):
def _test_loss_function(self, name, same_loss=0., multi_label=False):

num_batches = 2
num_classes = 4
# 4 classes
x = torch.ones(num_batches, num_classes, requires_grad=True)
x[:, 0, ...] = 10

loss_fn = F.__dict__[name]

# Identical target
if multi_label:
target = torch.zeros_like(x)
target[:, 0] = 1.
else:
target = torch.zeros(num_batches, dtype=torch.long)
loss_fn = F.__dict__[name]
self.assertAlmostEqual(loss_fn(x, target).item(), same_loss, places=3)
self.assertTrue(torch.allclose(loss_fn(x, target, reduction='none'),
same_loss * torch.ones(num_batches, dtype=x.dtype),
Expand Down Expand Up @@ -130,6 +131,40 @@ def test_multilabel_cross_entropy(self):
self.assertAlmostEqual(F.multilabel_cross_entropy(x, target).item(),
nn.functional.cross_entropy(x, target.argmax(dim=1)).item(), places=5)

def test_mc_loss(self):

num_batches = 2
num_classes = 4
chi = 2
# 4 classes
x = torch.ones(num_batches, chi * num_classes)
x[:, 0, ...] = 10
target = torch.zeros(num_batches, dtype=torch.long)

mod = nn.Linear(chi * num_classes, chi * num_classes)

# Check backprop
for reduction in ['mean', 'sum', 'none']:
for p in mod.parameters():
p.grad = None
train_loss = F.mutual_channel_loss(mod(x), target, ignore_index=0, reduction=reduction)
if reduction == 'none':
self.assertEqual(train_loss.shape, (num_batches,))
train_loss = train_loss.sum()
train_loss.backward()
self.assertIsInstance(mod.weight.grad, torch.Tensor)

# Check type casting of weights
for p in mod.parameters():
p.grad = None
class_weights = torch.ones(num_classes, dtype=torch.float16)
ignore_index = 0

criterion = loss.MutualChannelLoss(weight=class_weights, ignore_index=ignore_index, chi=chi)
train_loss = criterion(mod(x), target)
train_loss.backward()
self.assertIsInstance(mod.weight.grad, torch.Tensor)

def _test_activation_module(self, name, input_shape):
module = activation.__dict__[name]

Expand All @@ -156,11 +191,12 @@ def _test_loss_module(self, name, fn_name, multi_label=False):

num_batches = 2
num_classes = 4
x = torch.rand(num_batches, num_classes, 20, 20)
x_class_factor = 2 if fn_name == 'mutual_channel_loss' else 1
x = torch.rand(num_batches, x_class_factor * num_classes, 20, 20)

# Identical target
if multi_label:
target = torch.rand(x.shape)
target = torch.rand(num_batches, num_classes, 20, 20)
else:
target = (num_classes * torch.rand(num_batches, 20, 20)).to(torch.long)

Expand Down

0 comments on commit d41610b

Please sign in to comment.