Skip to content

Commit 0073127

Browse files
committed
feat: Implement Resource.get_extension_model
1 parent 298b090 commit 0073127

File tree

3 files changed

+56
-4
lines changed

3 files changed

+56
-4
lines changed

doc/changelog.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
Changelog
22
=========
33

4+
[0.2.9] - Unreleased
5+
--------------------
6+
7+
Added
8+
^^^^^
9+
- Implement :meth:`Resource.get_extension_model <scim2_models.Resource.get_extension_model>`.
10+
411
[0.2.8] - 2024-12-02
512
--------------------
613

@@ -13,7 +20,7 @@ Added
1320

1421
Added
1522
^^^^^
16-
- Implement :meth:`ResourceType.from_resource`.
23+
- Implement :meth:`ResourceType.from_resource <scim2_models.ResourceType.from_resource>`.
1724

1825
[0.2.6] - 2024-11-29
1926
--------------------

scim2_models/rfc7643/resource.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def __setitem__(self, item: Any, value: "Resource"):
177177
setattr(self, item.__name__, value)
178178

179179
@classmethod
180-
def get_extension_models(cls) -> dict[str, type]:
180+
def get_extension_models(cls) -> dict[str, type[Extension]]:
181181
"""Return extension a dict associating extension models with their schemas."""
182182
extension_models = cls.__pydantic_generic_metadata__.get("args", [])
183183
extension_models = (
@@ -191,6 +191,14 @@ def get_extension_models(cls) -> dict[str, type]:
191191
}
192192
return by_schema
193193

194+
@classmethod
195+
def get_extension_model(cls, name_or_schema) -> Optional[type[Extension]]:
196+
"""Return an extension by its name or schema."""
197+
for schema, extension in cls.get_extension_models().items():
198+
if schema == name_or_schema or extension.__name__ == name_or_schema:
199+
return extension
200+
return None
201+
194202
@staticmethod
195203
def get_by_schema(
196204
resource_types: list[type[BaseModel]], schema: str, with_extensions=True

tests/test_resource_extension.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def test_invalid_setitem():
201201

202202

203203
class SuperHero(Extension):
204-
schemas: list[str] = ["urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"]
204+
schemas: list[str] = ["example:extensions:SuperHero"]
205205

206206
superpower: Optional[str] = None
207207
"""The superhero superpower."""
@@ -217,8 +217,9 @@ def test_multiple_extensions_union():
217217
"schemas": [
218218
"urn:ietf:params:scim:schemas:core:2.0:User",
219219
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
220+
"example:extensions:SuperHero",
220221
],
221-
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User": {
222+
"example:extensions:SuperHero": {
222223
"superpower": "flight",
223224
},
224225
}
@@ -268,3 +269,39 @@ def test_validate_items_without_extension():
268269
User[EnterpriseUser].model_validate(
269270
payload, scim_ctx=Context.RESOURCE_CREATION_RESPONSE
270271
)
272+
273+
274+
def test_get_extension_model():
275+
assert User[EnterpriseUser].get_extension_model("EnterpriseUser") == EnterpriseUser
276+
assert (
277+
User[EnterpriseUser].get_extension_model(
278+
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
279+
)
280+
== EnterpriseUser
281+
)
282+
283+
assert (
284+
User[Union[EnterpriseUser, SuperHero]].get_extension_model("EnterpriseUser")
285+
== EnterpriseUser
286+
)
287+
assert (
288+
User[Union[EnterpriseUser, SuperHero]].get_extension_model(
289+
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
290+
)
291+
== EnterpriseUser
292+
)
293+
294+
assert User[SuperHero].get_extension_model("EnterpriseUser") is None
295+
assert (
296+
User[SuperHero].get_extension_model(
297+
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
298+
)
299+
is None
300+
)
301+
assert User.get_extension_model("EnterpriseUser") is None
302+
assert (
303+
User.get_extension_model(
304+
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
305+
)
306+
is None
307+
)

0 commit comments

Comments
 (0)