Skip to content

Commit 93c8881

Browse files
authored
Merge pull request #97 from ziatdinovmax/master
Sync up
2 parents 0ebe187 + bc1297e commit 93c8881

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

atomai/models/sam.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import numpy as np
22
import cv2
3-
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
43
import pandas as pd
54
import matplotlib.pyplot as plt
65
import torch
@@ -26,7 +25,7 @@ class ParticleAnalyzer:
2625
>>> analyzer = ParticleAnalyzer(model_type="vit_h")
2726
>>>
2827
>>> # 2. Load image and run the analysis
29-
>>> image = np.load(path_to_your_image)
28+
>>> image = np.load(IMAGE_PATH)
3029
>>> result = analyzer.analyze(image)
3130
>>>
3231
>>> # 3. Print summary and visualize results
@@ -96,6 +95,15 @@ def _download_model_if_needed(self, checkpoint_path, model_type):
9695

9796
def _load_model(self, checkpoint_path, model_type):
9897
"""Loads the SAM model from a checkpoint and moves it to the device."""
98+
try:
99+
from segment_anything import sam_model_registry
100+
except ImportError:
101+
raise ImportError(
102+
"The 'segment-anything' package is required to use this feature.\n"
103+
"Please install it directly from the official repository:\n\n"
104+
"pip install git+https://github.com/facebookresearch/segment-anything.git"
105+
)
106+
99107
try:
100108
sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
101109
sam.to(device=self.device)
@@ -179,6 +187,15 @@ def _preprocess_image(self, image_array, use_clahe):
179187

180188
def _run_sam(self, image_rgb, preset_name):
181189
"""Initializes and runs the SAM mask generator based on a preset."""
190+
try:
191+
from segment_anything import SamAutomaticMaskGenerator
192+
except ImportError:
193+
raise ImportError(
194+
"The 'segment-anything' package is required to use this feature.\n"
195+
"Please install it directly from the official repository:\n\n"
196+
"pip install git+https://github.com/facebookresearch/segment-anything.git"
197+
)
198+
182199
sam_param_presets = {
183200
"default": {},
184201
"sensitive": {

setup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@
3838
'torchvision>=0.13.0',
3939
'progressbar2>=3.38.0',
4040
'gpytorch>=1.9.1',
41-
'pandas>=1.1.5',
42-
'segment-anything @ git+https://github.com/facebookresearch/segment-anything.git'
41+
'pandas>=1.1.5'
4342
],
4443
classifiers=['Programming Language :: Python',
4544
'Development Status :: 3 - Alpha',

0 commit comments

Comments
 (0)