Skip to content

Commit 38d34ec

Browse files
authored
Support Optional in Input (#2216)
* Add integration tests for pydantic 2 output * Fix test case * Add mode logging to invalid output * Add support for Optionals in Input * Remove Optional import from typing_extensions * Handle python 3.8/3.9 * Fix > 3.9 check * Fix >= 3.10 * Use None.__class__ to replace NoneType and type(None) * Check if optional is Union
1 parent f943b68 commit 38d34ec

File tree

6 files changed

+156
-6
lines changed

6 files changed

+156
-6
lines changed

python/cog/predictor.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
import types
88
import uuid
99
from collections.abc import Iterable, Iterator
10+
11+
if sys.version_info >= (3, 10):
12+
from types import NoneType
1013
from typing import (
1114
Any,
1215
Callable,
@@ -190,8 +193,20 @@ def validate_input_type(
190193
elif get_origin(type) in (Union, List, list) or (
191194
hasattr(types, "UnionType") and get_origin(type) is types.UnionType
192195
): # noqa: E721
193-
for t in get_args(type):
194-
validate_input_type(t, name)
196+
args = get_args(type)
197+
198+
def is_optional() -> bool:
199+
if len(args) != 2 or get_origin(type) is not Union:
200+
return False
201+
if sys.version_info >= (3, 10):
202+
return args[1] is NoneType
203+
return args[1] is None.__class__
204+
205+
if is_optional():
206+
validate_input_type(args[0], name)
207+
else:
208+
for t in args:
209+
validate_input_type(t, name)
195210
else:
196211
if PYDANTIC_V2:
197212
# Cog types are exported as `Annotated[Type, ...]`, but `type` is the inner type

python/cog/server/http.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ async def _predict(
469469
try:
470470
_ = response_type(**response_object)
471471
except ValidationError as e:
472-
_log_invalid_output(e)
472+
_log_invalid_output(e, mode)
473473
raise HTTPException(status_code=500, detail=str(e)) from e
474474

475475
response_object["output"] = upload_files(
@@ -520,17 +520,20 @@ def _maybe_shutdown(exc: BaseException, *, status: Health = Health.DEFUNCT) -> N
520520
return app
521521

522522

523-
def _log_invalid_output(error: Any) -> None:
523+
def _log_invalid_output(error: Any, mode: Mode) -> None:
524+
function_name = "predict()"
525+
if mode == Mode.TRAIN:
526+
function_name = "train()"
524527
log.error(
525528
textwrap.dedent(
526529
f"""\
527-
The return value of predict() was not valid:
530+
The return value of {function_name} was not valid:
528531
529532
{error}
530533
531534
Check that your predict function is in this form, where `output_type` is the same as the type you are returning (e.g. `str`):
532535
533-
def predict(...) -> output_type:
536+
def {function_name} -> output_type:
534537
...
535538
"""
536539
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Configuration for Cog ⚙️
2+
# Reference: https://cog.run/yaml
3+
4+
build:
5+
# set to true if your model requires a GPU
6+
gpu: false
7+
8+
# python version in the form '3.11' or '3.11.4'
9+
python_version: "3.12"
10+
11+
# a list of packages in the format <package-name>==<version>
12+
python_packages:
13+
- "pydantic==2.10.6" # The problematic Pydantic version
14+
15+
# predict.py defines how predictions are run on your model
16+
predict: "predict.py:Predictor"
17+
18+
# train.py defines how training runs on your model
19+
train: "train.py:train"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Prediction interface for Cog ⚙️
2+
# https://cog.run/python
3+
4+
from cog import BasePredictor, Input, Path
5+
6+
7+
class Predictor(BasePredictor):
8+
def setup(self) -> None:
9+
"""Load the model into memory to make running multiple predictions efficient"""
10+
# self.model = torch.load("./weights.pth")
11+
pass
12+
13+
def predict(
14+
self,
15+
image: Path = Input(description="Grayscale input image"),
16+
scale: float = Input(
17+
description="Factor to scale image by", ge=0, le=10, default=1.5
18+
),
19+
) -> Path:
20+
"""Run a single prediction on the model"""
21+
# processed_input = preprocess(image)
22+
# output = self.model(processed_image, scale)
23+
# return postprocess(output)
24+
return Path(".")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import os
2+
from typing import Optional
3+
from cog import BaseModel, Input, Path as CogPath, Secret
4+
5+
# We return a path to our trained adapter weights
6+
class TrainingOutput(BaseModel):
7+
weights: CogPath
8+
9+
def train(
10+
# Basic input
11+
some_input: str = Input(
12+
description="A basic string input to satisfy minimum requirements.",
13+
default="default value",
14+
),
15+
# String input with None default (problematic)
16+
hf_repo_id: Optional[str] = Input(
17+
description="String with None default - this causes issues.",
18+
default=None,
19+
),
20+
# Secret with None default (problematic)
21+
hf_token: Optional[Secret] = Input(
22+
description="Secret with None default - this also causes issues.",
23+
default=None,
24+
),
25+
# String input with empty string default (works)
26+
working_repo_id: str = Input(
27+
description="String with empty string default - this works.",
28+
default="",
29+
),
30+
# Secret with empty string default (works)
31+
working_token: Secret = Input(
32+
description="Secret with empty string default - this works.",
33+
default="",
34+
),
35+
) -> TrainingOutput:
36+
"""
37+
Minimal example to demonstrate issues with Secret inputs.
38+
"""
39+
print("\n=== Minimal Cog Secret Test ===")
40+
print(f"cog version: {os.environ.get('COG_VERSION', 'unknown')}")
41+
42+
# Inputs with None defaults
43+
print("\n-- Inputs with None defaults (problematic) --")
44+
print(f"hf_repo_id: {hf_repo_id}")
45+
if hf_token:
46+
print(f"hf_token: [PROVIDED]")
47+
try:
48+
value = hf_token.get_secret_value()
49+
print("Secret access successful")
50+
except Exception as e:
51+
print(f"Error accessing secret: {e}")
52+
else:
53+
print("hf_token: None")
54+
55+
# Inputs with empty string defaults
56+
print("\n-- Inputs with empty string defaults (works) --")
57+
print(f"working_repo_id: {working_repo_id if working_repo_id else '(empty)'}")
58+
if working_token and working_token.get_secret_value():
59+
print(f"working_token: [PROVIDED]")
60+
try:
61+
value = working_token.get_secret_value()
62+
print("Secret access successful")
63+
except Exception as e:
64+
print(f"Error accessing secret: {e}")
65+
else:
66+
print("working_token: (empty)")
67+
68+
# Create a dummy output file
69+
output_path = "dummy_output.txt"
70+
with open(output_path, "w") as f:
71+
f.write("This is a dummy output file.")
72+
73+
print("\n=== Test Complete ===")
74+
75+
# Return the dummy output path
76+
return TrainingOutput(weights=CogPath(output_path))

test-integration/test_integration/test_train.py

+13
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,16 @@ def test_train_takes_input_and_produces_weights(tmpdir_factory):
1919
with open(out_dir / "weights.bin", "rb") as f:
2020
assert len(f.read()) == 42
2121
assert "falling back to slow loader" not in str(result.stderr)
22+
23+
24+
def test_train_pydantic2(tmpdir_factory):
25+
project_dir = Path(__file__).parent / "fixtures/pydantic2-output"
26+
out_dir = pathlib.Path(tmpdir_factory.mktemp("project"))
27+
shutil.copytree(project_dir, out_dir, dirs_exist_ok=True)
28+
result = subprocess.run(
29+
["cog", "train", "--debug", "-i", 'some_input="hello"'],
30+
cwd=out_dir,
31+
check=False,
32+
capture_output=True,
33+
)
34+
assert result.returncode == 0

0 commit comments

Comments
 (0)