Skip to content

Commit

Permalink
Add alpha parameter to DiceFocalLoss (Project-MONAI#7841)
Browse files Browse the repository at this point in the history
Fixes Project-MONAI#7682.

### Description

This PR introduces the `alpha` parameter from `FocalLoss` into
`DiceFocalLoss`.

### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [x] New tests added to cover the changes.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Kyle Harrington <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Kerfoot <[email protected]>
Co-authored-by: YunLiu <[email protected]>
  • Loading branch information
4 people authored Jun 29, 2024
1 parent 06cbd70 commit 2f62b81
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
13 changes: 10 additions & 3 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,7 @@ class DiceFocalLoss(_Loss):
The details of Focal Loss is shown in ``monai.losses.FocalLoss``.
``gamma`` and ``lambda_focal`` are only used for the focal loss.
``include_background``, ``weight`` and ``reduction`` are used for both losses
``include_background``, ``weight``, ``reduction``, and ``alpha`` are used for both losses,
and other parameters are only used for dice loss.
"""
Expand All @@ -837,6 +837,7 @@ def __init__(
weight: Sequence[float] | float | int | torch.Tensor | None = None,
lambda_dice: float = 1.0,
lambda_focal: float = 1.0,
alpha: float | None = None,
) -> None:
"""
Args:
Expand Down Expand Up @@ -871,7 +872,8 @@ def __init__(
Defaults to 1.0.
lambda_focal: the trade-off weight value for focal loss. The value should be no less than 0.0.
Defaults to 1.0.
alpha: value of the alpha in the definition of the alpha-balanced Focal loss. The value should be in
[0, 1]. Defaults to None.
"""
super().__init__()
weight = focal_weight if focal_weight is not None else weight
Expand All @@ -890,7 +892,12 @@ def __init__(
weight=weight,
)
self.focal = FocalLoss(
include_background=include_background, to_onehot_y=False, gamma=gamma, weight=weight, reduction=reduction
include_background=include_background,
to_onehot_y=False,
gamma=gamma,
weight=weight,
alpha=alpha,
reduction=reduction,
)
if lambda_dice < 0.0:
raise ValueError("lambda_dice should be no less than 0.0.")
Expand Down
29 changes: 29 additions & 0 deletions tests/test_dice_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,35 @@ def test_script(self):
test_input = torch.ones(2, 1, 8, 8)
test_script_save(loss, test_input, test_input)

@parameterized.expand(
[
("sum_None_0.5_0.25", "sum", None, 0.5, 0.25),
("sum_weight_0.5_0.25", "sum", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25),
("sum_weight_tuple_0.5_0.25", "sum", (3, 2.0, 1), 0.5, 0.25),
("mean_None_0.5_0.25", "mean", None, 0.5, 0.25),
("mean_weight_0.5_0.25", "mean", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25),
("mean_weight_tuple_0.5_0.25", "mean", (3, 2.0, 1), 0.5, 0.25),
("none_None_0.5_0.25", "none", None, 0.5, 0.25),
("none_weight_0.5_0.25", "none", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25),
("none_weight_tuple_0.5_0.25", "none", (3, 2.0, 1), 0.5, 0.25),
]
)
def test_with_alpha(self, name, reduction, weight, lambda_focal, alpha):
size = [3, 3, 5, 5]
label = torch.randint(low=0, high=2, size=size)
pred = torch.randn(size)

common_params = {"include_background": True, "to_onehot_y": False, "reduction": reduction, "weight": weight}

dice_focal = DiceFocalLoss(gamma=1.0, lambda_focal=lambda_focal, alpha=alpha, **common_params)
dice = DiceLoss(**common_params)
focal = FocalLoss(gamma=1.0, alpha=alpha, **common_params)

result = dice_focal(pred, label)
expected_val = dice(pred, label) + lambda_focal * focal(pred, label)

np.testing.assert_allclose(result, expected_val, err_msg=f"Failed on case: {name}")


if __name__ == "__main__":
unittest.main()

0 comments on commit 2f62b81

Please sign in to comment.