Skip to content

Commit

Permalink
feat: Implement Resource.get_extension_model
Browse files Browse the repository at this point in the history
  • Loading branch information
azmeuk committed Dec 2, 2024
1 parent 298b090 commit 0073127
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 4 deletions.
9 changes: 8 additions & 1 deletion doc/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
Changelog
=========

[0.2.9] - Unreleased
--------------------

Added
^^^^^
- Implement :meth:`Resource.get_extension_model <scim2_models.Resource.get_extension_model>`.

[0.2.8] - 2024-12-02
--------------------

Expand All @@ -13,7 +20,7 @@ Added

Added
^^^^^
- Implement :meth:`ResourceType.from_resource`.
- Implement :meth:`ResourceType.from_resource <scim2_models.ResourceType.from_resource>`.

[0.2.6] - 2024-11-29
--------------------
Expand Down
10 changes: 9 additions & 1 deletion scim2_models/rfc7643/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __setitem__(self, item: Any, value: "Resource"):
setattr(self, item.__name__, value)

@classmethod
def get_extension_models(cls) -> dict[str, type]:
def get_extension_models(cls) -> dict[str, type[Extension]]:
"""Return extension a dict associating extension models with their schemas."""
extension_models = cls.__pydantic_generic_metadata__.get("args", [])
extension_models = (
Expand All @@ -191,6 +191,14 @@ def get_extension_models(cls) -> dict[str, type]:
}
return by_schema

@classmethod
def get_extension_model(cls, name_or_schema) -> Optional[type[Extension]]:
"""Return an extension by its name or schema."""
for schema, extension in cls.get_extension_models().items():
if schema == name_or_schema or extension.__name__ == name_or_schema:
return extension
return None

@staticmethod
def get_by_schema(
resource_types: list[type[BaseModel]], schema: str, with_extensions=True
Expand Down
41 changes: 39 additions & 2 deletions tests/test_resource_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def test_invalid_setitem():


class SuperHero(Extension):
schemas: list[str] = ["urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"]
schemas: list[str] = ["example:extensions:SuperHero"]

superpower: Optional[str] = None
"""The superhero superpower."""
Expand All @@ -217,8 +217,9 @@ def test_multiple_extensions_union():
"schemas": [
"urn:ietf:params:scim:schemas:core:2.0:User",
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
"example:extensions:SuperHero",
],
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User": {
"example:extensions:SuperHero": {
"superpower": "flight",
},
}
Expand Down Expand Up @@ -268,3 +269,39 @@ def test_validate_items_without_extension():
User[EnterpriseUser].model_validate(
payload, scim_ctx=Context.RESOURCE_CREATION_RESPONSE
)


def test_get_extension_model():
assert User[EnterpriseUser].get_extension_model("EnterpriseUser") == EnterpriseUser
assert (
User[EnterpriseUser].get_extension_model(
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
)
== EnterpriseUser
)

assert (
User[Union[EnterpriseUser, SuperHero]].get_extension_model("EnterpriseUser")
== EnterpriseUser
)
assert (
User[Union[EnterpriseUser, SuperHero]].get_extension_model(
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
)
== EnterpriseUser
)

assert User[SuperHero].get_extension_model("EnterpriseUser") is None
assert (
User[SuperHero].get_extension_model(
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
)
is None
)
assert User.get_extension_model("EnterpriseUser") is None
assert (
User.get_extension_model(
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
)
is None
)

0 comments on commit 0073127

Please sign in to comment.