Skip to content

Commit 4dbb4ec

Browse files
committed
Switch to using file outputs and blocking api by default
1 parent 5fe2d54 commit 4dbb4ec

File tree

3 files changed

+22
-11
lines changed

3 files changed

+22
-11
lines changed

replicate/prediction.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -395,11 +395,13 @@ class CreatePredictionParams(TypedDict):
395395

396396
wait: NotRequired[Union[int, bool]]
397397
"""
398-
Wait until the prediction is completed before returning.
398+
Block until the prediction is completed before returning.
399399
400-
If `True`, wait a predetermined number of seconds until the prediction
401-
is completed before returning.
402-
If an `int`, wait for the specified number of seconds.
400+
If `True`, keep the request open for up to 60 seconds, falling back to
401+
polling until the prediction is completed.
402+
If an `int`, same as True but hold the request for a specified number of
403+
seconds (between 1 and 60).
404+
If `False`, poll for the prediction status until completed.
403405
"""
404406

405407
file_encoding_strategy: NotRequired[FileEncodingStrategy]

replicate/run.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,18 @@ def run(
2929
client: "Client",
3030
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
3131
input: Optional[Dict[str, Any]] = None,
32-
use_file_output: Optional[bool] = None,
32+
*,
33+
use_file_output: Optional[bool] = True,
3334
**params: Unpack["Predictions.CreatePredictionParams"],
3435
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
3536
"""
3637
Run a model and wait for its output.
3738
"""
3839

39-
is_blocking = "wait" in params
40+
if "wait" not in params:
41+
params["wait"] = True
42+
is_blocking = params["wait"] != False # noqa: E712
43+
4044
version, owner, name, version_id = identifier._resolve(ref)
4145

4246
if version_id is not None:
@@ -74,13 +78,18 @@ async def async_run(
7478
client: "Client",
7579
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
7680
input: Optional[Dict[str, Any]] = None,
77-
use_file_output: Optional[bool] = None,
81+
*,
82+
use_file_output: Optional[bool] = True,
7883
**params: Unpack["Predictions.CreatePredictionParams"],
7984
) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401
8085
"""
8186
Run a model and wait for its output asynchronously.
8287
"""
8388

89+
if "wait" not in params:
90+
params["wait"] = True
91+
is_blocking = params["wait"] != False # noqa: E712
92+
8493
version, owner, name, version_id = identifier._resolve(ref)
8594

8695
if version or version_id:

tests/test_run.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def prediction_with_status(status: str) -> dict:
123123
router.route(method="POST", path="/predictions").mock(
124124
return_value=httpx.Response(
125125
201,
126-
json=prediction_with_status("processing"),
126+
json=prediction_with_status("starting"),
127127
)
128128
)
129129
router.route(method="GET", path="/predictions/p1").mock(
@@ -212,7 +212,7 @@ def prediction_with_status(status: str) -> dict:
212212
router.route(method="POST", path="/predictions").mock(
213213
return_value=httpx.Response(
214214
201,
215-
json=prediction_with_status("processing"),
215+
json=prediction_with_status("starting"),
216216
)
217217
)
218218
router.route(method="GET", path="/predictions/p1").mock(
@@ -454,7 +454,7 @@ def prediction_with_status(
454454
router.route(method="POST", path="/predictions").mock(
455455
return_value=httpx.Response(
456456
201,
457-
json=prediction_with_status("processing"),
457+
json=prediction_with_status("starting"),
458458
)
459459
)
460460
router.route(method="GET", path="/predictions/p1").mock(
@@ -541,7 +541,7 @@ def prediction_with_status(
541541
router.route(method="POST", path="/predictions").mock(
542542
return_value=httpx.Response(
543543
201,
544-
json=prediction_with_status("processing"),
544+
json=prediction_with_status("starting"),
545545
)
546546
)
547547
router.route(method="GET", path="/predictions/p1").mock(

0 commit comments

Comments
 (0)