Skip to content

Commit ea06e01

Browse files
Fix codegen for fields with dashes (#163)
Why === Some new services require `-` `.` and `:` in field names What changed ============ When rendering a `TypeName` its value is normalized, and a new `LiteralType` dataclass was added to avoid that normalization. Known limitation: this new normalization only works on output types. Test plan ========= Run codegen on a schema with a field that has a `-` in its name
1 parent 7e9519f commit ea06e01

File tree

9 files changed

+177
-20
lines changed

9 files changed

+177
-20
lines changed

src/replit_river/codegen/client.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import re
33
import subprocess
4+
from collections import defaultdict
45
from pathlib import Path
56
from textwrap import dedent
67
from typing import (
@@ -24,6 +25,7 @@
2425
FileContents,
2526
HandshakeType,
2627
ListTypeExpr,
28+
LiteralType,
2729
LiteralTypeExpr,
2830
ModuleName,
2931
NoneTypeExpr,
@@ -33,6 +35,7 @@
3335
TypeName,
3436
UnionTypeExpr,
3537
extract_inner_type,
38+
normalize_special_chars,
3639
render_literal_type,
3740
render_type_expr,
3841
)
@@ -396,9 +399,12 @@ def {_field_name}(
396399
case NoneTypeExpr():
397400
typeddict_encoder.append("None")
398401
case other:
399-
_o2: DictTypeExpr | OpenUnionTypeExpr | UnionTypeExpr = (
400-
other
401-
)
402+
_o2: (
403+
DictTypeExpr
404+
| OpenUnionTypeExpr
405+
| UnionTypeExpr
406+
| LiteralType
407+
) = other
402408
raise ValueError(f"What does it mean to have {_o2} here?")
403409
if permit_unknown_members:
404410
union = _make_open_union_type_expr(any_of)
@@ -491,7 +497,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
491497
return (NoneTypeExpr(), [], [], set())
492498
elif type.type == "Date":
493499
typeddict_encoder.append("TODO: dstewart")
494-
return (TypeName("datetime.datetime"), [], [], set())
500+
return (LiteralType("datetime.datetime"), [], [], set())
495501
elif type.type == "array" and type.items:
496502
type_name, module_info, type_chunks, encoder_names = encode_type(
497503
type.items,
@@ -524,6 +530,9 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
524530
# lambda x: ... vs lambda _: {}
525531
needs_binding = False
526532
encoder_names = set()
533+
# Track effective field names to detect collisions after normalization
534+
# Maps effective name -> list of original field names
535+
effective_field_names: defaultdict[str, list[str]] = defaultdict(list)
527536
if type.properties:
528537
needs_binding = True
529538
typeddict_encoder.append("{")
@@ -653,19 +662,37 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
653662
value = ""
654663
if base_model != "TypedDict":
655664
value = f"= {field_value}"
665+
# Track $kind -> "kind" mapping for collision detection
666+
effective_field_names["kind"].append(name)
667+
656668
current_chunks.append(
657669
f" kind: Annotated[{render_type_expr(type_name)}, Field(alias={
658670
repr(name)
659671
})]{value}"
660672
)
661673
else:
674+
specialized_name = normalize_special_chars(name)
675+
effective_name = name
676+
extras = []
677+
if name != specialized_name:
678+
if base_model != "BaseModel":
679+
# TODO: alias support for TypedDict
680+
raise ValueError(
681+
f"Field {name} is not a valid Python identifier, but it is in the schema" # noqa: E501
682+
)
683+
# Pydantic doesn't allow leading underscores in field names
684+
effective_name = specialized_name.lstrip("_")
685+
extras.append(f"alias={repr(name)}")
686+
687+
effective_field_names[effective_name].append(name)
688+
662689
if name not in type.required:
663690
if base_model == "TypedDict":
664691
current_chunks.append(
665692
reindent(
666693
" ",
667694
f"""\
668-
{name}: NotRequired[{
695+
{effective_name}: NotRequired[{
669696
render_type_expr(
670697
UnionTypeExpr([type_name, NoneTypeExpr()])
671698
)
@@ -674,11 +701,13 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
674701
)
675702
)
676703
else:
704+
extras.append("default=None")
705+
677706
current_chunks.append(
678707
reindent(
679708
" ",
680709
f"""\
681-
{name}: {
710+
{effective_name}: {
682711
render_type_expr(
683712
UnionTypeExpr(
684713
[
@@ -687,15 +716,30 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
687716
]
688717
)
689718
)
690-
} = None
719+
} = Field({", ".join(extras)})
691720
""",
692721
)
693722
)
694723
else:
724+
extras_str = ""
725+
if len(extras) != 0:
726+
extras_str = f" = Field({', '.join(extras)})"
727+
695728
current_chunks.append(
696-
f" {name}: {render_type_expr(type_name)}"
729+
f" {effective_name}: {render_type_expr(type_name)}{extras_str}" # noqa: E501
697730
)
698731
typeddict_encoder.append(",")
732+
733+
# Check for field name collisions after processing all fields
734+
for effective_name, original_names in effective_field_names.items():
735+
if len(original_names) > 1:
736+
error_msg = (
737+
f"Field name collision: fields {original_names} all normalize "
738+
f"to the same effective name '{effective_name}'"
739+
)
740+
741+
raise ValueError(error_msg)
742+
699743
typeddict_encoder.append("}")
700744
# exclude_none
701745
typeddict_encoder = (

src/replit_river/codegen/typing.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from dataclasses import dataclass
22
from typing import NewType, assert_never, cast
33

4+
SPECIAL_CHARS = [".", "-", ":", "/", "@", " ", "$", "!", "?", "=", "&", "|", "~", "`"]
5+
46
ModuleName = NewType("ModuleName", str)
57
ClassName = NewType("ClassName", str)
68
FileContents = NewType("FileContents", str)
@@ -23,6 +25,20 @@ def __lt__(self, other: object) -> bool:
2325
return hash(self) < hash(other)
2426

2527

28+
@dataclass(frozen=True)
29+
class LiteralType:
30+
value: str
31+
32+
def __str__(self) -> str:
33+
raise Exception("Complex type must be put through render_type_expr!")
34+
35+
def __eq__(self, other: object) -> bool:
36+
return isinstance(other, LiteralType) and other.value == self.value
37+
38+
def __lt__(self, other: object) -> bool:
39+
return hash(self) < hash(other)
40+
41+
2642
@dataclass(frozen=True)
2743
class NoneTypeExpr:
2844
def __str__(self) -> str:
@@ -111,6 +127,7 @@ def __lt__(self, other: object) -> bool:
111127

112128
TypeExpression = (
113129
TypeName
130+
| LiteralType
114131
| NoneTypeExpr
115132
| DictTypeExpr
116133
| ListTypeExpr
@@ -145,6 +162,12 @@ def work(
145162
raise ValueError("Incoherent state when trying to flatten unions")
146163

147164

165+
def normalize_special_chars(value: str) -> str:
166+
for char in SPECIAL_CHARS:
167+
value = value.replace(char, "_")
168+
return value
169+
170+
148171
def render_type_expr(value: TypeExpression) -> str:
149172
match _flatten_nested_unions(value):
150173
case DictTypeExpr(nested):
@@ -192,7 +215,9 @@ def render_type_expr(value: TypeExpression) -> str:
192215
"]"
193216
)
194217
case TypeName(name):
195-
return name
218+
return normalize_special_chars(name)
219+
case LiteralType(literal_value):
220+
return literal_value
196221
case NoneTypeExpr():
197222
return "None"
198223
case other:
@@ -223,6 +248,10 @@ def extract_inner_type(value: TypeExpression) -> TypeName:
223248
)
224249
case TypeName(name):
225250
return TypeName(name)
251+
case LiteralType(name):
252+
raise ValueError(
253+
f"Attempting to extract from a literal type: {repr(value)}"
254+
)
226255
case NoneTypeExpr():
227256
raise ValueError(
228257
f"Attempting to extract from a literal 'None': {repr(value)}",

tests/conftest.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from datetime import datetime, timezone
12
from typing import Any, Literal, Mapping
23

34
import nanoid
@@ -55,7 +56,11 @@ def deserialize_request(request: dict) -> str:
5556

5657

5758
def serialize_response(response: str) -> dict:
58-
return {"data": response}
59+
return {
60+
"data": response,
61+
"data2": datetime.now(timezone.utc),
62+
"data-3": {"data-test": "test"},
63+
}
5964

6065

6166
def deserialize_response(response: dict) -> str:

tests/v1/codegen/rpc/generated/test_service/rpc_method.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def encode_Rpc_MethodInput(
3030
for (k, v) in (
3131
{
3232
"data": x.get("data"),
33+
"data2": x.get("data2"),
3334
}
3435
).items()
3536
if v is not None
@@ -38,10 +39,17 @@ def encode_Rpc_MethodInput(
3839

3940
class Rpc_MethodInput(TypedDict):
4041
data: str
42+
data2: datetime.datetime
43+
44+
45+
class Rpc_MethodOutputData_3(BaseModel):
46+
data_test: str | None = Field(alias="data-test", default=None)
4147

4248

4349
class Rpc_MethodOutput(BaseModel):
4450
data: str
51+
data_3: Rpc_MethodOutputData_3 = Field(alias="data-3")
52+
data2: datetime.datetime
4553

4654

4755
Rpc_MethodOutputTypeAdapter: TypeAdapter[Rpc_MethodOutput] = TypeAdapter(
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
{
2+
"services": {
3+
"test_service": {
4+
"procedures": {
5+
"rpc_method": {
6+
"input": {
7+
"type": "boolean"
8+
},
9+
"output": {
10+
"type": "object",
11+
"properties": {
12+
"data:3": {
13+
"type": "Date"
14+
},
15+
"data-3": {
16+
"type": "boolean"
17+
}
18+
},
19+
"required": ["data:3"]
20+
},
21+
"errors": {
22+
"not": {}
23+
},
24+
"type": "rpc"
25+
}
26+
}
27+
}
28+
}
29+
}
30+

tests/v1/codegen/rpc/schema.json

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,33 @@
88
"properties": {
99
"data": {
1010
"type": "string"
11+
},
12+
"data2": {
13+
"type": "Date"
1114
}
1215
},
13-
"required": ["data"]
16+
"required": ["data", "data2"]
1417
},
1518
"output": {
1619
"type": "object",
1720
"properties": {
1821
"data": {
1922
"type": "string"
23+
},
24+
"data2": {
25+
"type": "Date"
26+
},
27+
"data-3": {
28+
"type": "object",
29+
"properties": {
30+
"data-test": {
31+
"type": "string"
32+
}
33+
},
34+
"required": []
2035
}
2136
},
22-
"required": ["data"]
37+
"required": ["data", "data2", "data-3"]
2338
},
2439
"errors": {
2540
"not": {}

tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class NeedsenumobjectOutputFooOneOf_out_second(BaseModel):
9696

9797

9898
class NeedsenumobjectOutput(BaseModel):
99-
foo: NeedsenumobjectOutputFoo | None = None
99+
foo: NeedsenumobjectOutputFoo | None = Field(default=None)
100100

101101

102102
NeedsenumobjectOutputTypeAdapter: TypeAdapter[NeedsenumobjectOutput] = TypeAdapter(
@@ -105,11 +105,11 @@ class NeedsenumobjectOutput(BaseModel):
105105

106106

107107
class NeedsenumobjectErrorsFooAnyOf_0(BaseModel):
108-
beep: Literal["err_first"] | None = None
108+
beep: Literal["err_first"] | None = Field(default=None)
109109

110110

111111
class NeedsenumobjectErrorsFooAnyOf_1(BaseModel):
112-
borp: Literal["err_second"] | None = None
112+
borp: Literal["err_second"] | None = Field(default=None)
113113

114114

115115
NeedsenumobjectErrorsFoo = Annotated[
@@ -121,7 +121,7 @@ class NeedsenumobjectErrorsFooAnyOf_1(BaseModel):
121121

122122

123123
class NeedsenumobjectErrors(RiverError):
124-
foo: NeedsenumobjectErrorsFoo | None = None
124+
foo: NeedsenumobjectErrorsFoo | None = Field(default=None)
125125

126126

127127
NeedsenumobjectErrorsTypeAdapter: TypeAdapter[NeedsenumobjectErrors] = TypeAdapter(
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from io import StringIO
2+
3+
import pytest
4+
5+
from replit_river.codegen.client import schema_to_river_client_codegen
6+
7+
8+
def test_field_name_collision_error() -> None:
9+
"""Test that codegen raises ValueError for field name collisions."""
10+
11+
with pytest.raises(ValueError) as exc_info:
12+
schema_to_river_client_codegen(
13+
read_schema=lambda: open("tests/v1/codegen/rpc/invalid-schema.json"),
14+
target_path="tests/v1/codegen/rpc/generated",
15+
client_name="InvalidClient",
16+
typed_dict_inputs=True,
17+
file_opener=lambda _: StringIO(),
18+
method_filter=None,
19+
protocol_version="v1.1",
20+
)
21+
22+
# Check that the error message matches the expected format for field name collision
23+
error_message = str(exc_info.value)
24+
assert "Field name collision" in error_message
25+
assert "data:3" in error_message
26+
assert "data-3" in error_message
27+
assert "all normalize to the same effective name 'data_3'" in error_message

0 commit comments

Comments
 (0)