Skip to content

Commit 71c124d

Browse files
matttrohan-mehta
andauthored
Add prediction field to ModelError (#326)
This PR extends #325 to add the prediction object itself to `ModelError`, as opposed to just its ID. This makes it convenient to introspect logs and other information to determine how to handle the failure. ```python import replicate from replicate.exceptions import ModelError try: output = replicate.run("stability-ai/stable-diffusion-3", { "prompt": "..." }) except ModelError as e if "(some known issue)" in e.logs: pass print("Failed prediction: " + e.prediction.id) ``` --------- Signed-off-by: Rohan Mehta <[email protected]> Signed-off-by: Mattt Zmuda <[email protected]> Co-authored-by: Rohan Mehta <[email protected]>
1 parent ecfedfb commit 71c124d

File tree

5 files changed

+102
-6
lines changed

5 files changed

+102
-6
lines changed

README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,24 @@ or a handle to a file on your local device.
8282
"an astronaut riding a horse"
8383
```
8484
85+
`replicate.run` raises `ModelError` if the prediction fails.
86+
You can access the exception's `prediction` property
87+
to get more information about the failure.
88+
89+
```python
90+
import replicate
91+
from replicate.exceptions import ModelError
92+
93+
try:
94+
output = replicate.run("stability-ai/stable-diffusion-3", { "prompt": "An astronaut riding a rainbow unicorn" })
95+
except ModelError as e
96+
if "(some known issue)" in e.logs:
97+
pass
98+
99+
print("Failed prediction: " + e.prediction.id)
100+
```
101+
102+
85103
## Run a model and stream its output
86104

87105
Replicate’s API supports server-sent event streams (SSEs) for language models.

replicate/exceptions.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
from typing import Optional
1+
from typing import TYPE_CHECKING, Optional
22

33
import httpx
44

5+
if TYPE_CHECKING:
6+
from replicate.prediction import Prediction
7+
58

69
class ReplicateException(Exception):
710
"""A base class for all Replicate exceptions."""
@@ -10,6 +13,12 @@ class ReplicateException(Exception):
1013
class ModelError(ReplicateException):
1114
"""An error from user's code in a model."""
1215

16+
prediction: "Prediction"
17+
18+
def __init__(self, prediction: "Prediction") -> None:
19+
self.prediction = prediction
20+
super().__init__(prediction.error)
21+
1322

1423
class ReplicateError(ReplicateException):
1524
"""

replicate/prediction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def output_iterator(self) -> Iterator[Any]:
249249
self.reload()
250250

251251
if self.status == "failed":
252-
raise ModelError(self.error)
252+
raise ModelError(self)
253253

254254
output = self.output or []
255255
new_output = output[len(previous_output) :]
@@ -272,7 +272,7 @@ async def async_output_iterator(self) -> AsyncIterator[Any]:
272272
await self.async_reload()
273273

274274
if self.status == "failed":
275-
raise ModelError(self.error)
275+
raise ModelError(self)
276276

277277
output = self.output or []
278278
new_output = output[len(previous_output) :]

replicate/run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def run(
5858
prediction.wait()
5959

6060
if prediction.status == "failed":
61-
raise ModelError(prediction.error)
61+
raise ModelError(prediction)
6262

6363
return prediction.output
6464

@@ -97,7 +97,7 @@ async def async_run(
9797
await prediction.async_wait()
9898

9999
if prediction.status == "failed":
100-
raise ModelError(prediction.error)
100+
raise ModelError(prediction)
101101

102102
return prediction.output
103103

tests/test_run.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import replicate
99
from replicate.client import Client
10-
from replicate.exceptions import ReplicateError
10+
from replicate.exceptions import ModelError, ReplicateError
1111

1212

1313
@pytest.mark.vcr("run.yaml")
@@ -184,3 +184,72 @@ def prediction_with_status(status: str) -> dict:
184184
)
185185

186186
assert output == "Hello, world!"
187+
188+
189+
@pytest.mark.asyncio
190+
async def test_run_with_model_error(mock_replicate_api_token):
191+
def prediction_with_status(status: str) -> dict:
192+
return {
193+
"id": "p1",
194+
"model": "test/example",
195+
"version": "v1",
196+
"urls": {
197+
"get": "https://api.replicate.com/v1/predictions/p1",
198+
"cancel": "https://api.replicate.com/v1/predictions/p1/cancel",
199+
},
200+
"created_at": "2023-10-05T12:00:00.000000Z",
201+
"source": "api",
202+
"status": status,
203+
"input": {"text": "world"},
204+
"output": None,
205+
"error": "OOM" if status == "failed" else None,
206+
"logs": "",
207+
}
208+
209+
router = respx.Router(base_url="https://api.replicate.com/v1")
210+
router.route(method="POST", path="/predictions").mock(
211+
return_value=httpx.Response(
212+
201,
213+
json=prediction_with_status("processing"),
214+
)
215+
)
216+
router.route(method="GET", path="/predictions/p1").mock(
217+
return_value=httpx.Response(
218+
200,
219+
json=prediction_with_status("failed"),
220+
)
221+
)
222+
router.route(
223+
method="GET",
224+
path="/models/test/example/versions/v1",
225+
).mock(
226+
return_value=httpx.Response(
227+
201,
228+
json={
229+
"id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1",
230+
"created_at": "2024-07-18T00:35:56.210272Z",
231+
"cog_version": "0.9.10",
232+
"openapi_schema": {
233+
"openapi": "3.0.2",
234+
},
235+
},
236+
)
237+
)
238+
router.route(host="api.replicate.com").pass_through()
239+
240+
client = Client(
241+
api_token="test-token", transport=httpx.MockTransport(router.handler)
242+
)
243+
client.poll_interval = 0.001
244+
245+
with pytest.raises(ModelError) as excinfo:
246+
client.run(
247+
"test/example:v1",
248+
input={
249+
"text": "Hello, world!",
250+
},
251+
)
252+
253+
assert str(excinfo.value) == "OOM"
254+
assert excinfo.value.prediction.error == "OOM"
255+
assert excinfo.value.prediction.status == "failed"

0 commit comments

Comments
 (0)