Skip to content

Commit e1c38ff

Browse files
authored
Merge pull request #22 from lmignon/main-fix-model-as-query-args
Fix the use of extendable model as query params into fastapi
2 parents eb352a2 + 80df42d commit e1c38ff

File tree

4 files changed

+139
-1
lines changed

4 files changed

+139
-1
lines changed

pyproject.toml

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ requires-python = ">=3.7"
2424
test = [
2525
"pytest",
2626
"coverage[toml]",
27+
"fastapi>=0.111",
28+
"httpx",
2729
]
2830
mypy = [
2931
"mypy>=1.4.1",

src/extendable_pydantic/_patch.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
def _resolve_model_fields_annotation(model_fields):
2626
registry = context.extendable_registry.get()
27-
if registry:
27+
if registry and registry.ready:
2828
for field in model_fields:
2929
field_info = field.field_info
3030
new_type = resolve_annotation(field_info.annotation)
@@ -77,3 +77,19 @@ def _create_response_field_wrapper(wrapped, instance, args, kwargs):
7777
wrapt.wrap_function_wrapper(
7878
utils, "create_model_field", _create_response_field_wrapper
7979
)
80+
81+
82+
@wrapt.when_imported("fastapi.dependencies.utils")
83+
def hook_fastapi_dependencies_utils(utils):
84+
def _analyze_param_wrapper(wrapped, instance, args, kwargs):
85+
registry = context.extendable_registry.get()
86+
if registry and registry.ready:
87+
annotation = kwargs.get("annotation")
88+
if annotation:
89+
new_type = resolve_annotation(annotation)
90+
if not all_identical(annotation, new_type):
91+
kwargs["annotation"] = new_type
92+
return wrapped(*args, **kwargs)
93+
94+
if hasattr(utils, "analyze_param"):
95+
wrapt.wrap_function_wrapper(utils, "analyze_param", _analyze_param_wrapper)

tests/conftest.py

+63
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
1+
from extendable_pydantic import _patch # noqa: F401
12
import pytest
23
import sys
34
from extendable import context, main, registry
5+
from fastapi import FastAPI, APIRouter
6+
from fastapi.testclient import TestClient
7+
8+
if sys.version_info >= (3, 9):
9+
from typing import Annotated
10+
else:
11+
from typing_extensions import Annotated
12+
13+
from fastapi import Depends
14+
from extendable_pydantic import ExtendableBaseModel
415

516

617
skip_not_supported_version_for_generics = pytest.mark.skipif(
@@ -21,3 +32,55 @@ def test_registry() -> registry.ExtendableClassesRegistry:
2132
finally:
2233
main._extendable_class_defs_by_module = initial_class_defs
2334
context.extendable_registry.reset(token)
35+
36+
37+
@pytest.fixture
38+
def test_fastapi(test_registry) -> TestClient:
39+
app = FastAPI()
40+
my_router = APIRouter()
41+
42+
class TestRequest(ExtendableBaseModel):
43+
name: str = "rqst"
44+
45+
def get_type(self) -> str:
46+
return "request"
47+
48+
class TestResponse(ExtendableBaseModel):
49+
name: str = "resp"
50+
51+
def get_type(self) -> str:
52+
return "response"
53+
54+
@my_router.get("/")
55+
def get() -> TestResponse:
56+
"""Get method."""
57+
resp = TestResponse(name="World")
58+
assert hasattr(resp, "id")
59+
return resp
60+
61+
@my_router.post("/")
62+
def post(rqst: TestRequest) -> TestResponse:
63+
"""Post method."""
64+
resp = TestResponse(**rqst.model_dump())
65+
assert hasattr(resp, "id")
66+
return resp
67+
68+
@my_router.get("/extended")
69+
def get_with_params(rqst: Annotated[TestRequest, Depends()]) -> TestResponse:
70+
"""Get method with parameters."""
71+
resp = TestResponse(**rqst.model_dump())
72+
assert hasattr(resp, "id")
73+
return resp
74+
75+
class ExtendedTestRequest(TestRequest, extends=TestRequest):
76+
id: int = 1
77+
78+
class ExtendedTestResponse(TestResponse, extends=TestResponse):
79+
id: int = 2
80+
81+
test_registry.init_registry()
82+
83+
app.include_router(my_router)
84+
85+
with TestClient(app) as client:
86+
yield client

tests/test_fastapi.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""Test fastapi integration."""
2+
3+
4+
def test_open_api_schema(test_fastapi):
5+
client = test_fastapi
6+
response = client.get("/openapi.json")
7+
assert response.status_code == 200, response.text
8+
schema = response.json()
9+
rqst_schema = schema["components"]["schemas"]["TestRequest"]
10+
assert rqst_schema["properties"] == {
11+
"name": {"title": "Name", "type": "string", "default": "rqst"},
12+
"id": {"title": "Id", "type": "integer", "default": 1},
13+
}
14+
resp_schema = schema["components"]["schemas"]["TestResponse"]
15+
assert resp_schema["properties"] == {
16+
"name": {"title": "Name", "type": "string", "default": "resp"},
17+
"id": {"title": "Id", "type": "integer", "default": 2},
18+
}
19+
20+
extended_get_params = schema["paths"]["/extended"]["get"]["parameters"]
21+
assert len(extended_get_params) == 2
22+
assert extended_get_params[0] == {
23+
"in": "query",
24+
"name": "name",
25+
"required": False,
26+
"schema": {"title": "Name", "type": "string", "default": "rqst"},
27+
}
28+
assert extended_get_params[1] == {
29+
"in": "query",
30+
"name": "id",
31+
"required": False,
32+
"schema": {"title": "Id", "type": "integer", "default": 1},
33+
}
34+
35+
36+
def test_extended_response(test_fastapi):
37+
"""Test extended pydantic model as response."""
38+
client = test_fastapi
39+
response = client.get("/")
40+
assert response.status_code == 200
41+
assert response.json() == {"name": "World", "id": 2}
42+
43+
44+
def test_extended_request(test_fastapi):
45+
"""Test extended pydantic model as json request."""
46+
client = test_fastapi
47+
response = client.post("/", json={"name": "Hello", "id": 3})
48+
assert response.status_code == 200
49+
assert response.json() == {"name": "Hello", "id": 3}
50+
51+
52+
def test_extended_request_with_params(test_fastapi):
53+
"""Test extended pydantic model as request with parameters."""
54+
client = test_fastapi
55+
response = client.get("/extended", params={"name": "echo", "id": 3})
56+
assert response.status_code == 200
57+
assert response.json() == {"name": "echo", "id": 3}

0 commit comments

Comments
 (0)