diff --git a/doc/changelog.rst b/doc/changelog.rst index 1286e9b..f3ebfc4 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -4,6 +4,10 @@ Changelog [0.3.0] - Unreleased -------------------- +Added +^^^^^ +- :meth:`Attribute.get_attribute ` can be called with brackets. + Changed ^^^^^^^ - Add a :paramref:`~scim2_models.BaseModel.model_validate.original` diff --git a/scim2_models/rfc7643/schema.py b/scim2_models/rfc7643/schema.py index 2390f5e..8e49276 100644 --- a/scim2_models/rfc7643/schema.py +++ b/scim2_models/rfc7643/schema.py @@ -245,6 +245,12 @@ def get_attribute(self, attribute_name: str) -> Optional["Attribute"]: return sub_attribute return None + def __getitem__(self, name): + """Find an attribute by its name.""" + if attribute := self.get_attribute(name): + return attribute + raise KeyError(f"This attribute has no '{name}' sub-attribute") + class Schema(Resource): schemas: Annotated[list[str], Required.true] = [ @@ -280,3 +286,9 @@ def get_attribute(self, attribute_name: str) -> Optional[Attribute]: if attribute.name == attribute_name: return attribute return None + + def __getitem__(self, name): + """Find an attribute by its name.""" + if attribute := self.get_attribute(name): + return attribute + raise KeyError(f"This schema has no '{name}' attribute") diff --git a/tests/test_schema.py b/tests/test_schema.py index 8d2bba9..6721940 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -92,13 +92,18 @@ def test_get_schema_attribute(load_sample): payload = load_sample("rfc7643-8.7.1-schema-user.json") schema = Schema.model_validate(payload) assert schema.get_attribute("invalid") is None + with pytest.raises(KeyError): + schema["invalid"] assert schema.attributes[0].name == "userName" assert schema.attributes[0].mutability == Mutability.read_write - schema.get_attribute("userName").mutability = Mutability.read_only + schema.get_attribute("userName").mutability = Mutability.read_only assert schema.attributes[0].mutability == Mutability.read_only + schema["userName"].mutability = Mutability.read_write + assert schema.attributes[0].mutability == Mutability.read_write + def test_get_attribute_attribute(load_sample): """Test the Schema.get_attribute method.""" @@ -107,9 +112,14 @@ def test_get_attribute_attribute(load_sample): attribute = schema.get_attribute("members") assert attribute.get_attribute("invalid") is None + with pytest.raises(KeyError): + attribute["invalid"] assert attribute.sub_attributes[0].name == "value" assert attribute.sub_attributes[0].mutability == Mutability.immutable - attribute.get_attribute("value").mutability = Mutability.read_only + attribute.get_attribute("value").mutability = Mutability.read_only assert attribute.sub_attributes[0].mutability == Mutability.read_only + + attribute["value"].mutability = Mutability.read_write + assert attribute.sub_attributes[0].mutability == Mutability.read_write