4343 is_none ,
4444 is_opt ,
4545 is_primitive ,
46+ is_primitive_subclass ,
4647 is_set ,
4748 is_str_serializable ,
4849 is_tuple ,
@@ -245,8 +246,13 @@ def wrap(cls: Type[T]) -> Type[T]:
245246 # We call deserialize and not wrap to make sure that we will use the default serde
246247 # configuration for generating the deserialization function.
247248 deserialize (typ )
248- if is_primitive (typ ) and not is_enum (typ ):
249+
250+ # We don't want to add primitive class e.g "str" into the scope, but primitive
251+ # compatible types such as IntEnum and a subclass of primitives are added,
252+ # so that generated code can use those types.
253+ if is_primitive (typ ) and not is_enum (typ ) and not is_primitive_subclass (typ ):
249254 continue
255+
250256 if is_generic (typ ):
251257 g [typename (typ )] = get_origin (typ )
252258 else :
@@ -626,6 +632,7 @@ class Renderer:
626632 custom : Optional [DeserializeFunc ] = None # Custom class level deserializer.
627633 import_numpy : bool = False
628634 suppress_coerce : bool = False
635+ """ Disable type coercing in codegen """
629636
630637 def render (self , arg : DeField [Any ]) -> str :
631638 """
@@ -657,8 +664,6 @@ def render(self, arg: DeField[Any]) -> str:
657664 elif is_numpy_array (arg .type ):
658665 self .import_numpy = True
659666 res = deserialize_numpy_array (arg )
660- elif is_primitive (arg .type ):
661- res = self .primitive (arg )
662667 elif is_union (arg .type ):
663668 res = self .union_func (arg )
664669 elif is_str_serializable (arg .type ):
@@ -671,6 +676,9 @@ def render(self, arg: DeField[Any]) -> str:
671676 res = "None"
672677 elif is_any (arg .type ) or is_ellipsis (arg .type ):
673678 res = arg .data
679+ elif is_primitive (arg .type ):
680+ # For subclasses for primitives e.g. class FooStr(str), coercing is always enabled
681+ res = self .primitive (arg , not is_primitive_subclass (arg .type ))
674682 elif isinstance (arg .type , TypeVar ):
675683 index = find_generic_arg (self .cls , arg .type )
676684 res = (
@@ -878,6 +886,8 @@ def primitive(self, arg: DeField[Any], suppress_coerce: bool = False) -> str:
878886 """
879887 Render rvalue for primitives.
880888
889+ * `suppress_coerce`: Overrides "suppress_coerce" in the Renderer's field
890+
881891 >>> Renderer('foo').render(DeField(int, 'i', datavar='data'))
882892 'coerce(int, data["i"])'
883893
@@ -892,7 +902,7 @@ def primitive(self, arg: DeField[Any], suppress_coerce: bool = False) -> str:
892902 if arg .alias :
893903 aliases = (f'"{ s } "' for s in [arg .name , * arg .alias ])
894904 dat = f"_get_by_aliases(data, [{ ',' .join (aliases )} ])"
895- if self .suppress_coerce :
905+ if self .suppress_coerce and suppress_coerce :
896906 return dat
897907 else :
898908 return f"coerce({ typ } , { dat } )"
0 commit comments