Skip to content

Commit

Permalink
Black
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed Nov 21, 2024
1 parent 29dc3ea commit 66e48d7
Show file tree
Hide file tree
Showing 18 changed files with 27 additions and 28 deletions.
6 changes: 3 additions & 3 deletions pytorch_toolbelt/inference/ensembling.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ def __init__(self, models: List[nn.Module], reduction: str = "mean", outputs: Op
def forward(self, *input, **kwargs): # skipcq: PYL-W0221
outputs = [model(*input, **kwargs) for model in self.models]
output_is_dict = isinstance(outputs[0], dict)
output_is_list = isinstance(outputs[0], (list, tuple))
output_is_list = isinstance(outputs[0], (list, tuple)) # noqa

if self.return_some_outputs:
keys = self.outputs
elif isinstance(outputs[0], dict):
elif output_is_dict:
keys = outputs[0].keys()
elif isinstance(outputs[0], (list, tuple)):
elif output_is_list:
keys = list(range(len(outputs[0])))
elif torch.is_tensor(outputs[0]):
keys = None
Expand Down
1 change: 1 addition & 0 deletions pytorch_toolbelt/inference/tiles.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Implementation of tile-based inference allowing to predict huge images that does not fit into GPU memory entirely
in a sliding-window fashion and merging prediction mask back to full-resolution.
"""

import dataclasses
import math
from typing import List, Iterable, Tuple, Union, Sequence
Expand Down
1 change: 1 addition & 0 deletions pytorch_toolbelt/inference/tta.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Despite this is called test-time augmentation, these method can be used at training time as well since all
transformation written in PyTorch and respect gradients flow.
"""

from collections import defaultdict
from functools import partial
from typing import Tuple, List, Optional, Union, Callable, Dict, Mapping
Expand Down
6 changes: 4 additions & 2 deletions pytorch_toolbelt/losses/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def focal_loss_with_logits(
ignore_index=None,
activation: str = "sigmoid",
softmax_dim: Optional[int] = None,
class_weights: Optional[torch.Tensor] = None
class_weights: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Compute binary focal loss between target and output logits.
Expand Down Expand Up @@ -70,7 +70,9 @@ def focal_loss_with_logits(
if reduced_threshold is None:
focal_term = (1.0 - pt).pow(gamma)
else:
focal_term = ((1.0 - pt) / (1 - reduced_threshold)).pow(gamma) #the focal term continuity breaks when reduced_threshold not equal to 0.5. At pt equal to reduced_threshold, the value of piecewise function of focal term should be 1 from both sides .
focal_term = ((1.0 - pt) / (1 - reduced_threshold)).pow(
gamma
) # the focal term continuity breaks when reduced_threshold not equal to 0.5. At pt equal to reduced_threshold, the value of piecewise function of focal term should be 1 from both sides .
focal_term = torch.masked_fill(focal_term, pt < reduced_threshold, 1)

loss = focal_term * ce_loss
Expand Down
2 changes: 1 addition & 1 deletion pytorch_toolbelt/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@
from .initialization import *
from .normalization import *

from .heads import *
from .heads import *
2 changes: 1 addition & 1 deletion pytorch_toolbelt/modules/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def get_activation_block(activation_name: str):
ACT_SWISH: Swish,
ACT_SWISH_NAIVE: SwishNaive,
ACT_SIGMOID: nn.Sigmoid,
ACT_SOFTMAX: nn.Softmax
ACT_SOFTMAX: nn.Softmax,
}

return ACTIVATIONS[activation_name.lower()]
Expand Down
1 change: 1 addition & 0 deletions pytorch_toolbelt/modules/backbone/senet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
ResNet code gently borrowed from
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
"""

from __future__ import print_function, division, absolute_import

import math
Expand Down
1 change: 1 addition & 0 deletions pytorch_toolbelt/modules/encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Encodes listed here provides easy way to swap backbone of classification/segmentation/detection model.
"""

from .common import *
from .densenet import *
from .hrnet import *
Expand Down
1 change: 1 addition & 0 deletions pytorch_toolbelt/modules/encoders/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Encodes listed here provides easy way to swap backbone of classification/segmentation/detection model.
"""

import math
import warnings
from typing import List, Union, Tuple, Iterable, Any
Expand Down
1 change: 1 addition & 0 deletions pytorch_toolbelt/modules/encoders/seresnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Encodes listed here provides easy way to swap backbone of classification/segmentation/detection model.
"""

from typing import List

import torch
Expand Down
5 changes: 3 additions & 2 deletions pytorch_toolbelt/modules/encoders/timm/maxvit.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from .common import GenericTimmEncoder


class MaxVitEncoder(GenericTimmEncoder):
def __init__(self, model_name:str, pretrained=True, **kwargs):
def __init__(self, model_name: str, pretrained=True, **kwargs):
super().__init__(model_name, pretrained=pretrained, **kwargs)

def change_input_channels(self, input_channels: int, mode="auto", **kwargs):
from pytorch_toolbelt.modules import make_n_channel_input

self.encoder.stem.conv1 = make_n_channel_input(self.encoder.stem.conv1, input_channels)
return self
return self
1 change: 1 addition & 0 deletions pytorch_toolbelt/modules/encoders/timm/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def change_input_channels(self, input_channels: int, mode="auto", **kwargs):
self.encoder.conv1[0] = make_n_channel_input(self.encoder.conv1[0], input_channels, mode=mode, **kwargs)
return self


class TimmResnet50D(GenericTimmEncoder):
def __init__(
self, pretrained=True, layers=None, activation=ACT_RELU, first_conv_stride_one: bool = False, **kwargs
Expand Down
2 changes: 1 addition & 1 deletion pytorch_toolbelt/modules/heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from .deep_supervision import *
from .hypercolumn import *
from .segformer_head import *
from .classification_heads import *
from .classification_heads import *
9 changes: 3 additions & 6 deletions pytorch_toolbelt/modules/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ class HasInputFeaturesSpecification(Protocol):
"""

@torch.jit.unused
def get_input_spec(self) -> FeatureMapsSpecification:
...
def get_input_spec(self) -> FeatureMapsSpecification: ...


class HasOutputFeaturesSpecification(Protocol):
Expand All @@ -71,8 +70,7 @@ class HasOutputFeaturesSpecification(Protocol):
"""

@torch.jit.unused
def get_output_spec(self) -> FeatureMapsSpecification:
...
def get_output_spec(self) -> FeatureMapsSpecification: ...


class AbstractEncoder(nn.Module, HasOutputFeaturesSpecification):
Expand Down Expand Up @@ -108,8 +106,7 @@ def __init__(self, input_spec: FeatureMapsSpecification):
@abstractmethod
def forward(
self, feature_maps: List[Tensor], output_size: Union[Tuple[int, int], torch.Size, None] = None
) -> Union[Tensor, Tuple[Tensor, ...], List[Tensor], Mapping[str, Tensor]]:
...
) -> Union[Tensor, Tuple[Tensor, ...], List[Tensor], Mapping[str, Tensor]]: ...

@torch.jit.unused
def apply_to_final_layer(self, func: Callable[[nn.Module], None]):
Expand Down
11 changes: 1 addition & 10 deletions pytorch_toolbelt/modules/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,16 +200,7 @@ def forward(self, x: Tensor) -> Tensor:

def __repr__(self):
p = torch.softplus(self.p) + 1
return (
self.__class__.__name__
+ "("
+ "p="
+ "{:.4f}".format(p.item())
+ ", "
+ "eps="
+ str(self.eps)
+ ")"
)
return self.__class__.__name__ + "(" + "p=" + "{:.4f}".format(p.item()) + ", " + "eps=" + str(self.eps) + ")"


class GlobalMaxAvgPooling2d(nn.Module):
Expand Down
1 change: 1 addition & 0 deletions pytorch_toolbelt/utils/random_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utility functions to make your experiments reproducible
"""

import random
import warnings

Expand Down
2 changes: 1 addition & 1 deletion pytorch_toolbelt/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def container_to_tensor(value: Union[np.ndarray, List, Tuple, Mapping, Any]):
cls = type(value)
return cls((k, container_to_tensor(v)) for k, v in value.items())

raise ValueError(f"Unsupported container type")
raise ValueError(f"Unsupported container type {type(value)}")


def image_to_tensor(image: np.ndarray, dummy_channels_dim=True) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def get_test_requirements():
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Software Development :: Libraries :: Application Frameworks"
"Topic :: Software Development :: Libraries :: Application Frameworks",
# "Private :: Do Not Upload"
],
)

0 comments on commit 66e48d7

Please sign in to comment.