|
| 1 | +import numpy as np |
| 2 | +import cv2 |
| 3 | +from segment_anything import SamAutomaticMaskGenerator, sam_model_registry |
| 4 | +import pandas as pd |
| 5 | +import matplotlib.pyplot as plt |
| 6 | +import torch |
| 7 | +import os |
| 8 | +import urllib.request |
| 9 | + |
| 10 | + |
| 11 | +class ParticleAnalyzer: |
| 12 | + """ |
| 13 | + A class to encapsulate an end-to-end particle segmentation and analysis |
| 14 | + workflow using the Segment Anything Model (SAM). |
| 15 | +
|
| 16 | + This class handles: |
| 17 | + - Automatic downloading of SAM model checkpoints. |
| 18 | + - Image pre-processing, including normalization and optional contrast enhancement. |
| 19 | + - Running SAM with preset or custom parameters. |
| 20 | + - Advanced post-processing to filter masks by area and shape, and to remove duplicates. |
| 21 | + - Extraction of detailed properties for each detected particle. |
| 22 | + - Conversion of results to a pandas DataFrame and visualization of results. |
| 23 | +
|
| 24 | + Example: |
| 25 | + >>> # 1. Initialize the analyzer (downloads model if needed) |
| 26 | + >>> analyzer = ParticleAnalyzer(model_type="vit_h") |
| 27 | + >>> |
| 28 | + >>> # 2. Load image and run the analysis |
| 29 | + >>> image = np.load(path_to_your_image) |
| 30 | + >>> result = analyzer.analyze(image) |
| 31 | + >>> |
| 32 | + >>> # 3. Print summary and visualize results |
| 33 | + >>> print(f"Found {result['total_count']} particles.") |
| 34 | + >>> df = ParticleAnalyzer.particles_to_dataframe(result) |
| 35 | + >>> print(df.head()) |
| 36 | + >>> |
| 37 | + >>> # This will generate and show a side-by-side plot |
| 38 | + >>> ParticleAnalyzer.visualize_particles( |
| 39 | + ... result, |
| 40 | + ... original_image_for_plot=image, |
| 41 | + ... show_plot=True |
| 42 | + ... ) |
| 43 | + """ |
| 44 | + def __init__(self, checkpoint_path=None, model_type="vit_h", device="auto"): |
| 45 | + """ |
| 46 | + Initializes the ParticleAnalyzer by loading the SAM model. |
| 47 | + If the model checkpoint is not found, it will be downloaded automatically. |
| 48 | + |
| 49 | + Args: |
| 50 | + checkpoint_path (str, optional): Path to the SAM model checkpoint file. |
| 51 | + If None, a default path will be used. |
| 52 | + model_type (str): The type of SAM model (e.g., "vit_h", "vit_l", "vit_b"). |
| 53 | + device (str): The device to run the model on ("auto", "cuda", "cpu"). |
| 54 | + """ |
| 55 | + print("Initializing Particle Analyzer...") |
| 56 | + self.device = self._get_device(device) |
| 57 | + |
| 58 | + # Determine the final checkpoint path and download if necessary |
| 59 | + final_checkpoint_path = self._download_model_if_needed(checkpoint_path, model_type) |
| 60 | + |
| 61 | + self.sam_model = self._load_model(final_checkpoint_path, model_type) |
| 62 | + print(f"SAM model loaded successfully on device: {self.device}") |
| 63 | + |
| 64 | + def _get_device(self, device): |
| 65 | + """Determines the appropriate device for PyTorch.""" |
| 66 | + if device == "auto": |
| 67 | + return "cuda" if torch.cuda.is_available() else "cpu" |
| 68 | + return device |
| 69 | + |
| 70 | + def _download_model_if_needed(self, checkpoint_path, model_type): |
| 71 | + """Checks for the model checkpoint and downloads it if it doesn't exist.""" |
| 72 | + model_urls = { |
| 73 | + "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", |
| 74 | + "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", |
| 75 | + "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" |
| 76 | + } |
| 77 | + |
| 78 | + if checkpoint_path is None: |
| 79 | + # Create a default path if none is provided |
| 80 | + checkpoint_dir = "./checkpoints" |
| 81 | + os.makedirs(checkpoint_dir, exist_ok=True) |
| 82 | + checkpoint_path = os.path.join(checkpoint_dir, f"sam_{model_type}.pth") |
| 83 | + |
| 84 | + if not os.path.exists(checkpoint_path): |
| 85 | + url = model_urls.get(model_type) |
| 86 | + if url is None: |
| 87 | + raise ValueError(f"Unknown model type: '{model_type}'. Cannot download.") |
| 88 | + |
| 89 | + print(f"SAM model checkpoint not found at '{checkpoint_path}'.") |
| 90 | + print(f"Downloading model for '{model_type}' from {url}...") |
| 91 | + |
| 92 | + urllib.request.urlretrieve(url, checkpoint_path) |
| 93 | + print(f"Download complete. Model saved to '{checkpoint_path}'.") |
| 94 | + |
| 95 | + return checkpoint_path |
| 96 | + |
| 97 | + def _load_model(self, checkpoint_path, model_type): |
| 98 | + """Loads the SAM model from a checkpoint and moves it to the device.""" |
| 99 | + try: |
| 100 | + sam = sam_model_registry[model_type](checkpoint=checkpoint_path) |
| 101 | + sam.to(device=self.device) |
| 102 | + return sam |
| 103 | + except Exception as e: |
| 104 | + print(f"Error loading SAM model from '{checkpoint_path}': {e}") |
| 105 | + raise |
| 106 | + |
| 107 | + def analyze(self, image_array, params=None): |
| 108 | + """ |
| 109 | + Runs the full analysis pipeline on a given image using a set of parameters. |
| 110 | +
|
| 111 | + Args: |
| 112 | + image_array (np.array): The input 2D grayscale image. |
| 113 | + params (dict, optional): A dictionary of parameters controlling the analysis. |
| 114 | + If None, a set of default parameters will be used. |
| 115 | + """ |
| 116 | + # If no parameters are provided, use a default set for baseline analysis. |
| 117 | + if params is None: |
| 118 | + print("No parameters provided. Using default analysis settings.") |
| 119 | + params = { |
| 120 | + "use_clahe": False, |
| 121 | + "sam_parameters": "default", |
| 122 | + "min_area": 500, |
| 123 | + "max_area": 50000, |
| 124 | + "use_pruning": False, |
| 125 | + "pruning_iou_threshold": 0.5 |
| 126 | + } |
| 127 | + |
| 128 | + # 1. Pre-process the image |
| 129 | + processed_image = self._preprocess_image(image_array, params.get("use_clahe", False)) |
| 130 | + image_rgb = cv2.cvtColor(processed_image, cv2.COLOR_GRAY2RGB) |
| 131 | + |
| 132 | + # 2. Generate masks with SAM using specified parameters |
| 133 | + all_masks = self._run_sam(image_rgb, params.get("sam_parameters", "default")) |
| 134 | + print(f"Generated {len(all_masks)} raw masks.") |
| 135 | + |
| 136 | + # 3. Filter and prune masks |
| 137 | + final_masks_info = self._filter_and_prune(all_masks, params) |
| 138 | + print(f"Kept {len(final_masks_info)} masks after filtering and pruning.") |
| 139 | + |
| 140 | + # 4. Extract properties from final masks |
| 141 | + particles = [] |
| 142 | + for i, mask in enumerate(final_masks_info): |
| 143 | + particle_info = self._extract_particle_properties(mask, processed_image, i + 1) |
| 144 | + particles.append(particle_info) |
| 145 | + |
| 146 | + # Sort by area for consistent ordering |
| 147 | + particles = sorted(particles, key=lambda x: x['area'], reverse=True) |
| 148 | + # Reassign IDs after sorting |
| 149 | + for i, particle in enumerate(particles): |
| 150 | + particle['id'] = i + 1 |
| 151 | + |
| 152 | + return { |
| 153 | + 'particles': particles, |
| 154 | + 'original_image': processed_image, |
| 155 | + 'rgb_image': image_rgb, |
| 156 | + 'total_count': len(particles) |
| 157 | + } |
| 158 | + |
| 159 | + def _preprocess_image(self, image_array, use_clahe): |
| 160 | + """Normalizes image to uint8 and optionally applies CLAHE.""" |
| 161 | + # Normalize to uint8 |
| 162 | + if image_array.dtype != np.uint8: |
| 163 | + if image_array.max() <= 1.0 and image_array.min() >= 0.0: |
| 164 | + image_array = (image_array * 255).astype(np.uint8) |
| 165 | + else: |
| 166 | + min_val, max_val = image_array.min(), image_array.max() |
| 167 | + if max_val > min_val: |
| 168 | + image_array = ((image_array - min_val) / (max_val - min_val) * 255).astype(np.uint8) |
| 169 | + else: |
| 170 | + image_array = np.zeros_like(image_array, dtype=np.uint8) |
| 171 | + |
| 172 | + # Apply CLAHE if requested |
| 173 | + if use_clahe: |
| 174 | + print("Applying CLAHE...") |
| 175 | + clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) |
| 176 | + image_array = clahe.apply(image_array) |
| 177 | + |
| 178 | + return image_array |
| 179 | + |
| 180 | + def _run_sam(self, image_rgb, preset_name): |
| 181 | + """Initializes and runs the SAM mask generator based on a preset.""" |
| 182 | + sam_param_presets = { |
| 183 | + "default": {}, |
| 184 | + "sensitive": { |
| 185 | + "points_per_side": 96, |
| 186 | + "pred_iou_thresh": 0.80, |
| 187 | + "stability_score_thresh": 0.85, |
| 188 | + }, |
| 189 | + "ultra-permissive": { |
| 190 | + "points_per_side": 96, |
| 191 | + "pred_iou_thresh": 0.60, |
| 192 | + "stability_score_thresh": 0.70, |
| 193 | + } |
| 194 | + } |
| 195 | + |
| 196 | + sam_params = sam_param_presets.get(preset_name, {}) |
| 197 | + print(f"Running SAM with preset: '{preset_name}'") |
| 198 | + |
| 199 | + mask_generator = SamAutomaticMaskGenerator(self.sam_model, **sam_params) |
| 200 | + return mask_generator.generate(image_rgb) |
| 201 | + |
| 202 | + def _filter_and_prune(self, masks, params): |
| 203 | + """Applies area filtering and optional shape-based pruning.""" |
| 204 | + min_area = params.get("min_area", 0) |
| 205 | + max_area = params.get("max_area", float('inf')) |
| 206 | + |
| 207 | + # Area filtering |
| 208 | + area_filtered_masks = [m for m in masks if min_area <= m['area'] <= max_area] |
| 209 | + |
| 210 | + if params.get("use_pruning", False): |
| 211 | + print("Applying shape-based pruning...") |
| 212 | + iou_threshold = params.get("pruning_iou_threshold", 0.5) |
| 213 | + return self._prune_by_shape_and_iou(area_filtered_masks, iou_threshold) |
| 214 | + else: |
| 215 | + return area_filtered_masks |
| 216 | + |
| 217 | + def _extract_particle_properties(self, mask, image, particle_id): |
| 218 | + """Extracts detailed properties for a single particle mask.""" |
| 219 | + binary_mask = mask['segmentation'] |
| 220 | + area = mask['area'] |
| 221 | + |
| 222 | + y_coords, x_coords = np.where(binary_mask) |
| 223 | + centroid = (np.mean(x_coords), np.mean(y_coords)) |
| 224 | + |
| 225 | + particle_pixels = image[binary_mask] |
| 226 | + |
| 227 | + perimeter = self._calculate_perimeter(binary_mask) |
| 228 | + |
| 229 | + return { |
| 230 | + 'id': particle_id, |
| 231 | + 'area': area, |
| 232 | + 'centroid': centroid, |
| 233 | + 'bbox': mask['bbox'], |
| 234 | + 'mean_intensity': np.mean(particle_pixels), |
| 235 | + 'std_intensity': np.std(particle_pixels), |
| 236 | + 'min_intensity': np.min(particle_pixels), |
| 237 | + 'max_intensity': np.max(particle_pixels), |
| 238 | + 'perimeter': perimeter, |
| 239 | + 'circularity': 4 * np.pi * area / (perimeter ** 2) if perimeter > 0 else 0, |
| 240 | + 'equiv_diameter': 2 * np.sqrt(area / np.pi), |
| 241 | + 'aspect_ratio': mask['bbox'][3] / mask['bbox'][2] if mask['bbox'][2] > 0 else 1, |
| 242 | + 'solidity': mask.get('solidity', self._calculate_solidity(mask)), # Use pre-calculated solidity if available |
| 243 | + 'mask': binary_mask |
| 244 | + } |
| 245 | + |
| 246 | + def _calculate_perimeter(self, binary_mask): |
| 247 | + contours, _ = cv2.findContours(binary_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| 248 | + return cv2.arcLength(contours[0], True) if contours else 0 |
| 249 | + |
| 250 | + def _calculate_solidity(self, mask): |
| 251 | + binary_mask = mask['segmentation'].astype(np.uint8) |
| 252 | + contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| 253 | + if not contours: return 0 |
| 254 | + cnt = contours[0] |
| 255 | + area = cv2.contourArea(cnt) |
| 256 | + hull = cv2.convexHull(cnt) |
| 257 | + hull_area = cv2.contourArea(hull) |
| 258 | + return area / hull_area if hull_area > 0 else 0 |
| 259 | + |
| 260 | + def _calculate_iou(self, mask1, mask2): |
| 261 | + bbox1, bbox2 = mask1['bbox'], mask2['bbox'] |
| 262 | + x_left = max(bbox1[0], bbox2[0]) |
| 263 | + y_top = max(bbox1[1], bbox2[1]) |
| 264 | + x_right = min(bbox1[0] + bbox1[2], bbox2[0] + bbox2[2]) |
| 265 | + y_bottom = min(bbox1[1] + bbox1[3], bbox2[1] + bbox2[3]) |
| 266 | + if x_right < x_left or y_bottom < y_top: return 0.0 |
| 267 | + intersection_area = (x_right - x_left) * (y_bottom - y_top) |
| 268 | + area1, area2 = bbox1[2] * bbox1[3], bbox2[2] * bbox2[3] |
| 269 | + union_area = area1 + area2 - intersection_area |
| 270 | + return intersection_area / union_area if union_area > 0 else 0.0 |
| 271 | + |
| 272 | + def _prune_by_shape_and_iou(self, masks, iou_threshold): |
| 273 | + """Prunes masks based on a goodness score and IoU.""" |
| 274 | + if not masks: return [] |
| 275 | + |
| 276 | + for m in masks: |
| 277 | + m['solidity'] = self._calculate_solidity(m) |
| 278 | + m['score'] = m['area'] * (m['solidity'] ** 2) |
| 279 | + |
| 280 | + sorted_masks = sorted(masks, key=lambda x: x['score'], reverse=True) |
| 281 | + |
| 282 | + pruned_masks = [] |
| 283 | + for mask in sorted_masks: |
| 284 | + is_duplicate = any(self._calculate_iou(mask, kept_mask) > iou_threshold for kept_mask in pruned_masks) |
| 285 | + if not is_duplicate: |
| 286 | + pruned_masks.append(mask) |
| 287 | + return pruned_masks |
| 288 | + |
| 289 | + @staticmethod |
| 290 | + def particles_to_dataframe(result): |
| 291 | + """Converts the 'particles' list from the result into a pandas DataFrame.""" |
| 292 | + particles = result.get('particles', []) |
| 293 | + if not particles: return pd.DataFrame() |
| 294 | + |
| 295 | + data = [] |
| 296 | + for p in particles: |
| 297 | + row = {k: v for k, v in p.items() if k != 'mask'} |
| 298 | + row['centroid_x'], row['centroid_y'] = p['centroid'] |
| 299 | + row['bbox_x'], row['bbox_y'], row['bbox_width'], row['bbox_height'] = p['bbox'] |
| 300 | + del row['centroid'], row['bbox'] |
| 301 | + data.append(row) |
| 302 | + return pd.DataFrame(data) |
| 303 | + |
| 304 | + @staticmethod |
| 305 | + def visualize_particles(result, original_image_for_plot=None, show_plot=False, show_labels=True, show_centroids=True): |
| 306 | + """ |
| 307 | + Creates an RGB image visualizing the detected particles and optionally displays a plot. |
| 308 | + |
| 309 | + Args: |
| 310 | + result (dict): The output dictionary from the analyze method. |
| 311 | + original_image_for_plot (np.array, optional): The raw, unprocessed image for side-by-side comparison. |
| 312 | + If None, the processed image from the result is used. |
| 313 | + show_plot (bool): If True, displays a matplotlib plot comparing original and segmented images. |
| 314 | + show_labels (bool): If True, shows particle ID labels on the overlay. |
| 315 | + show_centroids (bool): If True, shows particle centroids on the overlay. |
| 316 | + |
| 317 | + Returns: |
| 318 | + np.array: The RGB overlay image with particles drawn on it. |
| 319 | + """ |
| 320 | + overlay = result['rgb_image'].copy() |
| 321 | + for particle in result.get('particles', []): |
| 322 | + contours, _ = cv2.findContours(particle['mask'].astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| 323 | + cv2.drawContours(overlay, contours, -1, (255, 0, 0), 2) |
| 324 | + |
| 325 | + cx, cy = int(particle['centroid'][0]), int(particle['centroid'][1]) |
| 326 | + if show_centroids: |
| 327 | + cv2.circle(overlay, (cx, cy), 5, (0, 255, 0), -1) |
| 328 | + if show_labels: |
| 329 | + cv2.putText(overlay, str(particle['id']), (cx + 5, cy + 5), |
| 330 | + cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 0), 2) |
| 331 | + |
| 332 | + if show_plot: |
| 333 | + fig, axes = plt.subplots(1, 2, figsize=(16, 8)) |
| 334 | + |
| 335 | + # Use the provided original image for the 'before' plot, otherwise use the processed one from results |
| 336 | + display_image = original_image_for_plot if original_image_for_plot is not None else result['original_image'] |
| 337 | + |
| 338 | + axes[0].imshow(display_image, cmap='gray') |
| 339 | + axes[0].set_title('Original Input') |
| 340 | + axes[1].imshow(overlay) |
| 341 | + axes[1].set_title(f"Detected Particles (n={result['total_count']})") |
| 342 | + for ax in axes: |
| 343 | + ax.set_axis_off() |
| 344 | + plt.tight_layout() |
| 345 | + plt.show() |
| 346 | + |
| 347 | + return overlay |
0 commit comments