2121from  typing  import  Any , Dict , List , Tuple 
2222from  urllib  import  request 
2323
24- import  matplotlib . pyplot   as   plt 
24+ import  cv2 
2525import  numpy  as  np 
2626import  pandas  as  pd 
2727from  Granny .Analyses .Analysis  import  Analysis 
3333from  Granny .Models .IO .RGBImageFile  import  RGBImageFile 
3434from  Granny .Models .Values .FileNameValue  import  FileNameValue 
3535from  Granny .Models .Values .ImageListValue  import  ImageListValue 
36- from  matplotlib  import  patches 
3736from  numpy .typing  import  NDArray 
3837
3938
@@ -42,7 +41,7 @@ class SegmentationConfig:
4241    MODELS : Dict [str , Dict [str , str ]] =  {
4342        "pome_fruit-v1_0" : {
4443            "full_name" : "granny-v1_0-pome_fruit-v1_0.pt" ,
45-             "url" : "https://osf.io/dqzyn /download/" ,
44+             "url" : "https://osf.io/vyfhm /download/" ,
4645        }
4746    }
4847
@@ -102,12 +101,12 @@ def __init__(self):
102101                "tray_infos" ,
103102            )
104103        )
105-         self .full_images  =  ImageListValue (
104+         self .masked_images  =  ImageListValue (
106105            "f_img" ,
107106            "full_masked_image" ,
108107            "The output directory where the full-masked images are written." ,
109108        )
110-         self .full_images .setValue (
109+         self .masked_images .setValue (
111110            os .path .join (
112111                os .curdir ,
113112                "results" ,
@@ -126,7 +125,6 @@ def _getModelUrl(self, model_name: str):
126125        model_url  =  "" 
127126        try :
128127            model_url  =  self .models [model_name ]["url" ]
129-             print (f"Model URL: { model_url }  " )
130128        except  KeyError :
131129            print (f"Key '{ model_name }  ' not found in configuration." )
132130        return  model_url 
@@ -163,7 +161,7 @@ def _segmentInstances(self, image: NDArray[np.uint8]) -> List[Any]:
163161
164162        return  results 
165163
166-     def  _writeMaskedImage (self , tray_image : Image ) ->  None :
164+     def  _extractMaskedImage (self , tray_image : Image ) ->  Image :
167165        """""" 
168166        [result ] =  tray_image .getSegmentationResults ()
169167        masks  =  result .masks .cpu ()
@@ -179,9 +177,9 @@ def _writeMaskedImage(self, tray_image: Image) -> None:
179177        hsv  =  [(i  /  num_instances , 1 , brightness ) for  i  in  range (num_instances )]
180178        colors  =  list (map (lambda  c : colorsys .hsv_to_rgb (* c ), hsv ))
181179        random .shuffle (colors )
182-         _ , ax  =  plt .subplots ()
183180        for  i  in  range (num_instances ):
184181            mask  =  masks .data [i ].numpy ()
182+             (r , g , b ) =  colors [i ]
185183            for  c  in  range (3 ):
186184                result [:, :, c ] =  np .where (
187185                    mask  ==  1 ,
@@ -191,31 +189,21 @@ def _writeMaskedImage(self, tray_image: Image) -> None:
191189
192190            x1 , y1 , x2 , y2  =  coords [i ]
193191            x1 , y1 , x2 , y2  =  int (x1 ), int (y1 ), int (x2 ), int (y2 )
194-             ax . text ( 
195-                  x1 ,  y1   +   10 ,  "{:.3f}" . format ( confs [ i ]),  color = "w" ,  size = 7 ,  backgroundcolor = "none" 
196-             ) 
197-             p   =   patches . Rectangle ( 
192+             cv2 . rectangle ( result , ( x1 ,  y1 ), ( x2 ,  y2 ), ( r   *   255 ,  g   *   255 ,  b   *   255 ),  5 ) 
193+             cv2 . putText ( 
194+                  result , 
195+                  "{:.3f}" . format ( confs [ i ]), 
198196                (x1 , y1 ),
199-                 x2  -  x1 ,
200-                 y2  -  y1 ,
201-                 linewidth = 1 ,
202-                 edgecolor = colors [i ],
203-                 facecolor = "none" ,
204-                 linestyle = "dashed" ,
197+                 fontFace = cv2 .FONT_HERSHEY_SIMPLEX ,
198+                 fontScale = 2 ,
199+                 color = (255 , 255 , 255 ),
200+                 thickness = 3 ,
205201            )
206-             ax .add_patch (p )
207-         plt .axis ("off" )
208-         plt .imshow (result )
209-         plt .tight_layout ()
210-         plt .savefig (
211-             os .path .join (
212-                 self .full_images .getValue (),
213-                 tray_image .getImageName (),
214-             ),
215-             bbox_inches = "tight" ,
216-             pad_inches = 0 ,
217-             dpi = 300 ,
202+         image_instance : Image  =  RGBImage (
203+             pathlib .Path (tray_image .getImageName ()).stem  +  f"_masked_image"  +  ".png" 
218204        )
205+         image_instance .setImage (result )
206+         return  image_instance 
219207
220208    def  _sortInstances (self , boxes : NDArray [np .float32 ], img_shape : Tuple [int , int ]):
221209        """ 
@@ -370,6 +358,7 @@ def performAnalysis(self) -> List[Image]:
370358        # performs segmentation on each image one-by-one 
371359        segmented_images : List [Image ] =  []
372360        tray_images : List [Image ] =  []
361+         masked_images : List [Image ] =  []
373362        for  image_instance  in  self .images :
374363            # set ImageIO with specific file path 
375364            self .image_io .setFilePath (image_instance .getFilePath ())
@@ -389,22 +378,30 @@ def performAnalysis(self) -> List[Image]:
389378            image_instance .setSegmentationResults (results = result )
390379
391380            try :
392-                 # extracts individual instances and tray information  
381+                 # extracts individual instances 
393382                image_instances  =  self ._extractImage (image_instance )
394-                 tray_info  =  self ._extractTrayInfo (image_instance )
383+                 # and tray information 
384+                 tray_infos  =  self ._extractTrayInfo (image_instance )
385+                 # and masked image 
386+                 masked_image  =  self ._extractMaskedImage (image_instance )
387+ 
388+                 # save to list for output 
395389                segmented_images .extend (image_instances )
396-                 tray_images .extend (tray_info )
397-                 # writes masked image 
398-                 self ._writeMaskedImage (image_instance )
390+                 tray_images .extend (tray_infos )
391+                 masked_images .append (masked_image )
399392            except :
400393                AttributeError ("Error with the results." )
401394
402395        # 1. sets the output ImageListValue with the list of segmented images 
403396        # 2. writes the segmented images to "segmented_images" folder 
404397        # 3. writes the tray information to "tray_info" folder 
398+         # 4. writes the full masked images to "full_masked_images" folder 
405399        self .seg_images .setImageList (segmented_images )
406400        self .seg_images .writeValue ()
407401
408402        self .tray_infos .setImageList (tray_images )
409403        self .tray_infos .writeValue ()
404+ 
405+         self .masked_images .setImageList (masked_images )
406+         self .masked_images .writeValue ()
410407        return  segmented_images 
0 commit comments