Skip to content

Commit

Permalink
Fixed failing unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed Nov 21, 2024
1 parent fcae22a commit 4ed934e
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/test_tiles.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import torch
from pytorch_toolbelt.inference.tiles import ImageSlicer, CudaTileMerger
from pytorch_toolbelt.inference.tiles import ImageSlicer, TileMerger
from pytorch_toolbelt.utils.torch_utils import tensor_from_rgb_image, rgb_image_from_tensor, to_numpy
from torch import nn
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -32,7 +32,7 @@ def test_tiles_split_merge_non_dividable_cuda():
tiler = ImageSlicer(image.shape, tile_size=(1280, 1280), tile_step=(1280, 1280), weight="mean")
tiles = tiler.split(image)

merger = CudaTileMerger(tiler.target_shape, channels=image.shape[2], weight=tiler.weight)
merger = TileMerger(tiler.target_shape, channels=image.shape[2], weight=tiler.weight)
for tile, coordinates in zip(tiles, tiler.crops):
# Integrate as batch of size 1
merger.integrate_batch(tensor_from_rgb_image(tile).unsqueeze(0).float().cuda(), [coordinates])
Expand Down Expand Up @@ -72,7 +72,7 @@ def forward(self, input):

model = MaxChannelIntensity().eval().cuda()

merger = CudaTileMerger(tiler.target_shape, 1, tiler.weight)
merger = TileMerger(tiler.target_shape, 1, tiler.weight)
for tiles_batch, coords_batch in DataLoader(list(zip(tiles, tiler.crops)), batch_size=8, pin_memory=True):
tiles_batch = tiles_batch.float().cuda()
pred_batch = model(tiles_batch)
Expand Down

0 comments on commit 4ed934e

Please sign in to comment.