Skip to content

Commit

Permalink
Merge pull request #74 from BloodAxe/develop
Browse files Browse the repository at this point in the history
Release 0.5.1
  • Loading branch information
BloodAxe authored Jun 27, 2022
2 parents 8bc1cd1 + 8faed01 commit ee72463
Show file tree
Hide file tree
Showing 23 changed files with 581 additions and 75 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ jobs:
matrix.operating-system == 'ubuntu-latest' ||
matrix.operating-system == 'windows-latest'
run: >
pip install torch==1.8.1+cpu torchvision==0.9.1+cpu
pip install torch==1.10.1+cpu torchvision==0.11.2+cpu
-f https://download.pytorch.org/whl/torch_stable.html
- name: Install PyTorch on MacOS
if: matrix.operating-system == 'macos-latest'
run: pip install torch==1.8.1 torchvision==0.9.1
run: pip install torch==1.10.1 torchvision==0.11.2
- name: Install dependencies
run: pip install .[${{ matrix.pytorch-toolbelt-version }}]
- name: Install linters
Expand Down Expand Up @@ -59,6 +59,6 @@ jobs:
- name: Update pip
run: python -m pip install --upgrade pip
- name: Install Black
run: pip install black==20.8b1
run: pip install black==22.3.0
- name: Run Black
run: black --config=black.toml --check .
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ var/
.pytest_cache/
/tests/tta_eval.csv
/tests/tmp.onnx
/tests/test_plot_confusion_matrix.png
368 changes: 368 additions & 0 deletions notebooks/tiled_inference.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pytorch_toolbelt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import absolute_import

__version__ = "0.5.0"
__version__ = "0.5.1"
2 changes: 1 addition & 1 deletion pytorch_toolbelt/datasets/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def __getitem__(self, index):

if self.need_supervision_masks:
for i in range(1, 6):
stride = 2 ** i
stride = 2**i
mask = block_reduce(mask, (2, 2), partial(_block_reduce_dominant_label))
sample[name_for_stride(TARGET_MASK_KEY, stride)] = self.make_target(mask)

Expand Down
41 changes: 31 additions & 10 deletions pytorch_toolbelt/inference/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,27 @@
from ..utils.support import pytorch_toolbelt_deprecated

__all__ = [
"geometric_mean",
"harmonic_mean",
"logodd_mean",
"pad_image_tensor",
"torch_fliplr",
"torch_flipud",
"torch_none",
"torch_rot180",
"torch_rot270",
"torch_rot90",
"torch_rot90_cw",
"torch_rot90_ccw",
"torch_transpose_rot90_cw",
"torch_transpose_rot90_ccw",
"torch_rot90_ccw_transpose",
"torch_rot90_cw",
"torch_rot90_cw_transpose",
"torch_rot180",
"torch_rot270",
"torch_fliplr",
"torch_flipud",
"torch_transpose",
"torch_transpose2",
"torch_transpose_",
"pad_image_tensor",
"torch_transpose_rot90_ccw",
"torch_transpose_rot90_cw",
"unpad_image_tensor",
"unpad_xyxy_bboxes",
"geometric_mean",
"harmonic_mean",
]


Expand Down Expand Up @@ -240,3 +241,23 @@ def harmonic_mean(x: Tensor, dim: int, eps: float = 1e-6) -> Tensor:
x = torch.mean(x, dim=dim)
x = torch.reciprocal(x.clamp_min(eps))
return x


