Skip to content

Commit 4270233

Browse files
authored
Fix schema inference (#28)
* Fix incorrect dynamic reply queue routing * Update lockfile * Fix pyright * Update resolver logic * Move service_name to Application from wire * Drop service name param from amqp wire * Move service name to endpoint params * Add default timeout * Update version * Add endpoint_params to application * Add test for recursive spec * Fix recursion capability for schema parsing * Increment rc version * Run lints
1 parent 8769a84 commit 4270233

File tree

8 files changed

+251
-31
lines changed

8 files changed

+251
-31
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "asyncapi-python"
3-
version = "0.3.0rc4"
3+
version = "0.3.0rc5"
44
license = { text = "Apache-2.0" }
55
description = "Easily generate type-safe and async Python applications from AsyncAPI 3 specifications."
66
authors = [{ name = "Yaroslav Petrov", email = "[email protected]" }]

src/asyncapi_python_codegen/generators/messages.py

Lines changed: 112 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from datamodel_code_generator.__main__ import main as datamodel_codegen
1111

1212
from asyncapi_python.kernel.document import Operation
13+
from asyncapi_python_codegen.parser.types import ParseContext, navigate_json_pointer
1314

1415

1516
class MessageGenerator:
@@ -67,35 +68,110 @@ def _collect_message_schemas(
6768
return schemas # type: ignore[return-value]
6869

6970
def _load_component_schemas(self, spec_path: Path) -> dict[str, Any]:
70-
"""Load component schemas from the AsyncAPI specification file."""
71-
try:
72-
with spec_path.open("r") as f:
73-
spec = yaml.safe_load(f)
71+
"""Load component schemas from the AsyncAPI specification file and all referenced files."""
72+
all_schemas: dict[str, Any] = {}
73+
visited_files: set[Path] = set()
7474

75-
components = spec.get("components", {})
76-
schemas = components.get("schemas", {})
77-
messages = components.get("messages", {})
75+
def load_schemas_from_file(file_path: Path) -> None:
76+
"""Recursively load schemas from a file and its references."""
77+
abs_path = file_path.absolute()
7878

79-
# Combine schemas and message payloads
80-
all_schemas = {}
79+
# Avoid infinite loops
80+
if abs_path in visited_files:
81+
return
82+
visited_files.add(abs_path)
8183

82-
# Add component schemas directly
83-
for schema_name, schema_def in schemas.items():
84-
all_schemas[schema_name] = schema_def
84+
try:
85+
with abs_path.open("r") as f:
86+
spec = yaml.safe_load(f)
8587

86-
# Add message payloads from components (only if not already present from schemas)
87-
for msg_name, msg_def in messages.items():
88-
if isinstance(msg_def, dict) and "payload" in msg_def:
89-
schema_name = self._to_pascal_case(msg_name)
90-
# Only add if we don't already have this schema from the schemas section
88+
components = spec.get("components", {})
89+
schemas = components.get("schemas", {})
90+
messages = components.get("messages", {})
91+
92+
# Add component schemas directly
93+
for schema_name, schema_def in schemas.items():
9194
if schema_name not in all_schemas:
92-
all_schemas[schema_name] = msg_def["payload"]
95+
# Check if this schema is itself a reference
96+
if isinstance(schema_def, dict) and "$ref" in schema_def:
97+
ref_value: Any = schema_def["$ref"] # type: ignore[misc]
98+
# Resolve the reference using ParseContext utilities
99+
if isinstance(ref_value, str):
100+
try:
101+
context = ParseContext(abs_path)
102+
target_context = context.resolve_reference(
103+
ref_value
104+
)
105+
106+
# Load and navigate to the referenced schema
107+
with target_context.filepath.open("r") as ref_file:
108+
ref_spec = yaml.safe_load(ref_file)
109+
110+
if target_context.json_pointer:
111+
resolved_schema = navigate_json_pointer(
112+
ref_spec, target_context.json_pointer
113+
)
114+
else:
115+
resolved_schema = ref_spec
116+
117+
all_schemas[schema_name] = resolved_schema
118+
except Exception as e:
119+
print(
120+
f"Warning: Could not resolve reference {ref_value} in {abs_path}: {e}"
121+
)
122+
all_schemas[schema_name] = schema_def
123+
else:
124+
all_schemas[schema_name] = schema_def
125+
126+
# Add message payloads from components
127+
for msg_name, msg_def in messages.items():
128+
if isinstance(msg_def, dict) and "payload" in msg_def:
129+
schema_name = self._to_pascal_case(msg_name)
130+
if schema_name not in all_schemas:
131+
all_schemas[schema_name] = msg_def["payload"]
132+
133+
# Find and process all external file references
134+
self._find_and_process_refs(
135+
spec, abs_path.parent, load_schemas_from_file
136+
)
137+
138+
except Exception as e:
139+
print(f"Warning: Could not load component schemas from {abs_path}: {e}")
140+
141+
# Start loading from the main spec file
142+
load_schemas_from_file(spec_path)
143+
144+
return all_schemas # type: ignore[return-value]
145+
146+
def _find_and_process_refs(
147+
self, data: Any, base_dir: Path, process_file: Any
148+
) -> None:
149+
"""Recursively find all $ref entries pointing to external files."""
150+
if isinstance(data, dict):
151+
# Check if this is a reference
152+
if "$ref" in data:
153+
ref_value: Any = data["$ref"] # type: ignore[misc]
154+
if isinstance(ref_value, str) and not ref_value.startswith("#"):
155+
# External reference - extract file path
156+
file_part: str
157+
if "#" in ref_value:
158+
file_part = ref_value.split("#")[0]
159+
else:
160+
file_part = ref_value
161+
162+
if file_part:
163+
# Resolve relative path
164+
ref_path = (base_dir / file_part).resolve()
165+
process_file(ref_path)
93166

94-
return all_schemas # type: ignore[return-value]
167+
# Recurse into all dict values
168+
for value in data.values(): # type: ignore[misc]
169+
self._find_and_process_refs(value, base_dir, process_file)
95170

96-
except Exception as e:
97-
print(f"Warning: Could not load component schemas from {spec_path}: {e}")
98-
return {}
171+
elif isinstance(data, list):
172+
# Recurse into all list items
173+
for item in data: # type: ignore[misc]
174+
self._find_and_process_refs(item, base_dir, process_file)
99175

100176
def _resolve_references(self, schemas: dict[str, Any]) -> dict[str, Any]:
101177
"""Recursively resolve $ref references to use #/$defs/... instead of #/components/schemas/..."""
@@ -105,17 +181,24 @@ def resolve_in_object(obj: Any) -> Any:
105181
resolved_obj: dict[str, Any] = {}
106182
for key, value in obj.items(): # type: ignore[misc]
107183
if key == "$ref" and isinstance(value, str):
108-
# Transform references from #/components/schemas/... to #/$defs/...
109-
if value.startswith("#/components/schemas/"):
110-
schema_name = value.split("/")[-1]
184+
# Extract schema name from the reference
185+
schema_name = value.split("/")[-1]
186+
187+
# Transform all component references to #/$defs/...
188+
if "#/components/schemas/" in value:
189+
# Internal or external schema reference
111190
resolved_obj[key] = f"#/$defs/{schema_name}"
112-
elif value.startswith("#/components/messages/"):
191+
elif "#/components/messages/" in value:
113192
# Handle message references - convert message name to PascalCase
114-
msg_name = value.split("/")[-1]
115-
schema_name = self._to_pascal_case(msg_name)
193+
schema_name = self._to_pascal_case(schema_name)
116194
resolved_obj[key] = f"#/$defs/{schema_name}"
117-
else:
195+
elif value.startswith("#"):
196+
# Other internal references, keep as-is
118197
resolved_obj[key] = value
198+
else:
199+
# External file reference (e.g., "./commons2.yaml#/components/schemas/Foo")
200+
# Extract just the schema name and point to #/$defs
201+
resolved_obj[key] = f"#/$defs/{schema_name}"
119202
else:
120203
resolved_obj[key] = resolve_in_object(value)
121204
return resolved_obj
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Level 1: Main entry point
2+
asyncapi: "3.0.0"
3+
info:
4+
title: Deep Recursion Test - Level 1
5+
version: 1.0.0
6+
description: Main file that starts the 4-level reference chain
7+
8+
operations:
9+
process.data:
10+
action: send
11+
channel:
12+
$ref: "level2.yaml#/channels/data_channel"
13+
14+
components:
15+
schemas:
16+
Level1Schema:
17+
type: object
18+
properties:
19+
level:
20+
type: integer
21+
const: 1
22+
message:
23+
type: string
24+
const: "from_level_1"
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Level 2: References Level 3
2+
channels:
3+
data_channel:
4+
address: data.queue
5+
title: Data Channel from Level 2
6+
messages:
7+
data_message:
8+
$ref: "level3.yaml#/components/messages/DataMessage"
9+
10+
components:
11+
schemas:
12+
Level2Schema:
13+
type: object
14+
properties:
15+
level:
16+
type: integer
17+
const: 2
18+
message:
19+
type: string
20+
const: "from_level_2"
21+
level1_ref:
22+
$ref: "level1.yaml#/components/schemas/Level1Schema"
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Level 3: References Level 4
2+
components:
3+
messages:
4+
DataMessage:
5+
title: Data Message from Level 3
6+
payload:
7+
type: object
8+
properties:
9+
id:
10+
type: string
11+
level3_data:
12+
type: string
13+
const: "from_level_3"
14+
deep_schema:
15+
$ref: "level4.yaml#/components/schemas/Level4Schema"
16+
17+
schemas:
18+
Level3Schema:
19+
type: object
20+
properties:
21+
level:
22+
type: integer
23+
const: 3
24+
message:
25+
type: string
26+
const: "from_level_3"
27+
level2_ref:
28+
$ref: "level2.yaml#/components/schemas/Level2Schema"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Level 4: Deepest level - no more references
2+
components:
3+
schemas:
4+
Level4Schema:
5+
type: object
6+
properties:
7+
level:
8+
type: integer
9+
const: 4
10+
message:
11+
type: string
12+
const: "from_level_4_deepest"
13+
metadata:
14+
type: object
15+
properties:
16+
depth:
17+
type: integer
18+
const: 4
19+
status:
20+
type: string
21+
enum: ["deep", "deeper", "deepest"]

tests/codegen/test_parser.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,3 +264,45 @@ def test_invalid_yaml_structure():
264264
extract_all_operations(invalid_yaml)
265265
finally:
266266
invalid_yaml.unlink(missing_ok=True)
267+
268+
269+
def test_four_level_deep_recursion():
270+
"""Test 4-level deep file reference chain: Level1->Level2->Level3->Level4.
271+
272+
This test verifies that the MessageGenerator recursively collects component schemas
273+
from all referenced files, not just the main spec file.
274+
"""
275+
from src.asyncapi_python_codegen.generators.messages import MessageGenerator
276+
277+
spec_path = Path("tests/codegen/specs/deep_recursion/level1.yaml")
278+
279+
# Test that MessageGenerator collects schemas from all 4 levels
280+
generator = MessageGenerator()
281+
schemas = generator._load_component_schemas(spec_path)
282+
283+
# Without recursive file loading, we would only get Level1Schema
284+
# With recursive loading, we should get schemas from all 4 files
285+
assert "Level1Schema" in schemas, "Level1Schema from main file not found"
286+
assert (
287+
"Level2Schema" in schemas
288+
), "Level2Schema from level2.yaml not found (recursive loading failed)"
289+
assert (
290+
"Level3Schema" in schemas
291+
), "Level3Schema from level3.yaml not found (recursive loading failed)"
292+
assert (
293+
"Level4Schema" in schemas
294+
), "Level4Schema from level4.yaml not found (recursive loading failed)"
295+
assert "DataMessage" in schemas, "DataMessage from level3.yaml not found"
296+
297+
# Verify the deepest level schema has correct structure
298+
level4_schema = schemas["Level4Schema"]
299+
assert level4_schema["properties"]["level"]["const"] == 4
300+
assert level4_schema["properties"]["message"]["const"] == "from_level_4_deepest"
301+
302+
# Also verify operations can be extracted (tests parser, not generator)
303+
operations = extract_all_operations(spec_path)
304+
assert len(operations) == 1
305+
306+
process_data = operations["process.data"]
307+
assert process_data.channel.address == "data.queue"
308+
assert process_data.channel.title == "Data Channel from Level 2"

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)