Skip to content

Commit 9dd907c

Browse files
authored
Merge pull request #96 from ziatdinovmax/master
Add Particle Analyzer
2 parents b23374e + 1449418 commit 9dd907c

File tree

6 files changed

+396
-1
lines changed

6 files changed

+396
-1
lines changed

atomai/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .classifier import Classifier
55
from .denoiser import DenoisingAutoencoder, denoise_images
66
from .dgm import BaseVAE, VAE, rVAE, jVAE, jrVAE
7+
from .sam import ParticleAnalyzer
78
from .dklgp import dklGPR, Reconstructor
89
from .loaders import load_model, load_ensemble, load_pretrained_model
910

atomai/models/sam.py

Lines changed: 347 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
import numpy as np
2+
import cv2
3+
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
4+
import pandas as pd
5+
import matplotlib.pyplot as plt
6+
import torch
7+
import os
8+
import urllib.request
9+
10+
11+
class ParticleAnalyzer:
12+
"""
13+
A class to encapsulate an end-to-end particle segmentation and analysis
14+
workflow using the Segment Anything Model (SAM).
15+
16+
This class handles:
17+
- Automatic downloading of SAM model checkpoints.
18+
- Image pre-processing, including normalization and optional contrast enhancement.
19+
- Running SAM with preset or custom parameters.
20+
- Advanced post-processing to filter masks by area and shape, and to remove duplicates.
21+
- Extraction of detailed properties for each detected particle.
22+
- Conversion of results to a pandas DataFrame and visualization of results.
23+
24+
Example:
25+
>>> # 1. Initialize the analyzer (downloads model if needed)
26+
>>> analyzer = ParticleAnalyzer(model_type="vit_h")
27+
>>>
28+
>>> # 2. Load image and run the analysis
29+
>>> image = np.load(path_to_your_image)
30+
>>> result = analyzer.analyze(image)
31+
>>>
32+
>>> # 3. Print summary and visualize results
33+
>>> print(f"Found {result['total_count']} particles.")
34+
>>> df = ParticleAnalyzer.particles_to_dataframe(result)
35+
>>> print(df.head())
36+
>>>
37+
>>> # This will generate and show a side-by-side plot
38+
>>> ParticleAnalyzer.visualize_particles(
39+
... result,
40+
... original_image_for_plot=image,
41+
... show_plot=True
42+
... )
43+
"""
44+
def __init__(self, checkpoint_path=None, model_type="vit_h", device="auto"):
45+
"""
46+
Initializes the ParticleAnalyzer by loading the SAM model.
47+
If the model checkpoint is not found, it will be downloaded automatically.
48+
49+
Args:
50+
checkpoint_path (str, optional): Path to the SAM model checkpoint file.
51+
If None, a default path will be used.
52+
model_type (str): The type of SAM model (e.g., "vit_h", "vit_l", "vit_b").
53+
device (str): The device to run the model on ("auto", "cuda", "cpu").
54+
"""
55+
print("Initializing Particle Analyzer...")
56+
self.device = self._get_device(device)
57+
58+
# Determine the final checkpoint path and download if necessary
59+
final_checkpoint_path = self._download_model_if_needed(checkpoint_path, model_type)
60+
61+
self.sam_model = self._load_model(final_checkpoint_path, model_type)
62+
print(f"SAM model loaded successfully on device: {self.device}")
63+
64+
def _get_device(self, device):
65+
"""Determines the appropriate device for PyTorch."""
66+
if device == "auto":
67+
return "cuda" if torch.cuda.is_available() else "cpu"
68+
return device
69+
70+
def _download_model_if_needed(self, checkpoint_path, model_type):
71+
"""Checks for the model checkpoint and downloads it if it doesn't exist."""
72+
model_urls = {
73+
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
74+
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
75+
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
76+
}
77+
78+
if checkpoint_path is None:
79+
# Create a default path if none is provided
80+
checkpoint_dir = "./checkpoints"
81+
os.makedirs(checkpoint_dir, exist_ok=True)
82+
checkpoint_path = os.path.join(checkpoint_dir, f"sam_{model_type}.pth")
83+
84+
if not os.path.exists(checkpoint_path):
85+
url = model_urls.get(model_type)
86+
if url is None:
87+
raise ValueError(f"Unknown model type: '{model_type}'. Cannot download.")
88+
89+
print(f"SAM model checkpoint not found at '{checkpoint_path}'.")
90+
print(f"Downloading model for '{model_type}' from {url}...")
91+
92+
urllib.request.urlretrieve(url, checkpoint_path)
93+
print(f"Download complete. Model saved to '{checkpoint_path}'.")
94+
95+
return checkpoint_path
96+
97+
def _load_model(self, checkpoint_path, model_type):
98+
"""Loads the SAM model from a checkpoint and moves it to the device."""
99+
try:
100+
sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
101+
sam.to(device=self.device)
102+
return sam
103+
except Exception as e:
104+
print(f"Error loading SAM model from '{checkpoint_path}': {e}")
105+
raise
106+
107+
def analyze(self, image_array, params=None):
108+
"""
109+
Runs the full analysis pipeline on a given image using a set of parameters.
110+
111+
Args:
112+
image_array (np.array): The input 2D grayscale image.
113+
params (dict, optional): A dictionary of parameters controlling the analysis.
114+
If None, a set of default parameters will be used.
115+
"""
116+
# If no parameters are provided, use a default set for baseline analysis.
117+
if params is None:
118+
print("No parameters provided. Using default analysis settings.")
119+
params = {
120+
"use_clahe": False,
121+
"sam_parameters": "default",
122+
"min_area": 500,
123+
"max_area": 50000,
124+
"use_pruning": False,
125+
"pruning_iou_threshold": 0.5
126+
}
127+
128+
# 1. Pre-process the image
129+
processed_image = self._preprocess_image(image_array, params.get("use_clahe", False))
130+
image_rgb = cv2.cvtColor(processed_image, cv2.COLOR_GRAY2RGB)
131+
132+
# 2. Generate masks with SAM using specified parameters
133+
all_masks = self._run_sam(image_rgb, params.get("sam_parameters", "default"))
134+
print(f"Generated {len(all_masks)} raw masks.")
135+
136+
# 3. Filter and prune masks
137+
final_masks_info = self._filter_and_prune(all_masks, params)
138+
print(f"Kept {len(final_masks_info)} masks after filtering and pruning.")
139+
140+
# 4. Extract properties from final masks
141+
particles = []
142+
for i, mask in enumerate(final_masks_info):
143+
particle_info = self._extract_particle_properties(mask, processed_image, i + 1)
144+
particles.append(particle_info)
145+
146+
# Sort by area for consistent ordering
147+
particles = sorted(particles, key=lambda x: x['area'], reverse=True)
148+
# Reassign IDs after sorting
149+
for i, particle in enumerate(particles):
150+
particle['id'] = i + 1
151+
152+
return {
153+
'particles': particles,
154+
'original_image': processed_image,
155+
'rgb_image': image_rgb,
156+
'total_count': len(particles)
157+
}
158+
159+
def _preprocess_image(self, image_array, use_clahe):
160+
"""Normalizes image to uint8 and optionally applies CLAHE."""
161+
# Normalize to uint8
162+
if image_array.dtype != np.uint8:
163+
if image_array.max() <= 1.0 and image_array.min() >= 0.0:
164+
image_array = (image_array * 255).astype(np.uint8)
165+
else:
166+
min_val, max_val = image_array.min(), image_array.max()
167+
if max_val > min_val:
168+
image_array = ((image_array - min_val) / (max_val - min_val) * 255).astype(np.uint8)
169+
else:
170+
image_array = np.zeros_like(image_array, dtype=np.uint8)
171+
172+
# Apply CLAHE if requested
173+
if use_clahe:
174+
print("Applying CLAHE...")
175+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
176+
image_array = clahe.apply(image_array)
177+
178+
return image_array
179+
180+
def _run_sam(self, image_rgb, preset_name):
181+
"""Initializes and runs the SAM mask generator based on a preset."""
182+
sam_param_presets = {
183+
"default": {},
184+
"sensitive": {
185+
"points_per_side": 96,
186+
"pred_iou_thresh": 0.80,
187+
"stability_score_thresh": 0.85,
188+
},
189+
"ultra-permissive": {
190+
"points_per_side": 96,
191+
"pred_iou_thresh": 0.60,
192+
"stability_score_thresh": 0.70,
193+
}
194+
}
195+
196+
sam_params = sam_param_presets.get(preset_name, {})
197+
print(f"Running SAM with preset: '{preset_name}'")
198+
199+
mask_generator = SamAutomaticMaskGenerator(self.sam_model, **sam_params)
200+
return mask_generator.generate(image_rgb)
201+
202+
def _filter_and_prune(self, masks, params):
203+
"""Applies area filtering and optional shape-based pruning."""
204+
min_area = params.get("min_area", 0)
205+
max_area = params.get("max_area", float('inf'))
206+
207+
# Area filtering
208+
area_filtered_masks = [m for m in masks if min_area <= m['area'] <= max_area]
209+
210+
if params.get("use_pruning", False):
211+
print("Applying shape-based pruning...")
212+
iou_threshold = params.get("pruning_iou_threshold", 0.5)
213+
return self._prune_by_shape_and_iou(area_filtered_masks, iou_threshold)
214+
else:
215+
return area_filtered_masks
216+
217+
def _extract_particle_properties(self, mask, image, particle_id):
218+
"""Extracts detailed properties for a single particle mask."""
219+
binary_mask = mask['segmentation']
220+
area = mask['area']
221+
222+
y_coords, x_coords = np.where(binary_mask)
223+
centroid = (np.mean(x_coords), np.mean(y_coords))
224+
225+
particle_pixels = image[binary_mask]
226+
227+
perimeter = self._calculate_perimeter(binary_mask)
228+
229+
return {
230+
'id': particle_id,
231+
'area': area,
232+
'centroid': centroid,
233+
'bbox': mask['bbox'],
234+
'mean_intensity': np.mean(particle_pixels),
235+
'std_intensity': np.std(particle_pixels),
236+
'min_intensity': np.min(particle_pixels),
237+
'max_intensity': np.max(particle_pixels),
238+
'perimeter': perimeter,
239+
'circularity': 4 * np.pi * area / (perimeter ** 2) if perimeter > 0 else 0,
240+
'equiv_diameter': 2 * np.sqrt(area / np.pi),
241+
'aspect_ratio': mask['bbox'][3] / mask['bbox'][2] if mask['bbox'][2] > 0 else 1,
242+
'solidity': mask.get('solidity', self._calculate_solidity(mask)), # Use pre-calculated solidity if available
243+
'mask': binary_mask
244+
}
245+
246+
def _calculate_perimeter(self, binary_mask):
247+
contours, _ = cv2.findContours(binary_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
248+
return cv2.arcLength(contours[0], True) if contours else 0
249+
250+
def _calculate_solidity(self, mask):
251+
binary_mask = mask['segmentation'].astype(np.uint8)
252+
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
253+
if not contours: return 0
254+
cnt = contours[0]
255+
area = cv2.contourArea(cnt)
256+
hull = cv2.convexHull(cnt)
257+
hull_area = cv2.contourArea(hull)
258+
return area / hull_area if hull_area > 0 else 0
259+
260+
def _calculate_iou(self, mask1, mask2):
261+
bbox1, bbox2 = mask1['bbox'], mask2['bbox']
262+
x_left = max(bbox1[0], bbox2[0])
263+
y_top = max(bbox1[1], bbox2[1])
264+
x_right = min(bbox1[0] + bbox1[2], bbox2[0] + bbox2[2])
265+
y_bottom = min(bbox1[1] + bbox1[3], bbox2[1] + bbox2[3])
266+
if x_right < x_left or y_bottom < y_top: return 0.0
267+
intersection_area = (x_right - x_left) * (y_bottom - y_top)
268+
area1, area2 = bbox1[2] * bbox1[3], bbox2[2] * bbox2[3]
269+
union_area = area1 + area2 - intersection_area
270+
return intersection_area / union_area if union_area > 0 else 0.0
271+
272+
def _prune_by_shape_and_iou(self, masks, iou_threshold):
273+
"""Prunes masks based on a goodness score and IoU."""
274+
if not masks: return []
275+
276+
for m in masks:
277+
m['solidity'] = self._calculate_solidity(m)
278+
m['score'] = m['area'] * (m['solidity'] ** 2)
279+
280+
sorted_masks = sorted(masks, key=lambda x: x['score'], reverse=True)
281+
282+
pruned_masks = []
283+
for mask in sorted_masks:
284+
is_duplicate = any(self._calculate_iou(mask, kept_mask) > iou_threshold for kept_mask in pruned_masks)
285+
if not is_duplicate:
286+
pruned_masks.append(mask)
287+
return pruned_masks
288+
289+
@staticmethod
290+
def particles_to_dataframe(result):
291+
"""Converts the 'particles' list from the result into a pandas DataFrame."""
292+
particles = result.get('particles', [])
293+
if not particles: return pd.DataFrame()
294+
295+
data = []
296+
for p in particles:
297+
row = {k: v for k, v in p.items() if k != 'mask'}
298+
row['centroid_x'], row['centroid_y'] = p['centroid']
299+
row['bbox_x'], row['bbox_y'], row['bbox_width'], row['bbox_height'] = p['bbox']
300+
del row['centroid'], row['bbox']
301+
data.append(row)
302+
return pd.DataFrame(data)
303+
304+
@staticmethod
305+
def visualize_particles(result, original_image_for_plot=None, show_plot=False, show_labels=True, show_centroids=True):
306+
"""
307+
Creates an RGB image visualizing the detected particles and optionally displays a plot.
308+
309+
Args:
310+
result (dict): The output dictionary from the analyze method.
311+
original_image_for_plot (np.array, optional): The raw, unprocessed image for side-by-side comparison.
312+
If None, the processed image from the result is used.
313+
show_plot (bool): If True, displays a matplotlib plot comparing original and segmented images.
314+
show_labels (bool): If True, shows particle ID labels on the overlay.
315+
show_centroids (bool): If True, shows particle centroids on the overlay.
316+
317+
Returns:
318+
np.array: The RGB overlay image with particles drawn on it.
319+
"""
320+
overlay = result['rgb_image'].copy()
321+
for particle in result.get('particles', []):
322+
contours, _ = cv2.findContours(particle['mask'].astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
323+
cv2.drawContours(overlay, contours, -1, (255, 0, 0), 2)
324+
325+
cx, cy = int(particle['centroid'][0]), int(particle['centroid'][1])
326+
if show_centroids:
327+
cv2.circle(overlay, (cx, cy), 5, (0, 255, 0), -1)
328+
if show_labels:
329+
cv2.putText(overlay, str(particle['id']), (cx + 5, cy + 5),
330+
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 2)
331+
332+
if show_plot:
333+
fig, axes = plt.subplots(1, 2, figsize=(16, 8))
334+
335+
# Use the provided original image for the 'before' plot, otherwise use the processed one from results
336+
display_image = original_image_for_plot if original_image_for_plot is not None else result['original_image']
337+
338+
axes[0].imshow(display_image, cmap='gray')
339+
axes[0].set_title('Original Input')
340+
axes[1].imshow(overlay)
341+
axes[1].set_title(f"Detected Particles (n={result['total_count']})")
342+
for ax in axes:
343+
ax.set_axis_off()
344+
plt.tight_layout()
345+
plt.show()
346+
347+
return overlay

docs/source/atomai_models.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,22 @@ ImSpec
1717
:member-order: bysource
1818
:show-inheritance:
1919

20+
ParticleAnalyzer
21+
----------------
22+
.. autoclass:: atomai.models.ParticleAnalyzer
23+
:members:
24+
:undoc-members:
25+
:member-order: bysource
26+
:show-inheritance:
27+
28+
Denoiser
29+
--------
30+
.. autoclass:: atomai.models.DenoisingAutoencoder
31+
:members:
32+
:undoc-members:
33+
:member-order: bysource
34+
:show-inheritance:
35+
2036
Variational Autoencoder (VAE)
2137
-----------------------------
2238
.. autoclass:: atomai.models.VAE

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Welcome to AtomAI's documentation!
2121
atomai_models
2222
trainers_predictors
2323
nets
24+
statistics
2425
losses_metrics
2526
other_utilities
2627

0 commit comments

Comments
 (0)