def logodd_mean(x: Tensor, dim: int, eps: float = 1e-6) -> Tensor:
"""
Compute log-odd mean along given dimension.
logodd = log(p / (1 - p))
This implementation assume values are in range [0, 1] (Probabilities)
Args:
x: Input tensor of arbitrary shape
dim: Dimension to reduce
Returns:
Tensor
"""
x = x.clamp(min=eps, max=1.0 - eps)
x = torch.log(x / (1 - x))
x = torch.mean(x, dim=dim)
x = torch.exp(x) / (1 + torch.exp(x))
return x
12 changes: 5 additions & 7 deletions pytorch_toolbelt/inference/tiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class ImageSlicer:
Helper class to slice image into tiles and merge them back
"""

def __init__(self, image_shape, tile_size, tile_step=0, image_margin=0, weight="mean"):
def __init__(self, image_shape: Tuple[int, int], tile_size, tile_step=0, image_margin=0, weight="mean"):
"""
:param image_shape: Shape of the source image (H, W)
Expand Down Expand Up @@ -122,12 +122,6 @@ def __init__(self, image_shape, tile_size, tile_step=0, image_margin=0, weight="
else:
margin_left = margin_right = margin_top = margin_bottom = image_margin

if (self.image_width + margin_left + margin_right) % self.tile_size[1] != 0:
raise ValueError()

if (self.image_height + margin_top + margin_bottom) % self.tile_size[0] != 0:
raise ValueError()

self.margin_left = margin_left
self.margin_right = margin_right
self.margin_top = margin_top
Expand Down Expand Up @@ -337,6 +331,10 @@ def integrate_batch(self, batch: torch.Tensor, crop_coords):
if batch.device != self.image.device:
batch = batch.to(device=self.image.device)

# Ensure that input batch dtype match the target dtyle of the accumulator
if batch.dtype != self.image.dtype:
batch = batch.type_as(self.image)

for tile, (x, y, tile_width, tile_height) in zip(batch, crop_coords):
self.image[:, y : y + tile_height, x : x + tile_width] += tile * self.weight
self.norm_mask[:, y : y + tile_height, x : x + tile_width] += self.weight
Expand Down
2 changes: 2 additions & 0 deletions pytorch_toolbelt/inference/tta.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def _deaugment_averaging(x: Tensor, reduction: MaybeStrOrCallable) -> Tensor:
x = F.geometric_mean(x, dim=0)
elif reduction in {"hmean", "harmonic_mean"}:
x = F.harmonic_mean(x, dim=0)
elif reduction == "logodd":
x = F.logodd_mean(x, dim=0)
elif callable(reduction):
x = reduction(x, dim=0)
elif reduction in {None, "None", "none"}:
Expand Down
10 changes: 9 additions & 1 deletion pytorch_toolbelt/modules/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,15 @@ def instantiate_activation_block(activation_name: str, **kwargs) -> nn.Module:

act_params = {}

if "inplace" in kwargs and activation_name in {ACT_RELU, ACT_RELU6, ACT_LEAKY_RELU, ACT_SELU, ACT_CELU, ACT_ELU}:
if "inplace" in kwargs and activation_name in {
ACT_RELU,
ACT_RELU6,
ACT_LEAKY_RELU,
ACT_SELU,
ACT_SILU,
ACT_CELU,
ACT_ELU,
}:
act_params["inplace"] = kwargs["inplace"]

if "slope" in kwargs and activation_name in {ACT_LEAKY_RELU}:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_toolbelt/modules/decoders/unet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels
super().__init__()

if not isinstance(decoder_features, list):
decoder_features = [decoder_features * (2 ** i) for i in range(len(feature_maps))]
decoder_features = [decoder_features * (2**i) for i in range(len(feature_maps))]

blocks = []
for block_index, in_enc_features in enumerate(feature_maps[:-1]):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_toolbelt/modules/dropblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _compute_block_mask(self, mask):
return block_mask, keeped

def _compute_gamma(self, x):
return self.drop_prob / (self.block_size ** 2)
return self.drop_prob / (self.block_size**2)


class DropBlock3D(DropBlock2D):
Expand Down Expand Up @@ -131,7 +131,7 @@ def _compute_block_mask(self, mask):
return block_mask

def _compute_gamma(self, x):
return self.drop_prob / (self.block_size ** 3)
return self.drop_prob / (self.block_size**3)


class DropBlockScheduled(nn.Module):
Expand Down
6 changes: 3 additions & 3 deletions pytorch_toolbelt/modules/encoders/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, at
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.scale = qk_scale or head_dim**-0.5

# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
Expand Down Expand Up @@ -587,7 +587,7 @@ def __init__(
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
dim=int(embed_dim * 2**i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
Expand All @@ -604,7 +604,7 @@ def __init__(
)
self.layers.append(layer)

num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
self.num_features = num_features

# add a norm layer for each output
Expand Down
4 changes: 2 additions & 2 deletions pytorch_toolbelt/modules/encoders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def __init__(
if pool_block is None:
pool_block = partial(nn.MaxPool2d, kernel_size=2, stride=2)

feature_maps = [out_channels * (growth_factor ** i) for i in range(num_layers)]
strides = [2 ** i for i in range(num_layers)]
feature_maps = [out_channels * (growth_factor**i) for i in range(num_layers)]
strides = [2**i for i in range(num_layers)]
super().__init__(feature_maps, strides, layers=list(range(num_layers)))

input_filters = in_channels
Expand Down
4 changes: 2 additions & 2 deletions pytorch_toolbelt/modules/ocnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def forward(self, x): # skipcq: PYL-W0221
key = self.f_key(x).view(batch_size, self.key_channels, -1)

sim_map = torch.matmul(query, key)
sim_map = (self.key_channels ** -0.5) * sim_map
sim_map = (self.key_channels**-0.5) * sim_map
sim_map = F.softmax(sim_map, dim=-1)

context = torch.matmul(sim_map, value)
Expand Down Expand Up @@ -300,7 +300,7 @@ def forward(self, x):
key_local = key_local.contiguous().view(batch_size, self.key_channels, -1)

sim_map = torch.matmul(query_local, key_local)
sim_map = (self.key_channels ** -0.5) * sim_map
sim_map = (self.key_channels**-0.5) * sim_map
sim_map = F.softmax(sim_map, dim=-1)

context_local = torch.matmul(sim_map, value_local)
Expand Down
6 changes: 3 additions & 3 deletions pytorch_toolbelt/modules/upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ def icnr_init(tensor: torch.Tensor, upscale_factor=2, initializer=nn.init.kaimin
.. _Checkerboard artifact free sub-pixel convolution:
https://arxiv.org/abs/1707.02937
"""
new_shape = [int(tensor.shape[0] / (upscale_factor ** 2))] + list(tensor.shape[1:])
new_shape = [int(tensor.shape[0] / (upscale_factor**2))] + list(tensor.shape[1:])
subkernel = torch.zeros(new_shape)
subkernel = initializer(subkernel)
subkernel = subkernel.transpose(0, 1)

subkernel = subkernel.contiguous().view(subkernel.shape[0], subkernel.shape[1], -1)

kernel = subkernel.repeat(1, 1, upscale_factor ** 2)
kernel = subkernel.repeat(1, 1, upscale_factor**2)

transposed_shape = [tensor.shape[1]] + [tensor.shape[0]] + list(tensor.shape[2:])
kernel = kernel.contiguous().view(transposed_shape)
Expand All @@ -77,7 +77,7 @@ class DepthToSpaceUpsample2d(nn.Module):

def __init__(self, in_channels: int, out_channels: int, scale_factor: int = 2):
super().__init__()
n = 2 ** scale_factor
n = 2**scale_factor
self.conv = nn.Conv2d(in_channels, out_channels * n, kernel_size=3, padding=1, bias=False)
self.out_channels = out_channels
self.shuffle = nn.PixelShuffle(upscale_factor=scale_factor)
Expand Down
6 changes: 5 additions & 1 deletion pytorch_toolbelt/optimization/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def get_optimizable_parameters(model: nn.Module) -> Iterator[nn.Parameter]:
return filter(lambda x: x.requires_grad, model.parameters())


def freeze_model(module: nn.Module, freeze_parameters: Optional[bool] = True, freeze_bn: Optional[bool] = True):
def freeze_model(
module: nn.Module, freeze_parameters: Optional[bool] = True, freeze_bn: Optional[bool] = True
) -> nn.Module:
"""
Change 'requires_grad' value for module and it's child modules and
optionally freeze batchnorm modules.
Expand All @@ -70,3 +72,5 @@ def freeze_model(module: nn.Module, freeze_parameters: Optional[bool] = True, fr
for m in module.modules():
if isinstance(m, bn_types):
module.track_running_stats = not freeze_bn

return module
4 changes: 2 additions & 2 deletions pytorch_toolbelt/optimization/lr_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_lr(self):
def compute_lr(base_lr):
return (
self.eta_min
+ (base_lr * self.gamma ** self.last_epoch - self.eta_min)
+ (base_lr * self.gamma**self.last_epoch - self.eta_min)
* (1 + math.cos(math.pi * self.last_epoch / self.T_max))
/ 2
)
Expand Down Expand Up @@ -110,7 +110,7 @@ def get_lr(self):

return [
self.eta_min
+ (base_lr * self.gamma ** self.last_epoch - self.eta_min)
+ (base_lr * self.gamma**self.last_epoch - self.eta_min)
* (1 + math.cos(math.pi * self.T_cur / self.T_i))
/ 2
for base_lr in self.base_lrs
Expand Down
2 changes: 1 addition & 1 deletion pytorch_toolbelt/utils/catalyst/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def _f1_from_confusion_matrix(
true_sum = np.array([true_sum.sum()])

# Finally, we have all our sufficient statistics. Divide! #
beta2 = beta ** 2
beta2 = beta**2

# Divide, and on zero-division, set scores and/or warn according to
# zero_division:
Expand Down
16 changes: 15 additions & 1 deletion pytorch_toolbelt/utils/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"find_images_in_dir",
"find_in_dir",
"find_in_dir_glob",
"find_subdirectories_in_dir",
"has_ext",
"has_image_ext",
"id_from_fname",
Expand Down Expand Up @@ -45,6 +46,19 @@ def find_in_dir(dirname: str) -> List[str]:
return [os.path.join(dirname, fname) for fname in sorted(os.listdir(dirname))]


def find_subdirectories_in_dir(dirname: str) -> List[str]:
"""
Retrieve list of subdirectories (non-recursive) in the given directory.
Args:
dirname: Target directory name
Returns:
Sorted list of absolute paths to directories
"""
all_entries = find_in_dir(dirname)
return [entry for entry in all_entries if os.path.isdir(entry)]


def find_in_dir_with_ext(dirname: str, extensions: Union[str, List[str]]) -> List[str]:
return [os.path.join(dirname, fname) for fname in sorted(os.listdir(dirname)) if has_ext(fname, extensions)]

Expand Down Expand Up @@ -107,7 +121,7 @@ def read_rgb_image(fname: Union[str, Path]) -> np.ndarray:
if type(fname) != str:
fname = str(fname)

image = cv2.imread(fname, cv2.IMREAD_UNCHANGED)
image = cv2.imread(fname, cv2.IMREAD_COLOR)
if image is None:
raise IOError(f'Cannot read image "{fname}"')

Expand Down
Loading

0 comments on commit ee72463

Please sign in to comment.