Skip to content

Commit 0647cf7

Browse files
committed
Fix codegen for fields with dashes
1 parent 7e9519f commit 0647cf7

File tree

2 files changed

+53
-6
lines changed

2 files changed

+53
-6
lines changed

src/replit_river/codegen/client.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
FileContents,
2525
HandshakeType,
2626
ListTypeExpr,
27+
LiteralType,
2728
LiteralTypeExpr,
2829
ModuleName,
2930
NoneTypeExpr,
@@ -33,6 +34,7 @@
3334
TypeName,
3435
UnionTypeExpr,
3536
extract_inner_type,
37+
normalize_special_chars,
3638
render_literal_type,
3739
render_type_expr,
3840
)
@@ -396,9 +398,12 @@ def {_field_name}(
396398
case NoneTypeExpr():
397399
typeddict_encoder.append("None")
398400
case other:
399-
_o2: DictTypeExpr | OpenUnionTypeExpr | UnionTypeExpr = (
400-
other
401-
)
401+
_o2: (
402+
DictTypeExpr
403+
| OpenUnionTypeExpr
404+
| UnionTypeExpr
405+
| LiteralType
406+
) = other
402407
raise ValueError(f"What does it mean to have {_o2} here?")
403408
if permit_unknown_members:
404409
union = _make_open_union_type_expr(any_of)
@@ -491,7 +496,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
491496
return (NoneTypeExpr(), [], [], set())
492497
elif type.type == "Date":
493498
typeddict_encoder.append("TODO: dstewart")
494-
return (TypeName("datetime.datetime"), [], [], set())
499+
return (LiteralType("datetime.datetime"), [], [], set())
495500
elif type.type == "array" and type.items:
496501
type_name, module_info, type_chunks, encoder_names = encode_type(
497502
type.items,
@@ -692,8 +697,21 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
692697
)
693698
)
694699
else:
700+
specialized_name = normalize_special_chars(name)
701+
effective_name = name
702+
extras = ""
703+
if name != specialized_name:
704+
if base_model != "BaseModel":
705+
# TODO: alias support for TypedDict
706+
raise ValueError(
707+
f"Field {name} is not a valid Python identifier, but it is in the schema" # noqa: E501
708+
)
709+
# Pydantic doesn't allow leading underscores in field names
710+
effective_name = specialized_name.lstrip("_")
711+
extras = f" = Field(serialization_alias={repr(name)})"
712+
695713
current_chunks.append(
696-
f" {name}: {render_type_expr(type_name)}"
714+
f" {effective_name}: {render_type_expr(type_name)}{extras}"
697715
)
698716
typeddict_encoder.append(",")
699717
typeddict_encoder.append("}")

src/replit_river/codegen/typing.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from dataclasses import dataclass
22
from typing import NewType, assert_never, cast
33

4+
SPECIAL_CHARS = [".", "-", ":", "/", "@", " ", "$", "!", "?", "=", "&", "|", "~", "`"]
5+
46
ModuleName = NewType("ModuleName", str)
57
ClassName = NewType("ClassName", str)
68
FileContents = NewType("FileContents", str)
@@ -23,6 +25,20 @@ def __lt__(self, other: object) -> bool:
2325
return hash(self) < hash(other)
2426

2527

28+
@dataclass(frozen=True)
29+
class LiteralType:
30+
value: str
31+
32+
def __str__(self) -> str:
33+
raise Exception("Complex type must be put through render_type_expr!")
34+
35+
def __eq__(self, other: object) -> bool:
36+
return isinstance(other, LiteralType) and other.value == self.value
37+
38+
def __lt__(self, other: object) -> bool:
39+
return hash(self) < hash(other)
40+
41+
2642
@dataclass(frozen=True)
2743
class NoneTypeExpr:
2844
def __str__(self) -> str:
@@ -111,6 +127,7 @@ def __lt__(self, other: object) -> bool:
111127

112128
TypeExpression = (
113129
TypeName
130+
| LiteralType
114131
| NoneTypeExpr
115132
| DictTypeExpr
116133
| ListTypeExpr
@@ -145,6 +162,12 @@ def work(
145162
raise ValueError("Incoherent state when trying to flatten unions")
146163

147164

165+
def normalize_special_chars(value: str) -> str:
166+
for char in SPECIAL_CHARS:
167+
value = value.replace(char, "_")
168+
return value
169+
170+
148171
def render_type_expr(value: TypeExpression) -> str:
149172
match _flatten_nested_unions(value):
150173
case DictTypeExpr(nested):
@@ -192,7 +215,9 @@ def render_type_expr(value: TypeExpression) -> str:
192215
"]"
193216
)
194217
case TypeName(name):
195-
return name
218+
return normalize_special_chars(name)
219+
case LiteralType(literal_value):
220+
return literal_value
196221
case NoneTypeExpr():
197222
return "None"
198223
case other:
@@ -223,6 +248,10 @@ def extract_inner_type(value: TypeExpression) -> TypeName:
223248
)
224249
case TypeName(name):
225250
return TypeName(name)
251+
case LiteralType(name):
252+
raise ValueError(
253+
f"Attempting to extract from a literal type: {repr(value)}"
254+
)
226255
case NoneTypeExpr():
227256
raise ValueError(
228257
f"Attempting to extract from a literal 'None': {repr(value)}",

0 commit comments

Comments
 (0)