Skip to content

Commit a4b09e4

Browse files
authored
Add safe root model handling for path searching in json codec (#34)
1 parent 3372b3c commit a4b09e4

File tree

5 files changed

+397
-12
lines changed

5 files changed

+397
-12
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "asyncapi-python"
3-
version = "0.3.0rc8"
3+
version = "0.3.0rc9"
44
license = { text = "Apache-2.0" }
55
description = "Easily generate type-safe and async Python applications from AsyncAPI 3 specifications."
66
authors = [{ name = "Yaroslav Petrov", email = "[email protected]" }]

src/asyncapi_python/contrib/codec/json.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import json
22
from enum import Enum
33
from types import ModuleType
4-
from typing import ClassVar, Type
4+
from typing import Any, ClassVar, Type
55

6-
from pydantic import BaseModel, ValidationError
6+
from pydantic import BaseModel, RootModel, ValidationError
77

88
from asyncapi_python.kernel.codec import Codec, CodecFactory
99
from asyncapi_python.kernel.document.message import Message
@@ -31,13 +31,16 @@ def decode(self, payload: bytes) -> BaseModel:
3131
def extract_field(self, payload: BaseModel, location: str) -> str:
3232
"""Extract field from Pydantic model using location path.
3333
34+
Handles both regular BaseModel and RootModel wrappers. RootModel instances
35+
are automatically unwrapped (recursively) to access the underlying data.
36+
3437
Examples:
3538
"$message.payload#/userId" → payload.userId → "123"
3639
"$message.payload#/user/id" → payload.user.id → "456"
3740
"$message.payload#/items" → payload.items → "[1, 2, 3]"
3841
3942
Args:
40-
payload: Pydantic BaseModel instance
43+
payload: Pydantic BaseModel instance (may be RootModel wrapper)
4144
location: Location expression like "$message.payload#/userId"
4245
4346
Returns:
@@ -56,18 +59,29 @@ def extract_field(self, payload: BaseModel, location: str) -> str:
5659
parts = [p for p in path.split("/") if p]
5760

5861
try:
59-
value = payload
62+
value: Any = payload
6063
for part in parts:
61-
value = getattr(value, part)
64+
# Recursively unwrap any RootModel wrappers before accessing attributes
65+
while isinstance(value, RootModel):
66+
value = value.root # type: ignore[assignment, misc]
67+
value = getattr(value, part) # type: ignore[arg-type]
68+
69+
# Unwrap final value if it's a RootModel
70+
while isinstance(value, RootModel):
71+
value = value.root # type: ignore[assignment, misc]
6272

6373
# Convert to string
64-
if isinstance(value, (str, int, float, bool)):
65-
return str(value)
66-
elif isinstance(value, Enum):
74+
# Check Enum FIRST (before str/int/etc) because str/int Enums are also instances of str/int
75+
if isinstance(value, Enum):
6776
# Handle Enum types - extract the value attribute
6877
return str(value.value)
78+
elif isinstance(value, (str, int, float, bool)):
79+
return str(value)
80+
elif isinstance(value, BaseModel):
81+
# Pydantic models: dump to dict then JSON serialize
82+
return json.dumps(value.model_dump())
6983
else:
70-
# Complex types: JSON serialize
84+
# Other complex types: JSON serialize directly
7185
return json.dumps(value)
7286

7387
except AttributeError as e:
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
"""Tests for JSON codec extract_field() method with RootModel support"""
2+
3+
from enum import Enum
4+
5+
import pytest
6+
from pydantic import BaseModel, RootModel
7+
8+
from asyncapi_python.contrib.codec.json import JsonCodec
9+
10+
11+
# Test models
12+
class SimpleMessage(BaseModel):
13+
"""Regular BaseModel for testing"""
14+
15+
chat_id: int
16+
message: str
17+
18+
19+
class NestedUser(BaseModel):
20+
"""Nested model for path traversal testing"""
21+
22+
id: str
23+
name: str
24+
25+
26+
class MessageWithNested(BaseModel):
27+
"""Model with nested fields"""
28+
29+
user: NestedUser
30+
content: str
31+
32+
33+
class Severity(str, Enum):
34+
"""Enum for testing enum extraction"""
35+
36+
LOW = "low"
37+
MEDIUM = "medium"
38+
HIGH = "high"
39+
40+
41+
class MessageWithEnum(BaseModel):
42+
"""Model with enum field"""
43+
44+
severity: Severity
45+
description: str
46+
47+
48+
class ComplexData(BaseModel):
49+
"""Complex nested data"""
50+
51+
items: list[str]
52+
metadata: dict[str, str]
53+
54+
55+
class MessageWithComplex(BaseModel):
56+
"""Model with complex types"""
57+
58+
data: ComplexData
59+
60+
61+
# RootModel wrappers
62+
class SimpleRootModel(RootModel[SimpleMessage]):
63+
"""Single-level RootModel wrapper"""
64+
65+
root: SimpleMessage
66+
67+
68+
class InnerRootModel(RootModel[NestedUser]):
69+
"""Inner RootModel for nested testing"""
70+
71+
root: NestedUser
72+
73+
74+
class OuterMessageWithRootModel(BaseModel):
75+
"""Message containing a RootModel field"""
76+
77+
user: InnerRootModel
78+
content: str
79+
80+
81+
class DoubleRootModel(RootModel[SimpleRootModel]):
82+
"""Nested RootModel (RootModel containing RootModel)"""
83+
84+
root: SimpleRootModel
85+
86+
87+
# Tests
88+
def test_extract_field_from_base_model():
89+
"""Test extracting fields from regular BaseModel"""
90+
codec = JsonCodec(SimpleMessage)
91+
message = SimpleMessage(chat_id=123, message="hello")
92+
93+
result = codec.extract_field(message, "$message.payload#/chat_id")
94+
assert result == "123"
95+
96+
result = codec.extract_field(message, "$message.payload#/message")
97+
assert result == "hello"
98+
99+
100+
def test_extract_field_from_root_model():
101+
"""Test extracting fields from single-level RootModel wrapper"""
102+
codec = JsonCodec(SimpleRootModel)
103+
wrapped = SimpleRootModel.model_validate({"chat_id": 456, "message": "world"})
104+
105+
# Should unwrap RootModel and access fields on the root
106+
result = codec.extract_field(wrapped, "$message.payload#/chat_id")
107+
assert result == "456"
108+
109+
result = codec.extract_field(wrapped, "$message.payload#/message")
110+
assert result == "world"
111+
112+
113+
def test_extract_field_from_nested_root_model():
114+
"""Test extracting fields from nested RootModel (RootModel containing RootModel)"""
115+
codec = JsonCodec(DoubleRootModel)
116+
117+
# Create nested RootModel: DoubleRootModel -> SimpleRootModel -> SimpleMessage
118+
inner = SimpleRootModel.model_validate({"chat_id": 789, "message": "nested"})
119+
wrapped = DoubleRootModel.model_validate(inner.model_dump())
120+
121+
# Should recursively unwrap both RootModel layers
122+
result = codec.extract_field(wrapped, "$message.payload#/chat_id")
123+
assert result == "789"
124+
125+
result = codec.extract_field(wrapped, "$message.payload#/message")
126+
assert result == "nested"
127+
128+
129+
def test_extract_field_nested_path():
130+
"""Test extracting nested fields using path like $message.payload#/user/id"""
131+
codec = JsonCodec(MessageWithNested)
132+
message = MessageWithNested(
133+
user=NestedUser(id="user123", name="Alice"), content="test"
134+
)
135+
136+
result = codec.extract_field(message, "$message.payload#/user/id")
137+
assert result == "user123"
138+
139+
result = codec.extract_field(message, "$message.payload#/user/name")
140+
assert result == "Alice"
141+
142+
143+
def test_extract_field_nested_path_with_root_model():
144+
"""Test extracting nested fields when intermediate field is a RootModel"""
145+
codec = JsonCodec(OuterMessageWithRootModel)
146+
147+
# The user field is a RootModel wrapper
148+
user_wrapped = InnerRootModel.model_validate({"id": "user456", "name": "Bob"})
149+
message = OuterMessageWithRootModel(user=user_wrapped, content="test")
150+
151+
# Should unwrap the RootModel at the intermediate step
152+
result = codec.extract_field(message, "$message.payload#/user/id")
153+
assert result == "user456"
154+
155+
result = codec.extract_field(message, "$message.payload#/user/name")
156+
assert result == "Bob"
157+
158+
159+
def test_extract_field_enum_value():
160+
"""Test extracting enum values (should return the enum value, not the enum object)"""
161+
codec = JsonCodec(MessageWithEnum)
162+
message = MessageWithEnum(severity=Severity.HIGH, description="critical alert")
163+
164+
result = codec.extract_field(message, "$message.payload#/severity")
165+
assert result == "high" # Should extract the value, not "Severity.HIGH"
166+
167+
168+
def test_extract_field_complex_type():
169+
"""Test extracting complex types (should JSON serialize)"""
170+
codec = JsonCodec(MessageWithComplex)
171+
message = MessageWithComplex(
172+
data=ComplexData(items=["a", "b", "c"], metadata={"key": "value"})
173+
)
174+
175+
result = codec.extract_field(message, "$message.payload#/data")
176+
# Should be JSON serialized
177+
assert '"items": ["a", "b", "c"]' in result
178+
assert '"metadata": {"key": "value"}' in result
179+
180+
181+
def test_extract_field_invalid_location():
182+
"""Test error handling for invalid location format"""
183+
codec = JsonCodec(SimpleMessage)
184+
message = SimpleMessage(chat_id=123, message="hello")
185+
186+
with pytest.raises(ValueError, match="Invalid location format"):
187+
codec.extract_field(message, "invalid/location")
188+
189+
with pytest.raises(ValueError, match="Invalid location format"):
190+
codec.extract_field(message, "#/chat_id")
191+
192+
193+
def test_extract_field_missing_path():
194+
"""Test error handling for non-existent paths"""
195+
codec = JsonCodec(SimpleMessage)
196+
message = SimpleMessage(chat_id=123, message="hello")
197+
198+
with pytest.raises(ValueError, match="Path 'nonexistent' not found in payload"):
199+
codec.extract_field(message, "$message.payload#/nonexistent")
200+
201+
202+
def test_extract_field_missing_nested_path():
203+
"""Test error handling for non-existent nested paths"""
204+
codec = JsonCodec(MessageWithNested)
205+
message = MessageWithNested(
206+
user=NestedUser(id="user123", name="Alice"), content="test"
207+
)
208+
209+
with pytest.raises(
210+
ValueError, match="Path 'user/nonexistent' not found in payload"
211+
):
212+
codec.extract_field(message, "$message.payload#/user/nonexistent")
213+
214+
215+
def test_extract_field_primitive_types() -> None:
216+
"""Test extraction returns proper string representations of primitive types"""
217+
218+
class PrimitiveMessage(BaseModel):
219+
str_field: str
220+
int_field: int
221+
float_field: float
222+
bool_field: bool
223+
224+
codec = JsonCodec(PrimitiveMessage)
225+
message = PrimitiveMessage(
226+
str_field="test", int_field=42, float_field=3.14, bool_field=True
227+
)
228+
229+
assert codec.extract_field(message, "$message.payload#/str_field") == "test"
230+
assert codec.extract_field(message, "$message.payload#/int_field") == "42"
231+
assert codec.extract_field(message, "$message.payload#/float_field") == "3.14"
232+
assert codec.extract_field(message, "$message.payload#/bool_field") == "True"

0 commit comments

Comments
 (0)