4343 is_none ,
4444 is_opt ,
4545 is_primitive ,
46+ is_primitive_subclass ,
4647 is_set ,
4748 is_str_serializable ,
4849 is_tuple ,
@@ -245,8 +246,13 @@ def wrap(cls: Type[T]) -> Type[T]:
245246 # We call deserialize and not wrap to make sure that we will use the default serde
246247 # configuration for generating the deserialization function.
247248 deserialize (typ )
248- if is_primitive (typ ) and not is_enum (typ ):
249+
250+ # We don't want to add primitive class e.g "str" into the scope, but primitive
251+ # compatible types such as IntEnum and a subclass of primitives are added,
252+ # so that generated code can use those types.
253+ if is_primitive (typ ) and not is_enum (typ ) and not is_primitive_subclass (typ ):
249254 continue
255+
250256 if is_generic (typ ):
251257 g [typename (typ )] = get_origin (typ )
252258 else :
@@ -624,6 +630,7 @@ class Renderer:
624630 custom : Optional [DeserializeFunc ] = None # Custom class level deserializer.
625631 import_numpy : bool = False
626632 suppress_coerce : bool = False
633+ """ Disable type coercing in codegen """
627634
628635 def render (self , arg : DeField [Any ]) -> str :
629636 """
@@ -655,8 +662,6 @@ def render(self, arg: DeField[Any]) -> str:
655662 elif is_numpy_array (arg .type ):
656663 self .import_numpy = True
657664 res = deserialize_numpy_array (arg )
658- elif is_primitive (arg .type ):
659- res = self .primitive (arg )
660665 elif is_union (arg .type ):
661666 res = self .union_func (arg )
662667 elif is_str_serializable (arg .type ):
@@ -669,6 +674,9 @@ def render(self, arg: DeField[Any]) -> str:
669674 res = "None"
670675 elif is_any (arg .type ) or is_ellipsis (arg .type ):
671676 res = arg .data
677+ elif is_primitive (arg .type ):
678+ # For subclasses for primitives e.g. class FooStr(str), coercing is always enabled
679+ res = self .primitive (arg , not is_primitive_subclass (arg .type ))
672680 elif isinstance (arg .type , TypeVar ):
673681 index = find_generic_arg (self .cls , arg .type )
674682 res = (
@@ -876,6 +884,8 @@ def primitive(self, arg: DeField[Any], suppress_coerce: bool = False) -> str:
876884 """
877885 Render rvalue for primitives.
878886
887+ * `suppress_coerce`: Overrides "suppress_coerce" in the Renderer's field
888+
879889 >>> Renderer('foo').render(DeField(int, 'i', datavar='data'))
880890 'coerce(int, data["i"])'
881891
@@ -890,7 +900,7 @@ def primitive(self, arg: DeField[Any], suppress_coerce: bool = False) -> str:
890900 if arg .alias :
891901 aliases = (f'"{ s } "' for s in [arg .name , * arg .alias ])
892902 dat = f"_get_by_aliases(data, [{ ',' .join (aliases )} ])"
893- if self .suppress_coerce :
903+ if self .suppress_coerce and suppress_coerce :
894904 return dat
895905 else :
896906 return f"coerce({ typ } , { dat } )"
0 commit comments