Skip to content

Commit fd5fc46

Browse files
nmheimdionhaefner
andauthored
fix: Enable users to opt-in to allowing extra fields in Tesseract schemas by setting extra="allow" (#117)
#### Relevant issue or PR Pydantic models inside e.g. `InputSchema` that have `extra="allow"` failed with extra fields. #### Description of changes Only override with `extra="forbid"` if no `model_config` is given. #### Testing done Added a test. #### License - [x] By submitting this pull request, I confirm that my contribution is made under the terms of the [Apache 2.0 license](https://pasteurlabs.github.io/tesseract/LICENSE). - [x] I sign the Developer Certificate of Origin below by adding my name and email address to the `Signed-off-by` line. <details> <summary><b>Developer Certificate of Origin</b></summary> ```text Developer Certificate of Origin Version 1.1 Copyright (C) 2004, 2006 The Linux Foundation and its contributors. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. Developer's Certificate of Origin 1.1 By making a contribution to this project, I certify that: (a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or (b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or (c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it. (d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved. ``` </details> Signed-off-by: [Niklas Heim] <[[email protected]]> --------- Co-authored-by: Dion Häfner <[email protected]>
1 parent ceb6deb commit fd5fc46

File tree

2 files changed

+41
-10
lines changed

2 files changed

+41
-10
lines changed

tesseract_core/runtime/schema_generation.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Any,
1212
ClassVar,
1313
Literal,
14+
Optional,
1415
TypeVar,
1516
Union,
1617
get_args,
@@ -65,7 +66,7 @@ def apply_function_to_model_tree(
6566
Schema: type[BaseModel],
6667
func: Callable[[type, tuple], type],
6768
model_prefix: str = "",
68-
model_kwargs: Any = None,
69+
default_model_config: Optional[dict[str, Any]] = None,
6970
) -> type[BaseModel]:
7071
"""Apply a function to all leaves of a Pydantic model, recursing into containers + nested models.
7172
@@ -87,8 +88,8 @@ class MyModel(BaseModel):
8788
# Annotation types that should be treated as leaves and not recursed into
8889
annotated_types_as_leaves = (PydanticArrayAnnotation,)
8990

90-
if model_kwargs is None:
91-
model_kwargs = {}
91+
if default_model_config is None:
92+
default_model_config = {}
9293

9394
seen_models = set()
9495

@@ -134,10 +135,15 @@ def _recurse_over_model_tree(treeobj: Any, path: list[str]) -> Any:
134135
# We only forbid encountering the same model twice if it is within the same subtree
135136
seen_models.remove(id(treeobj))
136137

138+
# Only override model_config if it is not already present
139+
# in the pydantic model definition
140+
model_config = ConfigDict(**default_model_config)
141+
model_config.update(treeobj.model_config)
142+
137143
return create_model(
138144
f"{model_prefix}{treeobj.__name__}",
139145
**new_fields,
140-
**model_kwargs,
146+
model_config=(ConfigDict, model_config),
141147
__base__=treeobj,
142148
)
143149

@@ -271,13 +277,13 @@ def create_apply_schema(
271277
InputSchema,
272278
lambda x, _: x,
273279
model_prefix="Apply_",
274-
model_kwargs={"model_config": (ConfigDict, ConfigDict(extra="forbid"))},
280+
default_model_config=dict(extra="forbid"),
275281
)
276282
OutputSchema = apply_function_to_model_tree(
277283
OutputSchema,
278284
lambda x, _: x,
279285
model_prefix="Apply_",
280-
model_kwargs={"model_config": (ConfigDict, ConfigDict(extra="forbid"))},
286+
default_model_config=dict(extra="forbid"),
281287
)
282288

283289
class ApplyInputSchema(BaseModel):
@@ -325,14 +331,14 @@ def replace_array_with_shapedtype(obj: T, _: Any) -> Union[T, type[ShapeDType]]:
325331
InputSchema,
326332
replace_array_with_shapedtype,
327333
model_prefix="AbstractEval_",
328-
model_kwargs={"model_config": (ConfigDict, ConfigDict(extra="forbid"))},
334+
default_model_config=dict(extra="forbid"),
329335
)
330336

331337
GeneratedOutputSchema = apply_function_to_model_tree(
332338
OutputSchema,
333339
replace_array_with_shapedtype,
334340
model_prefix="AbstractEval_",
335-
model_kwargs={"model_config": (ConfigDict, ConfigDict(extra="forbid"))},
341+
default_model_config=dict(extra="forbid"),
336342
)
337343

338344
class AbstractInputSchema(BaseModel):
@@ -463,7 +469,7 @@ def _find_shape_from_path(path_patterns: dict, concrete_path: str) -> tuple:
463469
InputSchema,
464470
lambda x, _: x,
465471
model_prefix=f"{ad_flavor.title()}_",
466-
model_kwargs={"model_config": (ConfigDict, ConfigDict(extra="forbid"))},
472+
default_model_config=dict(extra="forbid"),
467473
)
468474

469475
def result_validator(

tests/runtime_tests/test_schema_generation.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88

99
import numpy as np
1010
import pytest
11-
from pydantic import BaseModel, ValidationError
11+
from pydantic import BaseModel, ConfigDict, ValidationError
1212

1313
from tesseract_core.runtime import Array, Differentiable, Float32, Float64, Int64, UInt8
1414
from tesseract_core.runtime.experimental import LazySequence
1515
from tesseract_core.runtime.schema_generation import (
16+
apply_function_to_model_tree,
1617
create_abstract_eval_schema,
1718
create_apply_schema,
1819
create_autodiff_schema,
@@ -688,3 +689,27 @@ def test_json_schema(endpoint):
688689
# Test that the JSON schema is valid JSON
689690
json.dumps(schema_inputs)
690691
json.dumps(schema_outputs)
692+
693+
694+
def test_model_config_extra_forbid():
695+
class Child(BaseModel):
696+
x: str
697+
model_config: ConfigDict = ConfigDict(extra="allow")
698+
699+
class Parent(BaseModel):
700+
child: Child
701+
702+
ApplyParent = apply_function_to_model_tree(
703+
Parent,
704+
lambda x, y: x,
705+
default_model_config=dict(extra="forbid"),
706+
)
707+
ApplyChild = ApplyParent.model_fields["child"].annotation
708+
assert ApplyChild.model_config["extra"] == "allow"
709+
assert ApplyParent.model_config["extra"] == "forbid"
710+
711+
ApplyParent.model_validate({"child": {"x": "foo"}})
712+
ApplyParent.model_validate({"child": {"x": "foo", "extra": 1}})
713+
714+
with pytest.raises(ValidationError):
715+
ApplyParent.model_validate({"child": {"x": "foo"}, "extra": 1})

0 commit comments

Comments
 (0)