Skip to content

Commit d26b839

Browse files
committed
Merge branch 'dev-MVC' of github.com:SystemsGenetics/granny into dev-MVC
2 parents 21cbfb1 + cbbdf5f commit d26b839

File tree

1 file changed

+32
-35
lines changed

1 file changed

+32
-35
lines changed

Granny/Analyses/Segmentation.py

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from typing import Any, Dict, List, Tuple
2222
from urllib import request
2323

24-
import matplotlib.pyplot as plt
24+
import cv2
2525
import numpy as np
2626
import pandas as pd
2727
from Granny.Analyses.Analysis import Analysis
@@ -33,7 +33,6 @@
3333
from Granny.Models.IO.RGBImageFile import RGBImageFile
3434
from Granny.Models.Values.FileNameValue import FileNameValue
3535
from Granny.Models.Values.ImageListValue import ImageListValue
36-
from matplotlib import patches
3736
from numpy.typing import NDArray
3837

3938

@@ -42,7 +41,7 @@ class SegmentationConfig:
4241
MODELS: Dict[str, Dict[str, str]] = {
4342
"pome_fruit-v1_0": {
4443
"full_name": "granny-v1_0-pome_fruit-v1_0.pt",
45-
"url": "https://osf.io/dqzyn/download/",
44+
"url": "https://osf.io/vyfhm/download/",
4645
}
4746
}
4847

@@ -102,12 +101,12 @@ def __init__(self):
102101
"tray_infos",
103102
)
104103
)
105-
self.full_images = ImageListValue(
104+
self.masked_images = ImageListValue(
106105
"f_img",
107106
"full_masked_image",
108107
"The output directory where the full-masked images are written.",
109108
)
110-
self.full_images.setValue(
109+
self.masked_images.setValue(
111110
os.path.join(
112111
os.curdir,
113112
"results",
@@ -126,7 +125,6 @@ def _getModelUrl(self, model_name: str):
126125
model_url = ""
127126
try:
128127
model_url = self.models[model_name]["url"]
129-
print(f"Model URL: {model_url}")
130128
except KeyError:
131129
print(f"Key '{model_name}' not found in configuration.")
132130
return model_url
@@ -163,7 +161,7 @@ def _segmentInstances(self, image: NDArray[np.uint8]) -> List[Any]:
163161

164162
return results
165163

166-
def _writeMaskedImage(self, tray_image: Image) -> None:
164+
def _extractMaskedImage(self, tray_image: Image) -> Image:
167165
""""""
168166
[result] = tray_image.getSegmentationResults()
169167
masks = result.masks.cpu()
@@ -179,9 +177,9 @@ def _writeMaskedImage(self, tray_image: Image) -> None:
179177
hsv = [(i / num_instances, 1, brightness) for i in range(num_instances)]
180178
colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
181179
random.shuffle(colors)
182-
_, ax = plt.subplots()
183180
for i in range(num_instances):
184181
mask = masks.data[i].numpy()
182+
(r, g, b) = colors[i]
185183
for c in range(3):
186184
result[:, :, c] = np.where(
187185
mask == 1,
@@ -191,31 +189,21 @@ def _writeMaskedImage(self, tray_image: Image) -> None:
191189

192190
x1, y1, x2, y2 = coords[i]
193191
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
194-
ax.text(
195-
x1, y1 + 10, "{:.3f}".format(confs[i]), color="w", size=7, backgroundcolor="none"
196-
)
197-
p = patches.Rectangle(
192+
cv2.rectangle(result, (x1, y1), (x2, y2), (r * 255, g * 255, b * 255), 5)
193+
cv2.putText(
194+
result,
195+
"{:.3f}".format(confs[i]),
198196
(x1, y1),
199-
x2 - x1,
200-
y2 - y1,
201-
linewidth=1,
202-
edgecolor=colors[i],
203-
facecolor="none",
204-
linestyle="dashed",
197+
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
198+
fontScale=2,
199+
color=(255, 255, 255),
200+
thickness=3,
205201
)
206-
ax.add_patch(p)
207-
plt.axis("off")
208-
plt.imshow(result)
209-
plt.tight_layout()
210-
plt.savefig(
211-
os.path.join(
212-
self.full_images.getValue(),
213-
tray_image.getImageName(),
214-
),
215-
bbox_inches="tight",
216-
pad_inches=0,
217-
dpi=300,
202+
image_instance: Image = RGBImage(
203+
pathlib.Path(tray_image.getImageName()).stem + f"_masked_image" + ".png"
218204
)
205+
image_instance.setImage(result)
206+
return image_instance
219207

220208
def _sortInstances(self, boxes: NDArray[np.float32], img_shape: Tuple[int, int]):
221209
"""
@@ -370,6 +358,7 @@ def performAnalysis(self) -> List[Image]:
370358
# performs segmentation on each image one-by-one
371359
segmented_images: List[Image] = []
372360
tray_images: List[Image] = []
361+
masked_images: List[Image] = []
373362
for image_instance in self.images:
374363
# set ImageIO with specific file path
375364
self.image_io.setFilePath(image_instance.getFilePath())
@@ -389,22 +378,30 @@ def performAnalysis(self) -> List[Image]:
389378
image_instance.setSegmentationResults(results=result)
390379

391380
try:
392-
# extracts individual instances and tray information
381+
# extracts individual instances
393382
image_instances = self._extractImage(image_instance)
394-
tray_info = self._extractTrayInfo(image_instance)
383+
# and tray information
384+
tray_infos = self._extractTrayInfo(image_instance)
385+
# and masked image
386+
masked_image = self._extractMaskedImage(image_instance)
387+
388+
# save to list for output
395389
segmented_images.extend(image_instances)
396-
tray_images.extend(tray_info)
397-
# writes masked image
398-
self._writeMaskedImage(image_instance)
390+
tray_images.extend(tray_infos)
391+
masked_images.append(masked_image)
399392
except:
400393
AttributeError("Error with the results.")
401394

402395
# 1. sets the output ImageListValue with the list of segmented images
403396
# 2. writes the segmented images to "segmented_images" folder
404397
# 3. writes the tray information to "tray_info" folder
398+
# 4. writes the full masked images to "full_masked_images" folder
405399
self.seg_images.setImageList(segmented_images)
406400
self.seg_images.writeValue()
407401

408402
self.tray_infos.setImageList(tray_images)
409403
self.tray_infos.writeValue()
404+
405+
self.masked_images.setImageList(masked_images)
406+
self.masked_images.writeValue()
410407
return segmented_images

0 commit comments

Comments
 (0)