Skip to content

Commit a910c64

Browse files
committed
fix conflicts and add MaskDilatorStream
1 parent 3fb4ae5 commit a910c64

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

pipelines.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from steps.preprocessing.misc import XYSplit
66
from utils import squeeze_inputs
77
from models import PyTorchUNet, PyTorchUNetStream
8-
from postprocessing import Resizer, CategoryMapper, MulticlassLabeler, MaskDilator\
9-
ResizerStream, CategoryMapperStream, MulticlassLabelerStream
8+
from postprocessing import Resizer, CategoryMapper, MulticlassLabeler, MaskDilator, \
9+
ResizerStream, CategoryMapperStream, MulticlassLabelerStream, MaskDilatorStream
1010

1111

1212
def unet(config, train_mode):
@@ -27,7 +27,7 @@ def unet(config, train_mode):
2727
mask_postprocessed = mask_postprocessing(unet, config, save_output=save_output)
2828
if config.postprocessor["dilate_selem_size"] > 0:
2929
mask_postprocessed = Step(name='mask_dilation',
30-
transformer=MaskDilator(**config.postprocessor),
30+
transformer=MaskDilatorStream(**config.postprocessor) if config.execution.stream_mode else MaskDilator(**config.unet),
3131
input_steps=[mask_postprocessed],
3232
adapter={'images': ([(mask_postprocessed.name, 'categorized_images')]),
3333
},

postprocessing.py

+9
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,15 @@ def _transform(self, images):
7676
yield categorize_image(image)
7777

7878

79+
class MaskDilatorStream(BaseTransformer):
80+
def transform(self, images):
81+
return {'categorized_images': self._transform(images)}
82+
83+
def _transform(self, images):
84+
for image in tqdm(images):
85+
yield dilate_image(image)
86+
87+
7988
def label(mask):
8089
labeled, nr_true = ndi.label(mask)
8190
return labeled

0 commit comments

Comments
 (0)