diff --git a/pytorch_toolbelt/losses/functional.py b/pytorch_toolbelt/losses/functional.py index 1ef62f055..daae4a052 100644 --- a/pytorch_toolbelt/losses/functional.py +++ b/pytorch_toolbelt/losses/functional.py @@ -70,7 +70,7 @@ def focal_loss_with_logits( if reduced_threshold is None: focal_term = (1.0 - pt).pow(gamma) else: - focal_term = ((1.0 - pt) / reduced_threshold).pow(gamma) + focal_term = ((1.0 - pt) / (1 - reduced_threshold)).pow(gamma) #the focal term continuity breaks when reduced_threshold not equal to 0.5. At pt equal to reduced_threshold, the value of piecewise function of focal term should be 1 from both sides . focal_term = torch.masked_fill(focal_term, pt < reduced_threshold, 1) loss = focal_term * ce_loss