diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 441d14f95..455c88a29 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -54,6 +54,7 @@ jobs: fi python -m pip install -e . - name: Install test dependencies + # ToDo - remove pip install xgboost when flaml is fixed run: | if [ -f requirements-dev.txt ]; then python -m pip install -r requirements-dev.txt @@ -63,6 +64,7 @@ jobs: python -m pip install Pygments respx pytest-xdist markdown beautifulsoup4 Pillow async-cache lxml fi python -m pip install "pandas>=1.3.0" "pygeohash>=1.2.0" + python -m pip install "xgboost" - name: Prepare test dummy data run: | mkdir ~/.msticpy diff --git a/.pylintrc b/.pylintrc index 6ac718a29..ffd2524cd 100644 --- a/.pylintrc +++ b/.pylintrc @@ -68,6 +68,7 @@ disable=raw-checker-failed, useless-suppression, deprecated-pragma, use-symbolic-message-instead, + too-many-positional-arguments, # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/docs/source/conf.py b/docs/source/conf.py index 6c06a9620..119faa093 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -228,6 +228,8 @@ "azure.keyvault.secrets", "azure.keyvault", "azure.kusto.data", + "azure.kusto.data.helpers", + "azure.kusto.data.response", "azure.mgmt.compute.models", "azure.mgmt.compute", "azure.mgmt.keyvault.models", @@ -235,8 +237,11 @@ "azure.mgmt.monitor", "azure.mgmt.network", "azure.mgmt.resource", + "azure.mgmt.resource.subscriptions", "azure.mgmt.resourcegraph", + "azure.mgmt.resourcegraph.models", "azure.mgmt.subscription", + "azure.mgmt.subscription.models", "azure.monitor.query", "azure.storage.blob", "azure.storage", @@ -250,11 +255,12 @@ "ipwhois", "IPython", "ipywidgets", + "jwt", "keyring", "Kqlmagic", "matplotlib.pyplot", "matplotlib", - "mo-sql-parsing", + "mo_sql_parsing", "msal", "msal_extensions", "msrest", diff --git a/msticpy/_version.py b/msticpy/_version.py index fb0fc2d93..bf6f6b7db 100644 --- a/msticpy/_version.py +++ b/msticpy/_version.py @@ -1,3 +1,3 @@ """Version file.""" -VERSION = "2.13.1" +VERSION = "2.13.2" diff --git a/msticpy/auth/azure_auth.py b/msticpy/auth/azure_auth.py index bf24492b3..0c59ec054 100644 --- a/msticpy/auth/azure_auth.py +++ b/msticpy/auth/azure_auth.py @@ -4,14 +4,16 @@ # license information. # -------------------------------------------------------------------------- """Azure authentication handling.""" +from __future__ import annotations import os -from typing import List, Optional -from azure.common.exceptions import CloudError from azure.identity import DeviceCodeCredential from azure.mgmt.subscription import SubscriptionClient +from msticpy.common.provider_settings import ProviderSettings + from .._version import VERSION +from ..common.exceptions import MsticpyAzureConnectionError # pylint: enable=unused-import from ..common.provider_settings import get_provider_settings @@ -31,9 +33,11 @@ def az_connect( - auth_methods: Optional[List[str]] = None, - tenant_id: Optional[str] = None, + auth_methods: list[str] | None = None, + tenant_id: str | None = None, + *, silent: bool = False, + cloud: str | None = None, **kwargs, ) -> AzCredentials: """ @@ -68,7 +72,7 @@ def az_connect( Raises ------ - CloudError + MsticpyAzureConnectionError If chained token credential creation fails. See Also @@ -76,14 +80,16 @@ def az_connect( list_auth_methods """ - az_cloud_config = AzureCloudConfig(cloud=kwargs.get("cloud")) + az_cloud_config = AzureCloudConfig(cloud=cloud) # Use auth_methods param or configuration defaults - data_provs = get_provider_settings(config_section="DataProviders") + data_provs: dict[str, ProviderSettings] = get_provider_settings( + config_section="DataProviders" + ) auth_methods = auth_methods or az_cloud_config.auth_methods tenant_id = tenant_id or az_cloud_config.tenant_id # Ignore AzCLI settings except for authentication creds for EnvCred - az_cli_config = data_provs.get("AzureCLI") + az_cli_config: ProviderSettings | None = data_provs.get("AzureCLI") if ( az_cli_config and az_cli_config.args @@ -99,10 +105,11 @@ def az_connect( os.environ[AzureCredEnvNames.AZURE_CLIENT_SECRET] = ( az_cli_config.args.get("clientSecret") or "" ) - credentials = az_connect_core( + credentials: AzCredentials = az_connect_core( auth_methods=auth_methods, tenant_id=tenant_id, silent=silent, + cloud=cloud, **kwargs, ) sub_client = SubscriptionClient( @@ -111,13 +118,19 @@ def az_connect( credential_scopes=[az_cloud_config.token_uri], ) if not sub_client: - raise CloudError("Could not create a Subscription client.") + err_msg: str = "Could not create an Azure Subscription client with credentials." + raise MsticpyAzureConnectionError( + err_msg, + title="Azure authentication error", + ) return credentials def az_user_connect( - tenant_id: Optional[str] = None, silent: bool = False + tenant_id: str | None = None, + *, + silent: bool = False, ) -> AzCredentials: """ Authenticate to the SDK using user based authentication methods, Azure CLI or interactive logon. @@ -132,17 +145,23 @@ def az_user_connect( Returns ------- - AzCredentials + AzCredentials - Dataclass combining two types of Azure credentials: + - legacy (ADAL) credentials + - modern (MSAL) credentials """ return az_connect_core( - auth_methods=["cli", "interactive"], tenant_id=tenant_id, silent=silent + auth_methods=["cli", "interactive"], + tenant_id=tenant_id, + silent=silent, ) def fallback_devicecode_creds( - cloud: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs -): + cloud: str | None = None, + tenant_id: str | None = None, + region: str | None = None, +) -> AzCredentials: """ Authenticate using device code as a fallback method. @@ -158,30 +177,36 @@ def fallback_devicecode_creds( Returns ------- - AzCredentials - Named tuple of: - - legacy (ADAL) credentials - - modern (MSAL) credentials + AzCredentials - Dataclass combining two types of Azure credentials: + - legacy (ADAL) credentials + - modern (MSAL) credentials Raises ------ - CloudError + MsticpyAzureConnectionError If chained token credential creation fails. """ - cloud = cloud or kwargs.pop("region", AzureCloudConfig().cloud) - az_config = AzureCloudConfig(cloud) - aad_uri = az_config.authority_uri + cloud = cloud or region or AzureCloudConfig().cloud + az_config: AzureCloudConfig = AzureCloudConfig(cloud) + aad_uri: str = az_config.authority_uri tenant_id = tenant_id or az_config.tenant_id creds = DeviceCodeCredential(authority=aad_uri, tenant_id=tenant_id) legacy_creds = CredentialWrapper(creds, resource_id=az_config.token_uri) if not creds: - raise CloudError("Could not obtain credentials.") + err_msg: str = ( + f"Could not obtain credentials for tenant {tenant_id}" + "Please check your Azure configuration and try again." + ) + raise MsticpyAzureConnectionError( + err_msg, + title="Azure authentication error", + ) return AzCredentials(legacy_creds, ChainedTokenCredential(creds)) # type: ignore[arg-type] def get_default_resource_name(resource_uri: str) -> str: """Get a default resource name for a resource URI.""" - separator = "" if resource_uri.strip().endswith("/") else "/" + separator: str = "" if resource_uri.strip().endswith("/") else "/" return f"{resource_uri}{separator}.default" diff --git a/msticpy/auth/azure_auth_core.py b/msticpy/auth/azure_auth_core.py index 8cc197895..34169b9fc 100644 --- a/msticpy/auth/azure_auth_core.py +++ b/msticpy/auth/azure_auth_core.py @@ -9,10 +9,10 @@ import logging import os import sys -from dataclasses import dataclass +from dataclasses import asdict, dataclass from datetime import datetime from enum import Enum -from typing import Callable, ClassVar +from typing import Any, Callable, ClassVar, Iterator from azure.common.credentials import get_cli_profile from azure.core.credentials import TokenCredential @@ -55,6 +55,15 @@ class AzCredentials: legacy: TokenCredential modern: ChainedTokenCredential + # Backward compatibility with namedtuple + def __iter__(self) -> Iterator[Any]: + """Iterate over properties.""" + return iter(asdict(self).values()) + + def __getitem__(self, item) -> Any: + """Get item from properties.""" + return list(asdict(self).values())[item] + # pylint: disable=too-few-public-methods class AzureCredEnvNames: @@ -135,25 +144,31 @@ def _build_env_client( return None -def _build_cli_client(**kwargs) -> AzureCliCredential: +def _build_cli_client( + tenant_id: str | None = None, + **kwargs, +) -> AzureCliCredential: """Build a credential from Azure CLI.""" del kwargs + if tenant_id: + return AzureCliCredential(tenant_id=tenant_id) return AzureCliCredential() def _build_msi_client( tenant_id: str | None = None, aad_uri: str | None = None, + client_id: str | None = None, **kwargs, ) -> ManagedIdentityCredential: """Build a credential from Managed Identity.""" - msi_kwargs = kwargs.copy() - if AzureCredEnvNames.AZURE_CLIENT_ID in os.environ: - msi_kwargs["client_id"] = os.environ[AzureCredEnvNames.AZURE_CLIENT_ID] + msi_kwargs: dict[str, Any] = kwargs.copy() + client_id = client_id or os.environ.get(AzureCredEnvNames.AZURE_CLIENT_ID) return ManagedIdentityCredential( tenant_id=tenant_id, authority=aad_uri, + client_id=client_id, **msi_kwargs, ) @@ -213,10 +228,10 @@ def _build_client_secret_client( def _build_certificate_client( tenant_id: str | None = None, aad_uri: str | None = None, + client_id: str | None = None, **kwargs, ) -> CertificateCredential | None: """Build a credential from Certificate.""" - client_id = kwargs.pop("client_id", None) if not client_id: logger.info( "'certificate' credential requested but client_id param not supplied" @@ -236,7 +251,7 @@ def _build_powershell_client(**kwargs) -> AzurePowerShellCredential: return AzurePowerShellCredential() -_CLIENTS: dict[str, Callable] = dict( +_CLIENTS: dict[str, Callable[..., TokenCredential | None]] = dict( { "env": _build_env_client, "cli": _build_cli_client, @@ -405,15 +420,15 @@ def _create_chained_credential( if not requested_clients: requested_clients = ["env", "cli", "msi", "interactive"] logger.info("No auth methods requested defaulting to: %s", requested_clients) - cred_list = [] + cred_list: list[TokenCredential] = [] invalid_cred_types: list[str] = [] unusable_cred_type: list[str] = [] - for cred_type in requested_clients: # type: ignore[union-attr] + for cred_type in requested_clients: if cred_type not in _CLIENTS: invalid_cred_types.append(cred_type) logger.info("Unknown authentication type requested: %s", cred_type) continue - cred_client = _CLIENTS[cred_type]( + cred_client: TokenCredential | None = _CLIENTS[cred_type]( tenant_id=tenant_id, aad_uri=aad_uri, **kwargs, @@ -427,7 +442,7 @@ def _create_chained_credential( ", ".join(cred.__class__.__name__ for cred in cred_list if cred is not None), ) if not cred_list: - exception_args = [ + exception_args: list[str] = [ "Cannot authenticate - no valid credential types.", "At least one valid authentication method required.", f"Configured auth_types: {','.join(requested_clients)}", diff --git a/msticpy/context/ip_utils.py b/msticpy/context/ip_utils.py index cc821d9b6..0f8796a12 100644 --- a/msticpy/context/ip_utils.py +++ b/msticpy/context/ip_utils.py @@ -19,10 +19,10 @@ import re import socket import warnings -from dataclasses import dataclass, field +from dataclasses import asdict, dataclass, field from functools import lru_cache from time import sleep -from typing import Any, Callable +from typing import Any, Callable, Iterator import httpx import pandas as pd @@ -604,6 +604,15 @@ class _IpWhoIsResult: name: str | None = None properties: dict[str, Any] = field(default_factory=dict) + # Backward compatibility with namedtuple + def __iter__(self) -> Iterator[Any]: + """Iterate over properties.""" + return iter(asdict(self).values()) + + def __getitem__(self, item): + """Get item from properties.""" + return list(asdict(self).values())[item] + @lru_cache(maxsize=1024) def _whois_lookup( diff --git a/msticpy/context/lookup.py b/msticpy/context/lookup.py index ba21f2260..33935b348 100644 --- a/msticpy/context/lookup.py +++ b/msticpy/context/lookup.py @@ -19,15 +19,7 @@ import logging import warnings from collections import ChainMap -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - Iterable, - Mapping, - Sized, -) +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Mapping, Sized import nest_asyncio import pandas as pd @@ -59,6 +51,8 @@ logger: logging.Logger = logging.getLogger(__name__) +_HTTP_PROVIDER_LEGAL_KWARGS: list[str] = ["timeout", "ApiID", "AuthKey", "Instance"] + class ProgressCounter: """Progress counter for async tasks.""" @@ -811,7 +805,13 @@ def _load_providers( # instantiate class sending args from settings to init try: - provider_instance: Provider = provider_class(**(settings.args)) + # filter out any args that are not valid for the provider + provider_args = { + key: value + for key, value in settings.args.items() + if key in _HTTP_PROVIDER_LEGAL_KWARGS + } + provider_instance: Provider = provider_class(**(provider_args)) except MsticpyConfigError as mp_ex: # If the TI Provider didn't load, raise an exception err_msg: str = ( diff --git a/msticpy/context/provider_base.py b/msticpy/context/provider_base.py index 2ff1ef036..b7348849e 100644 --- a/msticpy/context/provider_base.py +++ b/msticpy/context/provider_base.py @@ -211,6 +211,8 @@ def lookup_items( ) results.append(item_result) + if not results: + return pd.DataFrame() return pd.concat(results) async def lookup_items_async( # noqa:PLR0913 diff --git a/msticpy/context/tilookup.py b/msticpy/context/tilookup.py index 4a85374d2..809e7b886 100644 --- a/msticpy/context/tilookup.py +++ b/msticpy/context/tilookup.py @@ -145,6 +145,8 @@ def lookup_iocs( # pylint: disable=too-many-arguments #noqa: PLR0913 *, start: dt.datetime | None = None, end: dt.datetime | None = None, + col: str | None = None, + column: str | None = None, ) -> pd.DataFrame: """ Lookup Threat Intelligence reports for a collection of IoCs in active providers. @@ -200,7 +202,7 @@ def lookup_iocs( # pylint: disable=too-many-arguments #noqa: PLR0913 return _make_sync( self._lookup_iocs_async( data=data, - ioc_col=ioc_col, + ioc_col=ioc_col or column or col, ioc_type_col=ioc_type_col, ioc_query_type=ioc_query_type, providers=providers, diff --git a/tests/auth/test_azure_auth.py b/tests/auth/test_azure_auth.py new file mode 100644 index 000000000..59c94a4b3 --- /dev/null +++ b/tests/auth/test_azure_auth.py @@ -0,0 +1,94 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from msticpy.auth.azure_auth import ( + az_connect, + az_user_connect, + fallback_devicecode_creds, + get_default_resource_name, +) +from msticpy.auth.azure_auth_core import AzCredentials + + +@pytest.fixture +def mock_az_credentials(): + return MagicMock(spec=AzCredentials) + + +@patch("msticpy.auth.azure_auth.os") +@patch("msticpy.auth.azure_auth.get_provider_settings") +@patch("msticpy.auth.azure_auth.az_connect_core") +@patch("msticpy.auth.azure_auth.SubscriptionClient") +def test_az_connect( + mock_sub_client, + mock_az_connect_core, + mock_get_provider_settings, + mock_os, + mock_az_credentials, +): + mock_az_credentials.modern = MagicMock() + mock_az_credentials.modern.__bool__.return_value = True + mock_az_credentials.legacy = MagicMock() + mock_az_connect_core.return_value = mock_az_credentials + mock_sub_client.return_value = MagicMock() + mock_os.environ = MagicMock() + az_cli_args = MagicMock() + az_cli_args.get.return_value = "test_value" + az_cli_args.__bool__.return_value = True + az_cli_config = MagicMock() + az_cli_config.__bool__.return_value = True + az_cli_config.args = az_cli_args + + data_provs = MagicMock(spec=dict) + data_provs.get.return_value = az_cli_config + mock_get_provider_settings.return_value = data_provs + + result = az_connect(auth_methods=["env"], tenant_id="test_tenant", silent=True) + + assert result == mock_az_credentials + mock_az_connect_core.assert_called_once_with( + auth_methods=["env"], + tenant_id="test_tenant", + silent=True, + cloud=None, + ) + mock_sub_client.assert_called_once() + + +@patch("msticpy.auth.azure_auth.az_connect_core") +def test_az_user_connect(mock_az_connect_core, mock_az_credentials): + mock_az_connect_core.return_value = mock_az_credentials + + result = az_user_connect(tenant_id="test_tenant", silent=True) + + assert result == mock_az_credentials + mock_az_connect_core.assert_called_once_with( + auth_methods=["cli", "interactive"], tenant_id="test_tenant", silent=True + ) + + +@patch("msticpy.auth.azure_auth.AzureCloudConfig") +@patch("msticpy.auth.azure_auth.DeviceCodeCredential") +@patch("msticpy.auth.azure_auth.CredentialWrapper") +def test_fallback_devicecode_creds( + mock_cred_wrapper, mock_device_code_cred, mock_azure_cloud_config +): + mock_azure_cloud_config.return_value = MagicMock() + mock_device_code_cred.return_value = MagicMock() + mock_cred_wrapper.return_value = MagicMock() + + result = fallback_devicecode_creds(cloud="test_cloud", tenant_id="test_tenant") + + assert isinstance(result, AzCredentials) + mock_device_code_cred.assert_called_once() + mock_cred_wrapper.assert_called_once() + + +def test_get_default_resource_name(): + resource_uri = "https://example.com/resource" + expected_result = "https://example.com/resource/.default" + + result = get_default_resource_name(resource_uri) + + assert result == expected_result diff --git a/tests/auth/test_azure_auth_core.py b/tests/auth/test_azure_auth_core.py index 313c7ea2b..c9dc363fa 100644 --- a/tests/auth/test_azure_auth_core.py +++ b/tests/auth/test_azure_auth_core.py @@ -4,19 +4,34 @@ # license information. # -------------------------------------------------------------------------- """Module docstring.""" +from __future__ import annotations +import logging +import os from datetime import datetime, timedelta from unittest.mock import MagicMock, patch import pytest import pytest_check as check +from azure.identity import ChainedTokenCredential, DeviceCodeCredential from msticpy.auth.azure_auth_core import ( AzCredentials, AzureCliStatus, AzureCloudConfig, - DeviceCodeCredential, + MsticpyAzureConfigError, _az_connect_core, + _build_certificate_client, + _build_cli_client, + _build_client_secret_client, + _build_device_code_client, _build_env_client, + _build_interactive_client, + _build_msi_client, + _build_powershell_client, + _build_vscode_client, + _create_chained_credential, + _filter_all_warnings, + _filter_credential_warning, check_cli_credentials, ) from msticpy.auth.cloud_mappings import default_auth_methods @@ -119,7 +134,7 @@ def test_check_cli_credentials(get_cli_profile, test, expected): _CLI_ID = "d8d9d2f2-5d2d-4d7e-9c5c-5d6d9d1d8d9d" _TENANT_ID = "f8d9d2f2-5d2d-4d7e-9c5c-5d6d9d1d8d9e" -_TEST_ENV_VARS = ( +_TEST_ENV_VARS: list[tuple[dict[str, str], bool]] = [ ( { "AZURE_CLIENT_ID": _CLI_ID, @@ -160,7 +175,7 @@ def test_check_cli_credentials(get_cli_profile, test, expected): False, ), ({}, False), -) +] @pytest.mark.parametrize("env_vars, expected", _TEST_ENV_VARS) @@ -192,28 +207,7 @@ def test_build_env_client(env_vars, expected, monkeypatch): ], ) def test_az_connect_core(auth_methods, cloud, tenant_id, silent, region, credential): - """ - Test _az_connect_core function with different parameters. - - Parameters - ---------- - auth_methods : list[str] - List of authentication methods to try. - cloud : str - Azure cloud to connect to. - tenant_id : str - Tenant to authenticate against. - silent : bool - Whether to display any output during auth process. - region : str - Azure region to connect to. - credential : AzCredentials - Azure credential to use directly. - - Returns - ------- - None - """ + """Test _az_connect_core function with different parameters.""" # Call the function with the test parameters result = _az_connect_core( auth_methods=auth_methods, @@ -228,3 +222,428 @@ def test_az_connect_core(auth_methods, cloud, tenant_id, silent, region, credent assert isinstance(result, AzCredentials) assert result.legacy is not None assert result.modern is not None + + +@pytest.mark.parametrize( + "env_vars, expected_credential", + [ + ( + { + "AZURE_CLIENT_ID": "test_client_id", + "AZURE_TENANT_ID": "test_tenant_id", + "AZURE_CLIENT_SECRET": "[PLACEHOLDER]", + }, + "EnvironmentCredential", + ), + ( + { + "AZURE_CLIENT_ID": "test_client_id", + "AZURE_TENANT_ID": "test_tenant_id", + "AZURE_CLIENT_CERTIFICATE_PATH": "[PLACEHOLDER]", + }, + "EnvironmentCredential", + ), + ( + { + "AZURE_CLIENT_ID": "test_client_id", + "AZURE_TENANT_ID": "test_tenant_id", + "AZURE_USERNAME": "test_user", + "AZURE_PASSWORD": "[PLACEHOLDER]", + }, + "EnvironmentCredential", + ), + ( + { + "AZURE_CLIENT_ID": "test_client_id", + "AZURE_CLIENT_CERTIFICATE_PATH": "[PLACEHOLDER]", + }, + None, + ), + ( + { + "AZURE_TENANT_ID": "test_tenant_id", + "AZURE_USERNAME": "test_user", + "AZURE_PASSWORD": "[PLACEHOLDER]", + }, + None, + ), + ({}, None), + ], +) +@patch.dict(os.environ, {}, clear=True) +@patch("msticpy.auth.azure_auth_core.EnvironmentCredential", autospec=True) +def test_build_env_client_alt( + mock_env_credential, env_vars, expected_credential, monkeypatch +): + """Test _build_env_client function.""" + for env_var, env_val in env_vars.items(): + monkeypatch.setenv(env_var, env_val) + result = _build_env_client() + if expected_credential: + # assert isinstance(result, mock_env_credential) + mock_env_credential.assert_called_once() + else: + mock_env_credential.assert_not_called() + assert result is None + + +@patch("msticpy.auth.azure_auth_core.AzureCliCredential", autospec=True) +def test_build_cli_client(mock_cli_credential): + """Test _build_cli_client function.""" + result = _build_cli_client() + # assert isinstance(result, mock_cli_credential) + mock_cli_credential.assert_called_once() + + +@pytest.mark.parametrize( + "env_vars, expected_kwargs, tenant_id, aad_uri, client_id", + [ + ( + {"AZURE_CLIENT_ID": "test_client_id"}, + {}, + "test_tenant_id", + "test_aad_uri", + "test_client_id", + ), + ({}, {}, None, None, None), + ], +) +@patch.dict(os.environ, {}, clear=True) +@patch("msticpy.auth.azure_auth_core.ManagedIdentityCredential", autospec=True) +def test_build_msi_client( + mock_msi_credential, + env_vars, + expected_kwargs, + tenant_id, + aad_uri, + client_id, +): + """Test _build_msi_client function.""" + os.environ.update(env_vars) + result = _build_msi_client( + tenant_id=tenant_id, aad_uri=aad_uri, client_id=client_id + ) + # assert isinstance(result, mock_msi_credential) + mock_msi_credential.assert_called_once_with( + tenant_id=tenant_id, authority=aad_uri, client_id=client_id, **expected_kwargs + ) + + +@pytest.mark.parametrize( + "tenant_id, aad_uri", + [ + ("test_tenant_id", "test_aad_uri"), + (None, None), + ], +) +@patch("msticpy.auth.azure_auth_core.VisualStudioCodeCredential", autospec=True) +def test_build_vscode_client(mock_vscode_credential, tenant_id, aad_uri): + """Test _build_vscode_client function.""" + result = _build_vscode_client(tenant_id=tenant_id, aad_uri=aad_uri) + # assert isinstance(result, mock_vscode_credential) + mock_vscode_credential.assert_called_once_with( + tenant_id=tenant_id, authority=aad_uri + ) + + +@pytest.mark.parametrize( + "tenant_id, aad_uri, kwargs", + [ + ("test_tenant_id", "test_aad_uri", {"param": "value"}), + (None, None, {}), + ], +) +@patch("msticpy.auth.azure_auth_core.InteractiveBrowserCredential", autospec=True) +def test_build_interactive_client( + mock_interactive_credential, tenant_id, aad_uri, kwargs +): + """Test _build_interactive_client function.""" + _ = _build_interactive_client(tenant_id=tenant_id, aad_uri=aad_uri, **kwargs) + # assert isinstance(result, mock_interactive_credential) + mock_interactive_credential.assert_called_once_with( + tenant_id=tenant_id, authority=aad_uri, **kwargs + ) + + +@pytest.mark.parametrize( + "tenant_id, aad_uri, kwargs", + [ + ("test_tenant_id", "test_aad_uri", {"param": "value"}), + (None, None, {}), + ], +) +@patch("msticpy.auth.azure_auth_core.DeviceCodeCredential", autospec=True) +def test_build_device_code_client( + mock_device_code_credential, tenant_id, aad_uri, kwargs +): + """Test _build_device_code_client function.""" + _ = _build_device_code_client(tenant_id=tenant_id, aad_uri=aad_uri, **kwargs) + mock_device_code_credential.assert_called_once_with( + tenant_id=tenant_id, authority=aad_uri, **kwargs + ) + + +@pytest.mark.parametrize( + "tenant_id, aad_uri, client_id, client_secret, expected_credential", + [ + ( + "test_tenant_id", + "test_aad_uri", + "test_client_id", + "test_client_secret", + "ClientSecretCredential", + ), + ("test_tenant_id", "test_aad_uri", None, "test_client_secret", None), + ("test_tenant_id", "test_aad_uri", "test_client_id", None, None), + ], +) +@patch("msticpy.auth.azure_auth_core.ClientSecretCredential", autospec=True) +def test_build_client_secret_client( + mock_client_secret_credential, + tenant_id, + aad_uri, + client_id, + client_secret, + expected_credential, +): + """Test _build_client_secret_client function.""" + kwargs = {"client_id": client_id, "client_secret": client_secret} + result = _build_client_secret_client(tenant_id=tenant_id, aad_uri=aad_uri, **kwargs) + if expected_credential: + mock_client_secret_credential.assert_called_once_with( + tenant_id=tenant_id, authority=aad_uri, **kwargs + ) + else: + assert result is None + + +@pytest.mark.parametrize( + "tenant_id, aad_uri, client_id, expected_credential", + [ + ("test_tenant_id", "test_aad_uri", "test_client_id", "CertificateCredential"), + ("test_tenant_id", "test_aad_uri", None, None), + ], +) +@patch("msticpy.auth.azure_auth_core.CertificateCredential", autospec=True) +def test_build_certificate_client( + mock_certificate_credential, tenant_id, aad_uri, client_id, expected_credential +): + """Test _build_certificate_client function.""" + kwargs = {"client_id": client_id} + result = _build_certificate_client(tenant_id=tenant_id, aad_uri=aad_uri, **kwargs) + if expected_credential: + mock_certificate_credential.assert_called_once_with( + tenant_id=tenant_id, authority=aad_uri, **kwargs + ) + else: + assert result is None + + +@patch("msticpy.auth.azure_auth_core.AzurePowerShellCredential", autospec=True) +def test_build_powershell_client(mock_powershell_credential): + """Test _build_powershell_client function.""" + result = _build_powershell_client() + # assert isinstance(result, mock_powershell_credential) + mock_powershell_credential.assert_called_once() + + +@pytest.mark.parametrize( + "requested_clients, tenant_id, aad_uri, kwargs, expected_cred_types, expected_exception", + [ + ( + None, + "test_tenant_id", + "test_aad_uri", + {}, + [ + "AzureCliCredential", + "ManagedIdentityCredential", + "InteractiveBrowserCredential", + ], + None, + ), + ( + ["env", "cli"], + "test_tenant_id", + "test_aad_uri", + {}, + ["AzureCliCredential"], + None, + ), + ( + ["unknown"], + "test_tenant_id", + "test_aad_uri", + {}, + [], + MsticpyAzureConfigError, + ), + ( + ["env-test", "cli", "invalid"], + "test_tenant_id", + "test_aad_uri", + {}, + ["AzureCliCredential"], + None, + ), + ], +) +@patch("msticpy.auth.azure_auth_core.EnvironmentCredential", autospec=True) +@patch("msticpy.auth.azure_auth_core.AzureCliCredential", autospec=True) +@patch("msticpy.auth.azure_auth_core.ManagedIdentityCredential", autospec=True) +@patch("msticpy.auth.azure_auth_core.InteractiveBrowserCredential", autospec=True) +def test_create_chained_credential( + mock_interactive_credential, + mock_msi_credential, + mock_cli_credential, + mock_env_credential, + requested_clients, + tenant_id, + aad_uri, + kwargs, + expected_cred_types, + expected_exception, +): + """ + Test _create_chained_credential function. + + Parameters + ---------- + mock_interactive_credential : MagicMock + Mocked InteractiveBrowserCredential class. + mock_msi_credential : MagicMock + Mocked ManagedIdentityCredential class. + mock_cli_credential : MagicMock + Mocked AzureCliCredential class. + mock_env_credential : MagicMock + Mocked EnvironmentCredential class. + mock_clients : dict + Mocked _CLIENTS dictionary. + requested_clients : list[str] + List of clients to chain. + tenant_id : str + The tenant ID to connect to. + aad_uri : str + The URI of the Azure AD cloud to connect to. + kwargs : dict + Additional keyword arguments. + expected_cred_types : list[str] + Expected credential types to be included in the chained credential. + expected_exception : Exception + Expected exception to be raised. + + Returns + ------- + None + """ + if expected_exception: + with pytest.raises(expected_exception): + _create_chained_credential( + aad_uri=aad_uri, + requested_clients=requested_clients, + tenant_id=tenant_id, + **kwargs + ) + else: + result = _create_chained_credential( + aad_uri=aad_uri, + requested_clients=requested_clients, + tenant_id=tenant_id, + **kwargs + ) + assert isinstance(result, ChainedTokenCredential) + cred_classes = {cred.__class__.__name__ for cred in result.credentials} + assert all(expected in cred_classes for expected in expected_cred_types) + + +@pytest.mark.parametrize( + "record_name, record_level, record_message, expected_output", + [ + ("azure.identity", logging.WARNING, "EnvironmentCredential.get_token", False), + ("azure.identity", logging.WARNING, "AzureCliCredential.get_token", False), + ( + "azure.identity", + logging.WARNING, + "ManagedIdentityCredential.get_token", + False, + ), + ("azure.identity", logging.WARNING, "SomeOtherCredential.get_token", False), + ("azure.identity", logging.INFO, "EnvironmentCredential.get_token", True), + ("some.other.logger", logging.WARNING, "EnvironmentCredential.get_token", True), + ], +) +def test_filter_credential_warning( + record_name, record_level, record_message, expected_output +): + """ + Test _filter_credential_warning function. + + Parameters + ---------- + record_name : str + The name of the log record. + record_level : int + The level of the log record. + record_message : str + The message of the log record. + expected_output : bool + The expected output of the function. + + Returns + ------- + None + """ + record = MagicMock() + record.name = record_name + record.levelno = record_level + record.getMessage.return_value = record_message + + result = _filter_credential_warning(record) + assert result == expected_output + + +@pytest.mark.parametrize( + "record_name, record_level, record_message, expected_output", + [ + ("azure.identity", logging.WARNING, "EnvironmentCredential.get_token", False), + ("azure.identity", logging.WARNING, "AzureCliCredential.get_token", False), + ( + "azure.identity", + logging.WARNING, + "ManagedIdentityCredential.get_token", + False, + ), + ("azure.identity", logging.WARNING, "SomeOtherCredential.get_token", False), + ("azure.identity", logging.WARNING, "Some other warning message", True), + ("azure.identity", logging.INFO, "EnvironmentCredential.get_token", True), + ("some.other.logger", logging.WARNING, "EnvironmentCredential.get_token", True), + ], +) +def test_filter_all_warnings( + record_name, record_level, record_message, expected_output +): + """ + Test _filter_all_warnings function. + + Parameters + ---------- + record_name : str + The name of the log record. + record_level : int + The level of the log record. + record_message : str + The message of the log record. + expected_output : bool + The expected output of the function. + + Returns + ------- + None + """ + record = MagicMock() + record.name = record_name + record.levelno = record_level + record.getMessage.return_value = record_message + + result = _filter_all_warnings(record) + assert result == expected_output