@@ -16,12 +16,24 @@ class TypeName:
16
16
def __str__ (self ) -> str :
17
17
raise Exception ("Complex type must be put through render_type_expr!" )
18
18
19
+ def __eq__ (self , other : object ) -> bool :
20
+ return isinstance (other , TypeName ) and other .value == self .value
21
+
22
+ def __lt__ (self , other : object ) -> bool :
23
+ return hash (self ) < hash (other )
24
+
19
25
20
26
@dataclass (frozen = True )
21
27
class NoneTypeExpr :
22
28
def __str__ (self ) -> str :
23
29
raise Exception ("Complex type must be put through render_type_expr!" )
24
30
31
+ def __eq__ (self , other : object ) -> bool :
32
+ return isinstance (other , NoneTypeExpr )
33
+
34
+ def __lt__ (self , other : object ) -> bool :
35
+ return hash (self ) < hash (other )
36
+
25
37
26
38
@dataclass (frozen = True )
27
39
class DictTypeExpr :
@@ -30,6 +42,12 @@ class DictTypeExpr:
30
42
def __str__ (self ) -> str :
31
43
raise Exception ("Complex type must be put through render_type_expr!" )
32
44
45
+ def __eq__ (self , other : object ) -> bool :
46
+ return isinstance (other , DictTypeExpr ) and other .nested == self .nested
47
+
48
+ def __lt__ (self , other : object ) -> bool :
49
+ return hash (self ) < hash (other )
50
+
33
51
34
52
@dataclass (frozen = True )
35
53
class ListTypeExpr :
@@ -38,6 +56,12 @@ class ListTypeExpr:
38
56
def __str__ (self ) -> str :
39
57
raise Exception ("Complex type must be put through render_type_expr!" )
40
58
59
+ def __eq__ (self , other : object ) -> bool :
60
+ return isinstance (other , ListTypeExpr ) and other .nested == self .nested
61
+
62
+ def __lt__ (self , other : object ) -> bool :
63
+ return hash (self ) < hash (other )
64
+
41
65
42
66
@dataclass (frozen = True )
43
67
class LiteralTypeExpr :
@@ -46,6 +70,12 @@ class LiteralTypeExpr:
46
70
def __str__ (self ) -> str :
47
71
raise Exception ("Complex type must be put through render_type_expr!" )
48
72
73
+ def __eq__ (self , other : object ) -> bool :
74
+ return isinstance (other , LiteralTypeExpr ) and other .nested == self .nested
75
+
76
+ def __lt__ (self , other : object ) -> bool :
77
+ return hash (self ) < hash (other )
78
+
49
79
50
80
@dataclass (frozen = True )
51
81
class UnionTypeExpr :
@@ -54,6 +84,14 @@ class UnionTypeExpr:
54
84
def __str__ (self ) -> str :
55
85
raise Exception ("Complex type must be put through render_type_expr!" )
56
86
87
+ def __eq__ (self , other : object ) -> bool :
88
+ return isinstance (other , UnionTypeExpr ) and set (other .nested ) == set (
89
+ self .nested
90
+ )
91
+
92
+ def __lt__ (self , other : object ) -> bool :
93
+ return hash (self ) < hash (other )
94
+
57
95
58
96
@dataclass (frozen = True )
59
97
class OpenUnionTypeExpr :
@@ -62,6 +100,12 @@ class OpenUnionTypeExpr:
62
100
def __str__ (self ) -> str :
63
101
raise Exception ("Complex type must be put through render_type_expr!" )
64
102
103
+ def __eq__ (self , other : object ) -> bool :
104
+ return isinstance (other , OpenUnionTypeExpr ) and other .union == self .union
105
+
106
+ def __lt__ (self , other : object ) -> bool :
107
+ return hash (self ) < hash (other )
108
+
65
109
66
110
TypeExpression = (
67
111
TypeName
@@ -117,13 +161,25 @@ def render_type_expr(value: TypeExpression) -> str:
117
161
literals .append (tpe )
118
162
else :
119
163
_other .append (tpe )
164
+
165
+ without_none : list [TypeExpression ] = [
166
+ x for x in _other if not isinstance (x , NoneTypeExpr )
167
+ ]
168
+ has_none = len (_other ) > len (without_none )
169
+ _other = without_none
170
+
120
171
retval : str = " | " .join (render_type_expr (x ) for x in _other )
121
172
if literals :
122
173
_rendered : str = ", " .join (repr (x .nested ) for x in literals )
123
174
if retval :
124
175
retval = f"Literal[{ _rendered } ] | { retval } "
125
176
else :
126
177
retval = f"Literal[{ _rendered } ]"
178
+ if has_none :
179
+ if retval :
180
+ retval = f"{ retval } | None"
181
+ else :
182
+ retval = "None"
127
183
return retval
128
184
case OpenUnionTypeExpr (inner ):
129
185
return (
0 commit comments