Skip to content

fix: Enable users to opt-in to allowing extra fields in Tesseract schemas by setting extra="allow" #117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 8, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions tesseract_core/runtime/schema_generation.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@
Any,
ClassVar,
Literal,
Optional,
TypeVar,
Union,
get_args,
@@ -65,7 +66,7 @@ def apply_function_to_model_tree(
Schema: type[BaseModel],
func: Callable[[type, tuple], type],
model_prefix: str = "",
model_kwargs: Any = None,
default_model_config: Optional[dict[str, Any]] = None,
) -> type[BaseModel]:
"""Apply a function to all leaves of a Pydantic model, recursing into containers + nested models.

@@ -87,8 +88,8 @@ class MyModel(BaseModel):
# Annotation types that should be treated as leaves and not recursed into
annotated_types_as_leaves = (PydanticArrayAnnotation,)

if model_kwargs is None:
model_kwargs = {}
if default_model_config is None:
default_model_config = {}

seen_models = set()

@@ -134,10 +135,15 @@ def _recurse_over_model_tree(treeobj: Any, path: list[str]) -> Any:
# We only forbid encountering the same model twice if it is within the same subtree
seen_models.remove(id(treeobj))

# Only override model_config if it is not already present
# in the pydantic model definition
model_config = ConfigDict(**default_model_config)
model_config.update(treeobj.model_config)

return create_model(
f"{model_prefix}{treeobj.__name__}",
**new_fields,
**model_kwargs,
model_config=(ConfigDict, model_config),
__base__=treeobj,
)

@@ -271,13 +277,13 @@ def create_apply_schema(
InputSchema,
lambda x, _: x,
model_prefix="Apply_",
model_kwargs={"model_config": (ConfigDict, ConfigDict(extra="forbid"))},
default_model_config=dict(extra="forbid"),
)
OutputSchema = apply_function_to_model_tree(
OutputSchema,
lambda x, _: x,
model_prefix="Apply_",
model_kwargs={"model_config": (ConfigDict, ConfigDict(extra="forbid"))},
default_model_config=dict(extra="forbid"),
)

class ApplyInputSchema(BaseModel):
@@ -325,14 +331,14 @@ def replace_array_with_shapedtype(obj: T, _: Any) -> Union[T, type[ShapeDType]]:
InputSchema,
replace_array_with_shapedtype,
model_prefix="AbstractEval_",
model_kwargs={"model_config": (ConfigDict, ConfigDict(extra="forbid"))},
default_model_config=dict(extra="forbid"),
)

GeneratedOutputSchema = apply_function_to_model_tree(
OutputSchema,
replace_array_with_shapedtype,
model_prefix="AbstractEval_",
model_kwargs={"model_config": (ConfigDict, ConfigDict(extra="forbid"))},
default_model_config=dict(extra="forbid"),
)

class AbstractInputSchema(BaseModel):
@@ -463,7 +469,7 @@ def _find_shape_from_path(path_patterns: dict, concrete_path: str) -> tuple:
InputSchema,
lambda x, _: x,
model_prefix=f"{ad_flavor.title()}_",
model_kwargs={"model_config": (ConfigDict, ConfigDict(extra="forbid"))},
default_model_config=dict(extra="forbid"),
)

def result_validator(
27 changes: 26 additions & 1 deletion tests/runtime_tests/test_schema_generation.py
Original file line number Diff line number Diff line change
@@ -8,11 +8,12 @@

import numpy as np
import pytest
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel, ConfigDict, ValidationError

from tesseract_core.runtime import Array, Differentiable, Float32, Float64, Int64, UInt8
from tesseract_core.runtime.experimental import LazySequence
from tesseract_core.runtime.schema_generation import (
apply_function_to_model_tree,
create_abstract_eval_schema,
create_apply_schema,
create_autodiff_schema,
@@ -688,3 +689,27 @@ def test_json_schema(endpoint):
# Test that the JSON schema is valid JSON
json.dumps(schema_inputs)
json.dumps(schema_outputs)


def test_model_config_extra_forbid():
class Child(BaseModel):
x: str
model_config: ConfigDict = ConfigDict(extra="allow")

class Parent(BaseModel):
child: Child

ApplyParent = apply_function_to_model_tree(
Parent,
lambda x, y: x,
default_model_config=dict(extra="forbid"),
)
ApplyChild = ApplyParent.model_fields["child"].annotation
assert ApplyChild.model_config["extra"] == "allow"
assert ApplyParent.model_config["extra"] == "forbid"

ApplyParent.model_validate({"child": {"x": "foo"}})
ApplyParent.model_validate({"child": {"x": "foo", "extra": 1}})

with pytest.raises(ValidationError):
ApplyParent.model_validate({"child": {"x": "foo"}, "extra": 1})
Loading