diff --git a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsAuthIntegration.java b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsAuthIntegration.java index 1a54d272..9c829316 100644 --- a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsAuthIntegration.java +++ b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsAuthIntegration.java @@ -57,6 +57,10 @@ public List getClientPlugins(GenerationContext context) { .build()) // TODO: Initialize with the provider chain? .nullable(true) + .initialize(writer -> { + writer.addImport("smithy_aws_core.credentials_resolvers", "CredentialsResolverChain"); + writer.write("self.aws_credentials_identity_resolver = aws_credentials_identity_resolver or CredentialsResolverChain(config=self)"); + }) .build()) .addConfigProperty(REGION) .authScheme(new Sigv4AuthScheme()) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/__init__.py b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/__init__.py index 9d4ace5e..f2b7ffd1 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/__init__.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/__init__.py @@ -1,10 +1,12 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +from .chain import CredentialsResolverChain from .environment import EnvironmentCredentialsResolver from .imds import IMDSCredentialsResolver from .static import StaticCredentialsResolver __all__ = ( + "CredentialsResolverChain", "EnvironmentCredentialsResolver", "IMDSCredentialsResolver", "StaticCredentialsResolver", diff --git a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/chain.py b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/chain.py new file mode 100644 index 00000000..fc9b8ab7 --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/chain.py @@ -0,0 +1,57 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from collections.abc import Sequence + +from smithy_core.aio.interfaces.identity import IdentityResolver +from smithy_core.exceptions import SmithyIdentityException +from smithy_core.interfaces.identity import IdentityProperties + +from smithy_aws_core.credentials_resolvers.environment import ( + EnvironmentCredentialsSource, +) +from smithy_aws_core.credentials_resolvers.imds import IMDSCredentialsSource +from smithy_aws_core.credentials_resolvers.interfaces import ( + AwsCredentialsConfig, + CredentialsSource, +) +from smithy_aws_core.identity import AWSCredentialsIdentity, AWSCredentialsResolver + +_DEFAULT_SOURCES: Sequence[CredentialsSource] = ( + EnvironmentCredentialsSource(), + IMDSCredentialsSource(), +) + + +class CredentialsResolverChain( + IdentityResolver[AWSCredentialsIdentity, IdentityProperties] +): + """Resolves AWS Credentials from an ordered list of credentials sources.""" + + def __init__( + self, + *, + config: AwsCredentialsConfig, + sources: Sequence[CredentialsSource] = _DEFAULT_SOURCES, + ): + self._config = config + self._sources: Sequence[CredentialsSource] = sources + self._credentials_resolver: AWSCredentialsResolver | None = None + + async def get_identity( + self, *, identity_properties: IdentityProperties + ) -> AWSCredentialsIdentity: + if self._credentials_resolver is not None: + return await self._credentials_resolver.get_identity( + identity_properties=identity_properties + ) + + for source in self._sources: + if source.is_available(config=self._config): + self._credentials_resolver = source.build_resolver(config=self._config) + return await self._credentials_resolver.get_identity( + identity_properties=identity_properties + ) + + raise SmithyIdentityException( + "None of the configured credentials sources were able to resolve credentials." + ) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/environment.py b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/environment.py index 34cea57a..27e0f803 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/environment.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/environment.py @@ -6,7 +6,12 @@ from smithy_core.exceptions import SmithyIdentityException from smithy_core.interfaces.identity import IdentityProperties -from ..identity import AWSCredentialsIdentity +from smithy_aws_core.credentials_resolvers.interfaces import ( + AwsCredentialsConfig, + CredentialsSource, +) + +from ..identity import AWSCredentialsIdentity, AWSCredentialsResolver class EnvironmentCredentialsResolver( @@ -41,3 +46,13 @@ async def get_identity( ) return self._credentials + + +class EnvironmentCredentialsSource(CredentialsSource): + def is_available(self, config: AwsCredentialsConfig) -> bool: + return ( + "AWS_ACCESS_KEY_ID" in os.environ and "AWS_SECRET_ACCESS_KEY" in os.environ + ) + + def build_resolver(self, config: AwsCredentialsConfig) -> AWSCredentialsResolver: + return EnvironmentCredentialsResolver() diff --git a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py index 6ae6fee0..619cca47 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py @@ -17,8 +17,13 @@ from smithy_http.aio import HTTPRequest from smithy_http.aio.interfaces import HTTPClient +from smithy_aws_core.credentials_resolvers.interfaces import ( + AwsCredentialsConfig, + CredentialsSource, +) + from .. import __version__ -from ..identity import AWSCredentialsIdentity +from ..identity import AWSCredentialsIdentity, AWSCredentialsResolver _USER_AGENT_FIELD = Field( name="User-Agent", @@ -235,3 +240,14 @@ async def get_identity( account_id=account_id, ) return self._credentials + + +class IMDSCredentialsSource(CredentialsSource): + def is_available(self, config: AwsCredentialsConfig) -> bool: + # IMDS credentials should always be the last in the chain + # We cannot check if they're available without actually making a call + return True + + def build_resolver(self, config: AwsCredentialsConfig) -> AWSCredentialsResolver: + # TODO: Configure lower number of retries/lower timeout + return IMDSCredentialsResolver(http_client=config.http_client) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/interfaces.py b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/interfaces.py new file mode 100644 index 00000000..31505072 --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/interfaces.py @@ -0,0 +1,23 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import Protocol + +from smithy_http.aio.interfaces import HTTPClient + +from smithy_aws_core.identity import AWSCredentialsResolver + + +class AwsCredentialsConfig(Protocol): + """Configuration required for resolving credentials.""" + + http_client: HTTPClient + + +class CredentialsSource(Protocol): + def is_available(self, config: AwsCredentialsConfig) -> bool: + """Returns True if credentials are available from this source.""" + ... + + def build_resolver(self, config: AwsCredentialsConfig) -> AWSCredentialsResolver: + """Builds a credentials resolver for the given configuration.""" + ... diff --git a/packages/smithy-aws-core/tests/unit/credentials_resolvers/test_credentials_resolver_chain.py b/packages/smithy-aws-core/tests/unit/credentials_resolvers/test_credentials_resolver_chain.py new file mode 100644 index 00000000..e6adbee2 --- /dev/null +++ b/packages/smithy-aws-core/tests/unit/credentials_resolvers/test_credentials_resolver_chain.py @@ -0,0 +1,181 @@ +from dataclasses import dataclass +from unittest.mock import Mock + +import pytest +from smithy_aws_core.credentials_resolvers import ( + CredentialsResolverChain, + IMDSCredentialsResolver, + StaticCredentialsResolver, +) +from smithy_aws_core.credentials_resolvers.environment import ( + EnvironmentCredentialsSource, +) +from smithy_aws_core.credentials_resolvers.interfaces import ( + AwsCredentialsConfig, + CredentialsSource, +) +from smithy_aws_core.identity import AWSCredentialsIdentity, AWSCredentialsResolver +from smithy_core.exceptions import SmithyIdentityException +from smithy_core.interfaces.identity import IdentityProperties +from smithy_http.aio.interfaces import HTTPClient + + +@dataclass +class Config: + http_client: HTTPClient + + def __init__(self): + self.http_client = Mock(spec=HTTPClient) # type: ignore + + +async def test_no_sources_resolve(): + resolver_chain = CredentialsResolverChain(sources=[], config=Config()) + with pytest.raises(SmithyIdentityException): + await resolver_chain.get_identity(identity_properties=IdentityProperties()) + + +async def test_env_credentials_resolver_not_set(monkeypatch: pytest.MonkeyPatch): + monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False) + monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False) + resolver_chain = CredentialsResolverChain( + sources=[EnvironmentCredentialsSource()], config=Config() + ) + + with pytest.raises(SmithyIdentityException): + await resolver_chain.get_identity(identity_properties=IdentityProperties()) + + +async def test_env_credentials_resolver_partial(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "akid") + monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False) + resolver_chain = CredentialsResolverChain( + sources=[EnvironmentCredentialsSource()], config=Config() + ) + + with pytest.raises(SmithyIdentityException): + await resolver_chain.get_identity(identity_properties=IdentityProperties()) + + +async def test_default_sources_env_credentials_resolver_success( + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "akid") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "secret") + resolver_chain = CredentialsResolverChain(config=Config()) + + credentials = await resolver_chain.get_identity( + identity_properties=IdentityProperties() + ) + assert credentials.access_key_id == "akid" + assert credentials.secret_access_key == "secret" + + +async def test_default_sources_imds_resolver_success(monkeypatch: pytest.MonkeyPatch): + monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False) + monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False) + + async def mock_imds_get_identity( + self: IMDSCredentialsResolver, *, identity_properties: IdentityProperties + ) -> AWSCredentialsIdentity: + return AWSCredentialsIdentity( + access_key_id="akid", + secret_access_key="secret", + ) + + monkeypatch.setattr( + "smithy_aws_core.credentials_resolvers.IMDSCredentialsResolver.get_identity", + mock_imds_get_identity, + ) + + resolver_chain = CredentialsResolverChain(config=Config()) + + credentials = await resolver_chain.get_identity( + identity_properties=IdentityProperties() + ) + assert credentials.access_key_id == "akid" + assert credentials.secret_access_key == "secret" + + +async def test_multiple_sources_one_valid(): + class FailingSource(CredentialsSource): + def is_available(self, config: AwsCredentialsConfig) -> bool: + return False + + def build_resolver( + self, config: AwsCredentialsConfig + ) -> AWSCredentialsResolver: + raise RuntimeError("Should not be called") + + static_credentials = AWSCredentialsIdentity( + access_key_id="valid_akid", secret_access_key="valid_secret" + ) + static_resolver = StaticCredentialsResolver(credentials=static_credentials) + + class ValidSource(CredentialsSource): + def is_available(self, config: AwsCredentialsConfig) -> bool: + return True + + def build_resolver( + self, config: AwsCredentialsConfig + ) -> AWSCredentialsResolver: + return static_resolver + + resolver_chain = CredentialsResolverChain( + sources=[FailingSource(), ValidSource()], config=Config() + ) + + credentials = await resolver_chain.get_identity( + identity_properties=IdentityProperties() + ) + assert credentials.access_key_id == "valid_akid" + assert credentials.secret_access_key == "valid_secret" + + +async def test_cached_resolver_used(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "cached_akid") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "cached_secret") + resolver_chain = CredentialsResolverChain( + sources=[EnvironmentCredentialsSource()], config=Config() + ) + + credentials1 = await resolver_chain.get_identity( + identity_properties=IdentityProperties() + ) + credentials2 = await resolver_chain.get_identity( + identity_properties=IdentityProperties() + ) + + assert credentials1.access_key_id == credentials2.access_key_id == "cached_akid" + assert ( + credentials1.secret_access_key + == credentials2.secret_access_key + == "cached_secret" + ) + + +async def test_custom_sources_with_static_credentials(): + static_credentials = AWSCredentialsIdentity( + access_key_id="static_akid", + secret_access_key="static_secret", + ) + static_resolver = StaticCredentialsResolver(credentials=static_credentials) + + class TestStaticSource(CredentialsSource): + def is_available(self, config: AwsCredentialsConfig) -> bool: + return True + + def build_resolver( + self, config: AwsCredentialsConfig + ) -> AWSCredentialsResolver: + return static_resolver + + resolver_chain = CredentialsResolverChain( + sources=[TestStaticSource()], + config=Config(), # type: ignore + ) + + credentials = await resolver_chain.get_identity( + identity_properties=IdentityProperties() + ) + assert credentials.access_key_id == "static_akid" + assert credentials.secret_access_key == "static_secret" diff --git a/uv.lock b/uv.lock index d716ebc2..6856c99e 100644 --- a/uv.lock +++ b/uv.lock @@ -686,6 +686,7 @@ dependencies = [ [package.optional-dependencies] aiohttp = [ { name = "aiohttp" }, + { name = "yarl" }, ] awscrt = [ { name = "awscrt" }, @@ -693,9 +694,10 @@ awscrt = [ [package.metadata] requires-dist = [ - { name = "aiohttp", marker = "extra == 'aiohttp'", specifier = ">=3.11.12" }, + { name = "aiohttp", marker = "extra == 'aiohttp'", specifier = ">=3.11.12,<4.0" }, { name = "awscrt", marker = "extra == 'awscrt'", specifier = ">=0.23.10" }, { name = "smithy-core", editable = "packages/smithy-core" }, + { name = "yarl", marker = "extra == 'aiohttp'" }, ] provides-extras = ["awscrt", "aiohttp"]