Skip to content

Commit 73130ce

Browse files
committed
Add AwsSystemsManagerParameterStoreSettingsSource
1 parent 818d56e commit 73130ce

File tree

3 files changed

+307
-0
lines changed

3 files changed

+307
-0
lines changed

docs/index.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,6 +1306,63 @@ class AzureKeyVaultSettings(BaseSettings):
13061306
)
13071307
```
13081308

1309+
## AWS Systems Manager Parameter Store
1310+
1311+
You must set the following parameters:
1312+
1313+
- `ssm_client`: An initialized [`boto3` SSM Client](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ssm.html#client).
1314+
1315+
Optionally, you may specify the following parameters:
1316+
1317+
- `ssm_path`: The hierarchy for the parameter. Hierarchies start with a forward slash (/). The hierarchy is the parameter name except the last part of the parameter. Under the hood, we make use of the [`get_parameters_by_path` method](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ssm/client/get_parameters_by_path.html) to recursively retrieve all parameters within the a specified path hierarchy.
1318+
1319+
```py
1320+
import os
1321+
from typing import Tuple, Type
1322+
1323+
import boto3
1324+
from pydantic import BaseModel
1325+
1326+
from pydantic_settings import (
1327+
AwsSystemsManagerParameterStoreSettingsSource,
1328+
BaseSettings,
1329+
PydanticBaseSettingsSource,
1330+
)
1331+
1332+
1333+
class SubModel(BaseModel):
1334+
a: str
1335+
1336+
1337+
class AzureKeyVaultSettings(BaseSettings):
1338+
foo: str
1339+
bar: int
1340+
sub: SubModel
1341+
1342+
@classmethod
1343+
def settings_customise_sources(
1344+
cls,
1345+
settings_cls: Type[BaseSettings],
1346+
init_settings: PydanticBaseSettingsSource,
1347+
env_settings: PydanticBaseSettingsSource,
1348+
dotenv_settings: PydanticBaseSettingsSource,
1349+
file_secret_settings: PydanticBaseSettingsSource,
1350+
) -> Tuple[PydanticBaseSettingsSource, ...]:
1351+
client = boto3.client('ssm')
1352+
ssm_param_store_settings = AwsSystemsManagerParameterStoreSettingsSource(
1353+
settings_cls,
1354+
ssm_client=client,
1355+
ssm_path=os.environ.get('SSM_PREFIX', '/api/dev/'),
1356+
)
1357+
return (
1358+
init_settings,
1359+
env_settings,
1360+
dotenv_settings,
1361+
file_secret_settings,
1362+
ssm_param_store_settings,
1363+
)
1364+
``` -->
1365+
13091366
## Other settings source
13101367

13111368
Other settings sources are available for common configuration files:

pydantic_settings/sources.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2014,6 +2014,62 @@ def __repr__(self) -> str:
20142014
return f'AzureKeyVaultSettingsSource(url={self._url!r}, ' f'env_nested_delimiter={self.env_nested_delimiter!r})'
20152015

20162016

2017+
class AwsSystemsManagerParameterStoreSettingsSource(EnvSettingsSource):
2018+
_ssm_client: "SSMClient" # type: ignore
2019+
_ssm_path: str
2020+
2021+
def __init__(
2022+
self,
2023+
settings_cls: type[BaseSettings],
2024+
ssm_client: "SSMClient", # type: ignore
2025+
ssm_path: str = "/",
2026+
case_sensitive: bool | None = None,
2027+
env_prefix: str | None = None,
2028+
env_nested_delimiter: str = "/",
2029+
env_ignore_empty: bool | None = None,
2030+
env_parse_none_str: str | None = None,
2031+
env_parse_enums: bool | None = None,
2032+
) -> None:
2033+
self._ssm_client = ssm_client
2034+
self._ssm_path = ssm_path
2035+
super().__init__(
2036+
settings_cls,
2037+
case_sensitive,
2038+
env_prefix,
2039+
env_nested_delimiter,
2040+
env_ignore_empty,
2041+
env_parse_none_str,
2042+
env_parse_enums,
2043+
)
2044+
2045+
def _load_env_vars(self) -> Mapping[str, Optional[str]]:
2046+
paginator = self._ssm_client.get_paginator("get_parameters_by_path")
2047+
response_iterator = paginator.paginate(
2048+
Path=self._ssm_path, WithDecryption=True, Recursive=True
2049+
)
2050+
2051+
output = {}
2052+
try:
2053+
for page in response_iterator:
2054+
for parameter in page["Parameters"]:
2055+
name = Path(parameter["Name"])
2056+
key = name.relative_to(self._ssm_path).as_posix()
2057+
2058+
if not self.case_sensitive:
2059+
first_key, *rest = key.split(self.env_nested_delimiter)
2060+
key = self.env_nested_delimiter.join([first_key.lower(), *rest])
2061+
2062+
output[key] = parameter["Value"]
2063+
2064+
except self._ssm_client.exceptions.ClientError as e:
2065+
warnings.warn(f"Unable to get parameters from {self._ssm_path!r}: {e}")
2066+
2067+
return output
2068+
2069+
def __repr__(self) -> str:
2070+
return f"AwsSystemsManagerParameterStoreSettingsSource(ssm_path={self._ssm_path!r})"
2071+
2072+
20172073
def _get_env_var_key(key: str, case_sensitive: bool = False) -> str:
20182074
return key if case_sensitive else key.lower()
20192075

tests/test_sources.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from pydantic_settings.main import BaseSettings, SettingsConfigDict
1212
from pydantic_settings.sources import (
13+
AwsSystemsManagerParameterStoreSettingsSource,
1314
AzureKeyVaultSettingsSource,
1415
PydanticBaseSettingsSource,
1516
PyprojectTomlConfigSettingsSource,
@@ -31,6 +32,12 @@
3132
except ImportError:
3233
azure_key_vault = False
3334

35+
try:
36+
aws = True
37+
import boto3
38+
except ImportError:
39+
aws = False
40+
3441
if TYPE_CHECKING:
3542
from pathlib import Path
3643

@@ -210,3 +217,190 @@ def _raise_resource_not_found_when_getting_parent_secret_name(self, secret_name:
210217
raise ResourceNotFoundError()
211218

212219
return key_vault_secret
220+
221+
222+
@pytest.mark.skipif(not aws, reason="boto3 is not installed")
223+
class TestAwsSystemsManagerParameterStoreSettingsSource:
224+
"""Test AwsSystemsManagerParameterStoreSettingsSource."""
225+
226+
def test___init__(self, mocker: MockerFixture) -> None:
227+
"""Test __init__."""
228+
229+
class AwsSettings(BaseSettings):
230+
"""AWS settings."""
231+
232+
mock_parameters = []
233+
paginator_mock = mocker.Mock()
234+
paginator_mock.paginate.return_value = [{'Parameters': mock_parameters}]
235+
236+
client_mock = mocker.Mock()
237+
client_mock.get_paginator.return_value = paginator_mock
238+
client_mock.exceptions.ClientError = Exception
239+
240+
AwsSystemsManagerParameterStoreSettingsSource(
241+
settings_cls=AwsSettings, ssm_client=client_mock, ssm_path='/my/path'
242+
)
243+
244+
def test___call__case_sensitive(self, mocker: MockerFixture) -> None:
245+
"""Test __call__."""
246+
247+
class SqlServer(BaseModel):
248+
password: str = Field(..., alias='Password')
249+
250+
class AwsSettings(BaseSettings):
251+
"""AWS settings."""
252+
253+
SqlServerUser: str
254+
sql_server_user: str = Field(..., alias='SqlServerUser')
255+
sql_server: SqlServer = Field(..., alias='SqlServer')
256+
257+
mock_parameters = [
258+
{'Name': '/my/path/SqlServerUser', 'Value': 'SecretValue'},
259+
{'Name': '/my/path/SqlServer/Password', 'Value': 'SecretValue'},
260+
]
261+
paginator_mock = mocker.Mock()
262+
paginator_mock.paginate.return_value = [{'Parameters': mock_parameters}]
263+
264+
client_mock = mocker.Mock()
265+
client_mock.get_paginator.return_value = paginator_mock
266+
client_mock.exceptions.ClientError = Exception
267+
268+
obj = AwsSystemsManagerParameterStoreSettingsSource(
269+
settings_cls=AwsSettings,
270+
ssm_client=client_mock,
271+
ssm_path='/my/path',
272+
case_sensitive=True,
273+
)
274+
275+
settings = obj()
276+
277+
assert settings['SqlServerUser'] == 'SecretValue'
278+
assert settings['SqlServer']['Password'] == 'SecretValue'
279+
280+
def test___call__case_insensitive(self, mocker: MockerFixture) -> None:
281+
"""Test __call__."""
282+
283+
class SqlServer(BaseModel):
284+
password: str = Field(..., alias='Password')
285+
286+
class AwsSettings(BaseSettings):
287+
"""AWS settings."""
288+
289+
SqlServerUser: str
290+
sql_server_user: str = Field(..., alias='SqlServerUser')
291+
sql_server: SqlServer = Field(..., alias='SqlServer')
292+
293+
mock_parameters = [
294+
{'Name': '/my/path/SqlServerUser', 'Value': 'SecretValue'},
295+
{'Name': '/my/path/SqlServer/Password', 'Value': 'SecretValue'},
296+
]
297+
paginator_mock = mocker.Mock()
298+
paginator_mock.paginate.return_value = [{'Parameters': mock_parameters}]
299+
300+
client_mock = mocker.Mock()
301+
client_mock.get_paginator.return_value = paginator_mock
302+
client_mock.exceptions.ClientError = Exception
303+
304+
obj = AwsSystemsManagerParameterStoreSettingsSource(
305+
settings_cls=AwsSettings,
306+
ssm_client=client_mock,
307+
ssm_path='/my/path',
308+
case_sensitive=False,
309+
)
310+
settings = obj()
311+
312+
assert settings['SqlServerUser'] == 'SecretValue'
313+
assert settings['SqlServer']['Password'] == 'SecretValue'
314+
315+
def test_aws_ssm_settings_source(self, mocker: MockerFixture) -> None:
316+
"""Test AwsSystemsManagerParameterStoreSettingsSource."""
317+
mock_parameters = [
318+
{'Name': '/my/path/SqlServerUser', 'Value': 'SecretValue'},
319+
{'Name': '/my/path/SqlServer/Password', 'Value': 'SecretValue'},
320+
]
321+
paginator_mock = mocker.Mock()
322+
paginator_mock.paginate.return_value = [{'Parameters': mock_parameters}]
323+
324+
client_mock = mocker.Mock()
325+
client_mock.get_paginator.return_value = paginator_mock
326+
client_mock.exceptions.ClientError = Exception
327+
328+
class SqlServer(BaseModel):
329+
password: str = Field(..., alias='Password')
330+
331+
class AwsSettings(BaseSettings):
332+
"""AWS settings."""
333+
334+
SqlServerUser: str
335+
sql_server_user: str = Field(..., alias='SqlServerUser')
336+
sql_server: SqlServer = Field(..., alias='SqlServer')
337+
338+
@classmethod
339+
def settings_customise_sources(
340+
cls,
341+
settings_cls: type[BaseSettings],
342+
init_settings: PydanticBaseSettingsSource,
343+
env_settings: PydanticBaseSettingsSource,
344+
dotenv_settings: PydanticBaseSettingsSource,
345+
file_secret_settings: PydanticBaseSettingsSource,
346+
) -> tuple[PydanticBaseSettingsSource, ...]:
347+
return (
348+
AwsSystemsManagerParameterStoreSettingsSource(
349+
settings_cls=AwsSettings,
350+
ssm_client=client_mock,
351+
ssm_path='/my/path',
352+
),
353+
)
354+
355+
settings = AwsSettings() # type: ignore
356+
357+
assert settings.SqlServerUser == 'SecretValue'
358+
assert settings.sql_server_user == 'SecretValue'
359+
assert settings.sql_server.password == 'SecretValue'
360+
361+
def test_aws_ssm_settings_source__delimiter(self, mocker: MockerFixture) -> None:
362+
"""Test AwsSystemsManagerParameterStoreSettingsSource."""
363+
mock_parameters = [
364+
{'Name': '/my/path/SqlServerUser', 'Value': 'SecretValue'},
365+
{'Name': '/my/path/SqlServer__Password', 'Value': 'SecretValue'},
366+
]
367+
paginator_mock = mocker.Mock()
368+
paginator_mock.paginate.return_value = [{'Parameters': mock_parameters}]
369+
370+
client_mock = mocker.Mock()
371+
client_mock.get_paginator.return_value = paginator_mock
372+
client_mock.exceptions.ClientError = Exception
373+
374+
class SqlServer(BaseModel):
375+
password: str = Field(..., alias='Password')
376+
377+
class AwsSettings(BaseSettings):
378+
"""AWS settings."""
379+
380+
SqlServerUser: str
381+
sql_server_user: str = Field(..., alias='SqlServerUser')
382+
sql_server: SqlServer = Field(..., alias='SqlServer')
383+
384+
@classmethod
385+
def settings_customise_sources(
386+
cls,
387+
settings_cls: type[BaseSettings],
388+
init_settings: PydanticBaseSettingsSource,
389+
env_settings: PydanticBaseSettingsSource,
390+
dotenv_settings: PydanticBaseSettingsSource,
391+
file_secret_settings: PydanticBaseSettingsSource,
392+
) -> tuple[PydanticBaseSettingsSource, ...]:
393+
return (
394+
AwsSystemsManagerParameterStoreSettingsSource(
395+
settings_cls=AwsSettings,
396+
ssm_client=client_mock,
397+
ssm_path='/my/path',
398+
env_nested_delimiter='__',
399+
),
400+
)
401+
402+
settings = AwsSettings() # type: ignore
403+
404+
assert settings.SqlServerUser == 'SecretValue'
405+
assert settings.sql_server_user == 'SecretValue'
406+
assert settings.sql_server.password == 'SecretValue'

0 commit comments

Comments
 (0)