5
5
from steps .preprocessing .misc import XYSplit
6
6
from utils import squeeze_inputs
7
7
from models import PyTorchUNet , PyTorchUNetStream
8
- from postprocessing import Resizer , CategoryMapper , MulticlassLabeler , MaskDilator \
9
- ResizerStream , CategoryMapperStream , MulticlassLabelerStream
8
+ from postprocessing import Resizer , CategoryMapper , MulticlassLabeler , MaskDilator , \
9
+ ResizerStream , CategoryMapperStream , MulticlassLabelerStream , MaskDilatorStream
10
10
11
11
12
12
def unet (config , train_mode ):
@@ -27,7 +27,7 @@ def unet(config, train_mode):
27
27
mask_postprocessed = mask_postprocessing (unet , config , save_output = save_output )
28
28
if config .postprocessor ["dilate_selem_size" ] > 0 :
29
29
mask_postprocessed = Step (name = 'mask_dilation' ,
30
- transformer = MaskDilator (** config .postprocessor ),
30
+ transformer = MaskDilatorStream (** config .postprocessor ) if config . execution . stream_mode else MaskDilator ( ** config . unet ),
31
31
input_steps = [mask_postprocessed ],
32
32
adapter = {'images' : ([(mask_postprocessed .name , 'categorized_images' )]),
33
33
},
0 commit comments