Skip to content

Commit c129fea

Browse files
authored
Merge pull request #350 from Anirudh2112/fix/predict-numpy-images
fix: allow model.predict to handle numpy array inputs
2 parents c570d22 + 0d51bd8 commit c129fea

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

roboflow/models/inference.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,26 @@ def __get_image_params(self, image_path):
6262
Get parameters about an image (i.e. dimensions) for use in an inference request.
6363
6464
Args:
65-
image_path (str): path to the image you'd like to perform prediction on
65+
image_path (Union[str, np.ndarray]): path to image or numpy array
6666
6767
Returns:
6868
Tuple containing a dict of querystring params and a dict of requests kwargs
6969
7070
Raises:
7171
Exception: Image path is not valid
7272
"""
73+
import numpy as np
74+
75+
if isinstance(image_path, np.ndarray):
76+
# Convert numpy array to PIL Image
77+
image = Image.fromarray(image_path)
78+
dimensions = image.size
79+
image_dims = {"width": str(dimensions[0]), "height": str(dimensions[1])}
80+
buffered = io.BytesIO()
81+
image.save(buffered, quality=90, format="JPEG")
82+
data = MultipartEncoder(fields={"file": ("imageToUpload", buffered.getvalue(), "image/jpeg")})
83+
return {}, {"data": data, "headers": {"Content-Type": data.content_type}}, image_dims
84+
7385
validate_image_path(image_path)
7486

7587
hosted_image = urllib.parse.urlparse(image_path).scheme in ("http", "https")

tests/models/test_instance_segmentation.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,23 @@ def test_predict_with_non_200_response_raises_http_error(self):
142142

143143
with self.assertRaises(HTTPError):
144144
instance.predict(image_path)
145+
146+
@responses.activate
147+
def test_predict_with_numpy_array(self):
148+
# Create a simple numpy array image
149+
import numpy as np
150+
151+
image_array = np.zeros((100, 100, 3), dtype=np.uint8) # Create a black image
152+
image_array[30:70, 30:70] = 255 # Add a white square
153+
154+
instance = InstanceSegmentationModel(self.api_key, self.version_id)
155+
156+
responses.add(responses.POST, self.api_url, json=MOCK_RESPONSE)
157+
group = instance.predict(image_array)
158+
self.assertIsInstance(group, PredictionGroup)
159+
160+
request = responses.calls[0].request
161+
self.assertEqual(request.method, "POST")
162+
self.assertRegex(request.url, rf"^{self.api_url}")
163+
self.assertDictEqual(request.params, self._default_params)
164+
self.assertIsNotNone(request.body)

0 commit comments

Comments
 (0)