Skip to content

Commit 42de571

Browse files
[bug] Avoiding mixing type rendering and unstructured codegen (#146)
Why === Introduced in #142, `NotRequired[A] | None` doesn't typecheck, but we didn't have any codegen tests that exhausted that path. What changed ============ Move the statically rendered `${render(...)} | None` into `${render(UnionTypeExpr([..., NoneTypeExpr()]))}` so the type renderer knows about and can properly unify/deduplicate/render the structured union. Test plan ========= TODO: Add a test that makes it so we don't regress here in the future.
1 parent c2fe359 commit 42de571

File tree

11 files changed

+1012
-47
lines changed

11 files changed

+1012
-47
lines changed

src/replit_river/codegen/client.py

+27-3
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,16 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
643643
"""
644644
)
645645
current_chunks.append(
646-
f" kind: {render_type_expr(type_name)} | None{value}"
646+
f" kind: {
647+
render_type_expr(
648+
UnionTypeExpr(
649+
[
650+
type_name,
651+
NoneTypeExpr(),
652+
]
653+
)
654+
)
655+
}{value}"
647656
)
648657
else:
649658
value = ""
@@ -666,7 +675,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
666675
reindent(
667676
" ",
668677
f"""\
669-
{name}: NotRequired[{render_type_expr(type_name)}] | None
678+
{name}: NotRequired[{
679+
render_type_expr(
680+
UnionTypeExpr([type_name, NoneTypeExpr()])
681+
)
682+
}]
670683
""",
671684
)
672685
)
@@ -675,7 +688,16 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
675688
reindent(
676689
" ",
677690
f"""\
678-
{name}: {render_type_expr(type_name)} | None = None
691+
{name}: {
692+
render_type_expr(
693+
UnionTypeExpr(
694+
[
695+
type_name,
696+
NoneTypeExpr(),
697+
]
698+
)
699+
)
700+
} = None
679701
""",
680702
)
681703
)
@@ -1246,6 +1268,8 @@ def schema_to_river_client_codegen(
12461268
stdout=subprocess.PIPE,
12471269
)
12481270
stdout, _ = popen.communicate(contents.encode())
1271+
if popen.returncode != 0:
1272+
f.write(contents)
12491273
f.write(stdout.decode("utf-8"))
12501274
except:
12511275
f.write(contents)

src/replit_river/codegen/typing.py

+56
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,24 @@ class TypeName:
1616
def __str__(self) -> str:
1717
raise Exception("Complex type must be put through render_type_expr!")
1818

19+
def __eq__(self, other: object) -> bool:
20+
return isinstance(other, TypeName) and other.value == self.value
21+
22+
def __lt__(self, other: object) -> bool:
23+
return hash(self) < hash(other)
24+
1925

2026
@dataclass(frozen=True)
2127
class NoneTypeExpr:
2228
def __str__(self) -> str:
2329
raise Exception("Complex type must be put through render_type_expr!")
2430

31+
def __eq__(self, other: object) -> bool:
32+
return isinstance(other, NoneTypeExpr)
33+
34+
def __lt__(self, other: object) -> bool:
35+
return hash(self) < hash(other)
36+
2537

2638
@dataclass(frozen=True)
2739
class DictTypeExpr:
@@ -30,6 +42,12 @@ class DictTypeExpr:
3042
def __str__(self) -> str:
3143
raise Exception("Complex type must be put through render_type_expr!")
3244

45+
def __eq__(self, other: object) -> bool:
46+
return isinstance(other, DictTypeExpr) and other.nested == self.nested
47+
48+
def __lt__(self, other: object) -> bool:
49+
return hash(self) < hash(other)
50+
3351

3452
@dataclass(frozen=True)
3553
class ListTypeExpr:
@@ -38,6 +56,12 @@ class ListTypeExpr:
3856
def __str__(self) -> str:
3957
raise Exception("Complex type must be put through render_type_expr!")
4058

59+
def __eq__(self, other: object) -> bool:
60+
return isinstance(other, ListTypeExpr) and other.nested == self.nested
61+
62+
def __lt__(self, other: object) -> bool:
63+
return hash(self) < hash(other)
64+
4165

4266
@dataclass(frozen=True)
4367
class LiteralTypeExpr:
@@ -46,6 +70,12 @@ class LiteralTypeExpr:
4670
def __str__(self) -> str:
4771
raise Exception("Complex type must be put through render_type_expr!")
4872

73+
def __eq__(self, other: object) -> bool:
74+
return isinstance(other, LiteralTypeExpr) and other.nested == self.nested
75+
76+
def __lt__(self, other: object) -> bool:
77+
return hash(self) < hash(other)
78+
4979

5080
@dataclass(frozen=True)
5181
class UnionTypeExpr:
@@ -54,6 +84,14 @@ class UnionTypeExpr:
5484
def __str__(self) -> str:
5585
raise Exception("Complex type must be put through render_type_expr!")
5686

87+
def __eq__(self, other: object) -> bool:
88+
return isinstance(other, UnionTypeExpr) and set(other.nested) == set(
89+
self.nested
90+
)
91+
92+
def __lt__(self, other: object) -> bool:
93+
return hash(self) < hash(other)
94+
5795

5896
@dataclass(frozen=True)
5997
class OpenUnionTypeExpr:
@@ -62,6 +100,12 @@ class OpenUnionTypeExpr:
62100
def __str__(self) -> str:
63101
raise Exception("Complex type must be put through render_type_expr!")
64102

103+
def __eq__(self, other: object) -> bool:
104+
return isinstance(other, OpenUnionTypeExpr) and other.union == self.union
105+
106+
def __lt__(self, other: object) -> bool:
107+
return hash(self) < hash(other)
108+
65109

66110
TypeExpression = (
67111
TypeName
@@ -117,13 +161,25 @@ def render_type_expr(value: TypeExpression) -> str:
117161
literals.append(tpe)
118162
else:
119163
_other.append(tpe)
164+
165+
without_none: list[TypeExpression] = [
166+
x for x in _other if not isinstance(x, NoneTypeExpr)
167+
]
168+
has_none = len(_other) > len(without_none)
169+
_other = without_none
170+
120171
retval: str = " | ".join(render_type_expr(x) for x in _other)
121172
if literals:
122173
_rendered: str = ", ".join(repr(x.nested) for x in literals)
123174
if retval:
124175
retval = f"Literal[{_rendered}] | {retval}"
125176
else:
126177
retval = f"Literal[{_rendered}]"
178+
if has_none:
179+
if retval:
180+
retval = f"{retval} | None"
181+
else:
182+
retval = "None"
127183
return retval
128184
case OpenUnionTypeExpr(inner):
129185
return (

src/replit_river/rpc.py

-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
)
1919

2020
import grpc
21-
import grpc.aio
2221
from aiochannel import Channel, ChannelClosed
2322
from opentelemetry.propagators.textmap import Setter
2423
from pydantic import BaseModel, ConfigDict, Field
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from io import StringIO
2+
from pathlib import Path
3+
from typing import Callable, TextIO
4+
5+
from pytest_snapshot.plugin import Snapshot
6+
7+
from replit_river.codegen.client import schema_to_river_client_codegen
8+
9+
10+
class UnclosableStringIO(StringIO):
11+
def close(self) -> None:
12+
pass
13+
14+
15+
def validate_codegen(
16+
*,
17+
snapshot: Snapshot,
18+
read_schema: Callable[[], TextIO],
19+
target_path: str,
20+
client_name: str,
21+
) -> None:
22+
snapshot.snapshot_dir = "tests/codegen/snapshot/snapshots"
23+
files: dict[Path, UnclosableStringIO] = {}
24+
25+
def file_opener(path: Path) -> TextIO:
26+
buffer = UnclosableStringIO()
27+
assert path not in files, "Codegen attempted to write to the same file twice!"
28+
files[path] = buffer
29+
return buffer
30+
31+
schema_to_river_client_codegen(
32+
read_schema=read_schema,
33+
target_path=target_path,
34+
client_name=client_name,
35+
file_opener=file_opener,
36+
typed_dict_inputs=True,
37+
)
38+
for path, file in files.items():
39+
file.seek(0)
40+
snapshot.assert_match(file.read(), Path(snapshot.snapshot_dir, path))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Code generated by river.codegen. DO NOT EDIT.
2+
from pydantic import BaseModel
3+
from typing import Literal
4+
5+
import replit_river as river
6+
7+
8+
from .test_service import Test_ServiceService
9+
10+
11+
class PathologicalClient:
12+
def __init__(self, client: river.Client[Literal[None]]):
13+
self.test_service = Test_ServiceService(client)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Code generated by river.codegen. DO NOT EDIT.
2+
from collections.abc import AsyncIterable, AsyncIterator
3+
from typing import Any
4+
import datetime
5+
6+
from pydantic import TypeAdapter
7+
8+
from replit_river.error_schema import RiverError, RiverErrorTypeAdapter
9+
import replit_river as river
10+
11+
12+
from .pathological_method import (
13+
Pathological_MethodInput,
14+
Pathological_MethodInputTypeAdapter,
15+
encode_Pathological_MethodInput,
16+
encode_Pathological_MethodInputObj_Boolean,
17+
encode_Pathological_MethodInputObj_Date,
18+
encode_Pathological_MethodInputObj_Integer,
19+
encode_Pathological_MethodInputObj_Null,
20+
encode_Pathological_MethodInputObj_Number,
21+
encode_Pathological_MethodInputObj_String,
22+
encode_Pathological_MethodInputObj_Uint8Array,
23+
encode_Pathological_MethodInputObj_Undefined,
24+
encode_Pathological_MethodInputReq_Obj_Boolean,
25+
encode_Pathological_MethodInputReq_Obj_Date,
26+
encode_Pathological_MethodInputReq_Obj_Integer,
27+
encode_Pathological_MethodInputReq_Obj_Null,
28+
encode_Pathological_MethodInputReq_Obj_Number,
29+
encode_Pathological_MethodInputReq_Obj_String,
30+
encode_Pathological_MethodInputReq_Obj_Uint8Array,
31+
encode_Pathological_MethodInputReq_Obj_Undefined,
32+
)
33+
34+
boolTypeAdapter: TypeAdapter[Any] = TypeAdapter(bool)
35+
36+
37+
class Test_ServiceService:
38+
def __init__(self, client: river.Client[Any]):
39+
self.client = client
40+
41+
async def pathological_method(
42+
self,
43+
input: Pathological_MethodInput,
44+
timeout: datetime.timedelta,
45+
) -> bool:
46+
return await self.client.send_rpc(
47+
"test_service",
48+
"pathological_method",
49+
input,
50+
encode_Pathological_MethodInput,
51+
lambda x: boolTypeAdapter.validate_python(
52+
x # type: ignore[arg-type]
53+
),
54+
lambda x: RiverErrorTypeAdapter.validate_python(
55+
x # type: ignore[arg-type]
56+
),
57+
timeout,
58+
)

0 commit comments

Comments
 (0)