Skip to content

Commit ef2bac9

Browse files
committed
chore: fix a few typing issues
1 parent 33bf803 commit ef2bac9

File tree

2 files changed

+21
-15
lines changed

2 files changed

+21
-15
lines changed

scim2_models/base.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def validate_model_attribute(model: type["BaseModel"], attribute_base: str) -> N
5555
if sub_attribute_base:
5656
attribute_type = model.get_field_root_type(attribute_name)
5757

58-
if not issubclass(attribute_type, BaseModel):
58+
if not attribute_type or not issubclass(attribute_type, BaseModel):
5959
raise ValueError(
6060
f"Attribute '{attribute_name}' is not a complex attribute, and cannot have a '{sub_attribute_base}' sub-attribute"
6161
)
@@ -429,7 +429,7 @@ def annotation_type_filter(item):
429429
return field_annotation
430430

431431
@classmethod
432-
def get_field_root_type(cls, attribute_name: str) -> type:
432+
def get_field_root_type(cls, attribute_name: str) -> type | None:
433433
"""Extract the root type from a model field.
434434
435435
For example, return 'GroupMember' for
@@ -442,9 +442,8 @@ def get_field_root_type(cls, attribute_name: str) -> type:
442442
attribute_type = get_args(attribute_type)[0]
443443

444444
# extract 'x' from 'List[x]'
445-
if isclass(get_origin(attribute_type)) and issubclass(
446-
get_origin(attribute_type), list
447-
):
445+
origin = get_origin(attribute_type)
446+
if origin and isclass(origin) and issubclass(origin, list):
448447
attribute_type = get_args(attribute_type)[0]
449448

450449
return attribute_type
@@ -637,28 +636,29 @@ def scim_serializer(
637636
) -> Any:
638637
"""Serialize the fields according to mutability indications passed in the serialization context."""
639638
value = handler(value)
639+
scim_ctx = info.context.get("scim") if info.context else None
640640

641-
if info.context.get("scim") and Context.is_request(info.context["scim"]):
641+
if scim_ctx and Context.is_request(scim_ctx):
642642
value = self.scim_request_serializer(value, info)
643643

644-
if info.context.get("scim") and Context.is_response(info.context["scim"]):
644+
if scim_ctx and Context.is_response(scim_ctx):
645645
value = self.scim_response_serializer(value, info)
646646

647647
return value
648648

649649
def scim_request_serializer(self, value: Any, info: SerializationInfo) -> Any:
650650
"""Serialize the fields according to mutability indications passed in the serialization context."""
651651
mutability = self.get_field_annotation(info.field_name, Mutability)
652-
context = info.context.get("scim")
652+
scim_ctx = info.context.get("scim") if info.context else None
653653

654654
if (
655-
context == Context.RESOURCE_CREATION_REQUEST
655+
scim_ctx == Context.RESOURCE_CREATION_REQUEST
656656
and mutability == Mutability.read_only
657657
):
658658
return None
659659

660660
if (
661-
context
661+
scim_ctx
662662
in (
663663
Context.RESOURCE_QUERY_REQUEST,
664664
Context.SEARCH_REQUEST,
@@ -667,7 +667,7 @@ def scim_request_serializer(self, value: Any, info: SerializationInfo) -> Any:
667667
):
668668
return None
669669

670-
if context == Context.RESOURCE_REPLACEMENT_REQUEST and mutability in (
670+
if scim_ctx == Context.RESOURCE_REPLACEMENT_REQUEST and mutability in (
671671
Mutability.immutable,
672672
Mutability.read_only,
673673
):
@@ -679,8 +679,10 @@ def scim_response_serializer(self, value: Any, info: SerializationInfo) -> Any:
679679
"""Serialize the fields according to returnability indications passed in the serialization context."""
680680
returnability = self.get_field_annotation(info.field_name, Returned)
681681
attribute_urn = self.get_attribute_urn(info.field_name)
682-
included_urns = info.context.get("scim_attributes", [])
683-
excluded_urns = info.context.get("scim_excluded_attributes", [])
682+
included_urns = info.context.get("scim_attributes", []) if info.context else []
683+
excluded_urns = (
684+
info.context.get("scim_excluded_attributes", []) if info.context else []
685+
)
684686

685687
attribute_urn = normalize_attribute_name(attribute_urn)
686688
included_urns = [normalize_attribute_name(urn) for urn in included_urns]

scim2_models/rfc7643/schema.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pydantic.alias_generators import to_snake
1818
from pydantic_core import Url
1919

20+
from ..base import BaseModel
2021
from ..base import CaseExact
2122
from ..base import ComplexAttribute
2223
from ..base import ExternalReference
@@ -30,6 +31,7 @@
3031
from ..base import is_complex_attribute
3132
from ..constants import RESERVED_WORDS
3233
from ..utils import normalize_attribute_name
34+
from .resource import Extension
3335
from .resource import Resource
3436

3537

@@ -43,8 +45,10 @@ def make_python_identifier(identifier: str) -> str:
4345

4446

4547
def make_python_model(
46-
obj: Union["Schema", "Attribute"], base: Optional[type] = None, multiple=False
47-
) -> "Resource":
48+
obj: Union["Schema", "Attribute"],
49+
base: Optional[type[BaseModel]] = None,
50+
multiple=False,
51+
) -> "Resource" | "Extension":
4852
"""Build a Python model from a Schema or an Attribute object."""
4953
if isinstance(obj, Attribute):
5054
pydantic_attributes = {

0 commit comments

Comments
 (0)