diff --git a/pyproject.toml b/pyproject.toml index 0e2ed128ef..e72250b3b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,7 @@ pip: tests that rely on pip install operations""" timeout = 300 [tool.ruff] -target-version = "py39" +target-version = "py310" line-length = 100 [tool.ruff.lint] diff --git a/setup.py b/setup.py index f7e09c0d05..29cf61d1c2 100644 --- a/setup.py +++ b/setup.py @@ -145,7 +145,7 @@ "ape_pm=ape_pm._cli:cli", ], }, - python_requires=">=3.9,<4", + python_requires=">=3.10,<4", extras_require=extras_require, py_modules=list(_MODULES), license="Apache-2.0", diff --git a/src/ape/_cli.py b/src/ape/_cli.py index 9d6e78df61..21ed246ad1 100644 --- a/src/ape/_cli.py +++ b/src/ape/_cli.py @@ -6,7 +6,7 @@ from gettext import gettext from importlib.metadata import entry_points from pathlib import Path -from typing import Any, Optional +from typing import Any from warnings import catch_warnings, simplefilter import click @@ -153,7 +153,7 @@ def commands(self) -> dict: def list_commands(self, ctx) -> list[str]: return [k for k in self.commands] - def get_command(self, ctx, name) -> Optional[click.Command]: + def get_command(self, ctx, name) -> click.Command | None: try: return self.commands[name]() except Exception as err: diff --git a/src/ape/api/accounts.py b/src/ape/api/accounts.py index 685c1f5a4d..97bb2acd19 100644 --- a/src/ape/api/accounts.py +++ b/src/ape/api/accounts.py @@ -4,7 +4,7 @@ from contextlib import contextmanager from functools import cached_property from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any import click from eip712.messages import EIP712Message @@ -73,14 +73,14 @@ def __dir__(self) -> list[str]: ] @property - def alias(self) -> Optional[str]: + def alias(self) -> str | None: """ A shortened-name for quicker access to the account. """ return None @property - def public_key(self) -> Optional["HexBytes"]: + def public_key(self) -> "HexBytes | None": """ The public key for the account. @@ -95,7 +95,7 @@ def prepare_transaction(self, txn: "TransactionAPI", **kwargs) -> "TransactionAP prepared_tx = super().prepare_transaction(txn, **kwargs) return (self.sign_transaction(prepared_tx) or prepared_tx) if sign else prepared_tx - def sign_raw_msghash(self, msghash: "HexBytes") -> Optional[MessageSignature]: + def sign_raw_msghash(self, msghash: "HexBytes") -> MessageSignature | None: """ Sign a raw message hash. @@ -115,19 +115,19 @@ def sign_raw_msghash(self, msghash: "HexBytes") -> Optional[MessageSignature]: def sign_authorization( self, address: Any, - chain_id: Optional[int] = None, - nonce: Optional[int] = None, - ) -> Optional[MessageSignature]: + chain_id: int | None = None, + nonce: int | None = None, + ) -> MessageSignature | None: """ Sign an `EIP-7702 `__ Authorization. Args: address (Any): A delegate address to sign the authorization for. - chain_id (Optional[int]): + chain_id (int | None): The chain ID that the authorization should be valid for. A value of ``0`` means that the authorization is valid for **any chain**. Default tells implementation to use the currently connected network's ``chain_id``. - nonce (Optional[int]): + nonce (int | None): The nonce to use to sign authorization with. Defaults to account's current nonce. Returns: @@ -146,7 +146,7 @@ def sign_authorization( ) @abstractmethod - def sign_message(self, msg: Any, **signer_options) -> Optional[MessageSignature]: + def sign_message(self, msg: Any, **signer_options) -> MessageSignature | None: """ Sign a message. @@ -164,7 +164,7 @@ def sign_message(self, msg: Any, **signer_options) -> Optional[MessageSignature] """ @abstractmethod - def sign_transaction(self, txn: TransactionAPI, **signer_options) -> Optional[TransactionAPI]: + def sign_transaction(self, txn: TransactionAPI, **signer_options) -> TransactionAPI | None: """ Sign a transaction. @@ -259,9 +259,9 @@ def call( def transfer( self, - account: Union[str, AddressType, BaseAddress], - value: Optional[Union[str, int]] = None, - data: Optional[Union[bytes, str]] = None, + account: str | AddressType | BaseAddress, + value: str | int | None = None, + data: bytes | str | None = None, private: bool = False, **kwargs, ) -> ReceiptAPI: @@ -273,9 +273,9 @@ def transfer( and using a provider that does not support private transactions. Args: - account (Union[str, AddressType, BaseAddress]): The receiver of the funds. - value (Optional[Union[str, int]]): The amount to send. - data (Optional[Union[bytes, str]]): Extra data to include in the transaction. + account (str | AddressType | BaseAddress): The receiver of the funds. + value (str | int | None): The amount to send. + data (bytes | str | None): Extra data to include in the transaction. private (bool): ``True`` asks the provider to make the transaction private. For example, EVM providers typically use the RPC ``eth_sendPrivateTransaction`` to achieve this. Local providers may ignore @@ -412,8 +412,8 @@ def declare(self, contract: "ContractContainer", *args, **kwargs) -> ReceiptAPI: def check_signature( self, - data: Union[SignableMessage, TransactionAPI, str, EIP712Message, int, bytes], - signature: Optional[MessageSignature] = None, # TransactionAPI doesn't need it + data: SignableMessage | TransactionAPI | str | EIP712Message | int | bytes, + signature: MessageSignature | None = None, # TransactionAPI doesn't need it recover_using_eip191: bool = True, ) -> bool: """ @@ -422,7 +422,7 @@ def check_signature( Args: data (Union[:class:`~ape.types.signatures.SignableMessage`, :class:`~ape.api.transactions.TransactionAPI`]): # noqa: E501 The message or transaction to verify. - signature (Optional[:class:`~ape.types.signatures.MessageSignature`]): + signature (MessageSignature | None): The signature to check. Defaults to ``None`` and is not needed when the first argument is a transaction class. recover_using_eip191 (bool): @@ -459,7 +459,7 @@ def check_signature( else: raise AccountsError(f"Unsupported message type: {type(data)}.") - def get_deployment_address(self, nonce: Optional[int] = None) -> AddressType: + def get_deployment_address(self, nonce: int | None = None) -> AddressType: """ Get a contract address before it is deployed. This is useful when you need to pass the contract address to another contract @@ -481,7 +481,7 @@ def get_deployment_address(self, nonce: Optional[int] = None) -> AddressType: nonce = self.nonce if nonce is None else nonce return ecosystem.get_deployment_address(self.address, nonce) - def set_delegate(self, contract: Union[BaseAddress, AddressType, str], **txn_kwargs): + def set_delegate(self, contract: BaseAddress | AddressType | str, **txn_kwargs): """ Have the account class override the value of its ``delegate``. For plugins that support this feature, the way they choose to handle it can vary. For example, it could be a call to @@ -526,9 +526,9 @@ def remove_delegate(self, **txn_kwargs): @contextmanager def delegate_to( self, - new_delegate: Union[BaseAddress, AddressType, str], - set_txn_kwargs: Optional[dict] = None, - reset_txn_kwargs: Optional[dict] = None, + new_delegate: BaseAddress | AddressType | str, + set_txn_kwargs: dict | None = None, + reset_txn_kwargs: dict | None = None, **txn_kwargs, ) -> Iterator[BaseAddress]: """ @@ -799,7 +799,7 @@ def get_test_account(self, index: int) -> "TestAccountAPI": # type: ignore[empt """ @abstractmethod - def generate_account(self, index: Optional[int] = None) -> "TestAccountAPI": + def generate_account(self, index: int | None = None) -> "TestAccountAPI": """ Generate a new test account. """ @@ -832,10 +832,10 @@ class ImpersonatedAccount(AccountAPI): def address(self) -> AddressType: return self.raw_address - def sign_message(self, msg: Any, **signer_options) -> Optional[MessageSignature]: + def sign_message(self, msg: Any, **signer_options) -> MessageSignature | None: raise APINotImplementedError("This account cannot sign messages") - def sign_transaction(self, txn: TransactionAPI, **signer_options) -> Optional[TransactionAPI]: + def sign_transaction(self, txn: TransactionAPI, **signer_options) -> TransactionAPI | None: # Returns input transaction unsigned (since it doesn't have access to the key) return txn diff --git a/src/ape/api/address.py b/src/ape/api/address.py index 4238b6ec6e..36596442c6 100644 --- a/src/ape/api/address.py +++ b/src/ape/api/address.py @@ -1,6 +1,6 @@ from abc import abstractmethod from functools import cached_property -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from ape.exceptions import AccountsError, ConversionError from ape.types.address import AddressType @@ -179,7 +179,7 @@ def is_contract(self) -> bool: return self.codesize > 0 @property - def delegate(self) -> Optional["BaseAddress"]: + def delegate(self) -> "BaseAddress | None": """ Check and see if Account has a "delegate" contract, which is a contract that this account delegates functionality to. This could be from many contexts, such as a Smart Wallet like diff --git a/src/ape/api/compiler.py b/src/ape/api/compiler.py index eec9c47265..b3e222e141 100644 --- a/src/ape/api/compiler.py +++ b/src/ape/api/compiler.py @@ -2,7 +2,7 @@ from collections.abc import Iterable, Iterator from functools import cached_property from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from ape.exceptions import APINotImplementedError, ContractLogicError from ape.utils.basemodel import BaseInterfaceModel @@ -44,12 +44,12 @@ def name(self) -> str: The name of the compiler. """ - def get_config(self, project: Optional["ProjectManager"] = None) -> "PluginConfig": + def get_config(self, project: "ProjectManager | None" = None) -> "PluginConfig": """ The combination of settings from ``ape-config.yaml`` and ``.compiler_settings``. Args: - project (Optional[:class:`~ape.managers.project.ProjectManager`]): Optionally provide + project ("ProjectManager | None"): Optionally provide the project containing the base paths and full source set. Defaults to the local project. Dependencies will change this value to their respective projects. @@ -77,7 +77,7 @@ def get_versions(self, all_paths: Iterable[Path]) -> set[str]: # type: ignore[e def get_compiler_settings( # type: ignore[empty-body] self, contract_filepaths: Iterable[Path], - project: Optional["ProjectManager"] = None, + project: "ProjectManager | None" = None, **overrides, ) -> dict["Version", dict]: """ @@ -86,7 +86,7 @@ def get_compiler_settings( # type: ignore[empty-body] Args: contract_filepaths (Iterable[pathlib.Path]): The list of paths. - project (Optional[:class:`~ape.managers.project.ProjectManager`]): Optionally provide + project ("ProjectManager | None"): Optionally provide the project containing the base paths and full source set. Defaults to the local project. Dependencies will change this value to their respective projects. **overrides: Settings overrides. @@ -99,18 +99,18 @@ def get_compiler_settings( # type: ignore[empty-body] def compile( self, contract_filepaths: Iterable[Path], - project: Optional["ProjectManager"], - settings: Optional[dict] = None, + project: "ProjectManager | None", + settings: dict | None = None, ) -> Iterator["ContractType"]: """ Compile the given source files. All compiler plugins must implement this function. Args: contract_filepaths (Iterable[pathlib.Path]): A list of source file paths to compile. - project (Optional[:class:`~ape.managers.project.ProjectManager`]): Optionally provide + project ("ProjectManager | None"): Optionally provide the project containing the base paths and full source set. Defaults to the local project. Dependencies will change this value to their respective projects. - settings (Optional[dict]): Adhoc compiler settings. + settings (dict | None): Adhoc compiler settings. Returns: list[:class:`~ape.type.contract.ContractType`] @@ -120,8 +120,8 @@ def compile( def compile_code( # type: ignore[empty-body] self, code: str, - project: Optional["ProjectManager"], - settings: Optional[dict] = None, + project: "ProjectManager | None", + settings: dict | None = None, **kwargs, ) -> "ContractType": """ @@ -129,7 +129,7 @@ def compile_code( # type: ignore[empty-body] Args: code (str): The code to compile. - project (Optional[:class:`~ape.managers.project.ProjectManager`]): Optionally provide + project ("ProjectManager | None"): Optionally provide the project containing the base paths and full source set. Defaults to the local project. Dependencies will change this value to their respective projects. settings (Optional[Dict]): Adhoc compiler settings. @@ -141,7 +141,7 @@ def compile_code( # type: ignore[empty-body] @raises_not_implemented def get_imports( # type: ignore[empty-body] - self, contract_filepaths: Iterable[Path], project: Optional["ProjectManager"] + self, contract_filepaths: Iterable[Path], project: "ProjectManager | None" ) -> dict[str, list[str]]: """ Returns a list of imports as source_ids for each contract's source_id in a given @@ -149,7 +149,7 @@ def get_imports( # type: ignore[empty-body] Args: contract_filepaths (Iterable[pathlib.Path]): A list of source file paths to compile. - project (Optional[:class:`~ape.managers.project.ProjectManager`]): Optionally provide + project ("ProjectManager | None"): Optionally provide the project containing the base paths and full source set. Defaults to the local project. Dependencies will change this value to their respective projects. @@ -161,7 +161,7 @@ def get_imports( # type: ignore[empty-body] def get_version_map( # type: ignore[empty-body] self, contract_filepaths: Iterable[Path], - project: Optional["ProjectManager"] = None, + project: "ProjectManager | None" = None, ) -> dict["Version", set[Path]]: """ Get a map of versions to source paths. @@ -169,7 +169,7 @@ def get_version_map( # type: ignore[empty-body] Args: contract_filepaths (Iterable[Path]): Input source paths. Defaults to all source paths per compiler. - project (Optional[:class:`~ape.managers.project.ProjectManager`]): Optionally provide + project ("ProjectManager | None"): Optionally provide the project containing the base paths and full source set. Defaults to the local project. Dependencies will change this value to their respective projects. @@ -238,7 +238,7 @@ def trace_source( # type: ignore[empty-body] @raises_not_implemented def flatten_contract( # type: ignore[empty-body] - self, path: Path, project: Optional["ProjectManager"] = None, **kwargs + self, path: Path, project: "ProjectManager | None" = None, **kwargs ) -> "Content": """ Get the content of a flattened contract via its source path. @@ -247,7 +247,7 @@ def flatten_contract( # type: ignore[empty-body] Args: path (``pathlib.Path``): The source path of the contract. - project (Optional[:class:`~ape.managers.project.ProjectManager`]): Optionally provide + project ("ProjectManager | None"): Optionally provide the project containing the base paths and full source set. Defaults to the local project. Dependencies will change this value to their respective projects. **kwargs (Any): Additional compiler-specific settings. See specific diff --git a/src/ape/api/config.py b/src/ape/api/config.py index d4b1d3ee4b..9ddf3b2a21 100644 --- a/src/ape/api/config.py +++ b/src/ape/api/config.py @@ -3,7 +3,7 @@ from enum import Enum from functools import cached_property from pathlib import Path -from typing import Any, Optional, TypeVar, cast +from typing import Any, TypeVar, cast import yaml from ethpm_types import PackageManifest, PackageMeta, Source @@ -73,7 +73,7 @@ class PluginConfig(BaseSettings): @classmethod def from_overrides( - cls, overrides: dict, plugin_name: Optional[str] = None, project_path: Optional[Path] = None + cls, overrides: dict, plugin_name: str | None = None, project_path: Path | None = None ) -> "PluginConfig": default_values = cls().model_dump() @@ -100,8 +100,8 @@ def update(root: dict, value_map: dict): @classmethod def _find_plugin_config_problems( - cls, err: ValidationError, plugin_name: str, project_path: Optional[Path] = None - ) -> Optional[str]: + cls, err: ValidationError, plugin_name: str, project_path: Path | None = None + ) -> str | None: # Attempt showing line-nos for failed plugin config validation. # This is trickier than root-level data since by this time, we # no longer are aware of which files are responsible for which config. @@ -131,7 +131,7 @@ def _find_plugin_config_problems( @classmethod def _find_plugin_config_problems_from_file( cls, err: ValidationError, base_path: Path - ) -> Optional[str]: + ) -> str | None: cfg_files = _find_config_yaml_files(base_path) for cfg_file in cfg_files: if problems := _get_problem_with_config(err.errors(), cfg_file): @@ -174,7 +174,7 @@ def __str__(self) -> str: ) return yaml.safe_dump(data) - def get(self, key: str, default: Optional[ConfigItemType] = None) -> ConfigItemType: + def get(self, key: str, default: ConfigItemType | None = None) -> ConfigItemType: extra: dict = self.__pydantic_extra__ or {} return self.__dict__.get(key, extra.get(key, default)) # type: ignore @@ -202,7 +202,7 @@ class DeploymentConfig(PluginConfig): """ -def _get_problem_with_config(errors: list, path: Path) -> Optional[str]: +def _get_problem_with_config(errors: list, path: Path) -> str | None: # Attempt to find line numbers in the config matching. cfg_content = Source(content=path.read_text(encoding="utf8")).content if not cfg_content: @@ -303,7 +303,7 @@ def __init__(self, *args, **kwargs): # NOTE: Cannot reference `self` at all until after super init. self._project_path = project_path - contracts_folder: Optional[str] = None + contracts_folder: str | None = None """ The path to the folder containing the contract source files. **NOTE**: Non-absolute paths are relative to the project-root. @@ -352,7 +352,7 @@ def __init__(self, *args, **kwargs): The name of the project. """ - base_path: Optional[str] = None + base_path: str | None = None """ Use this when the project's base-path is not the root of the project. @@ -527,7 +527,7 @@ def model_dump(self, *args, **kwargs): # shouldn't. Figure out why. return {k: v for k, v in res.items() if not k.startswith("_")} - def get(self, name: str) -> Optional[Any]: + def get(self, name: str) -> Any | None: return self.__getattr__(name) def get_config(self, plugin_name: str) -> PluginConfig: @@ -538,7 +538,7 @@ def get_config(self, plugin_name: str) -> PluginConfig: or self.get_unknown_config(name) ) - def get_plugin_config(self, name: str) -> Optional[PluginConfig]: + def get_plugin_config(self, name: str) -> PluginConfig | None: name = name.replace("-", "_") cfg = self._plugin_configs.get(name, {}) if cfg and not isinstance(cfg, dict): @@ -568,7 +568,7 @@ def _get_config_plugin_classes(self): # NOTE: Abstracted for easily mocking in tests. return self.plugin_manager.config_class - def get_custom_ecosystem_config(self, name: str) -> Optional[PluginConfig]: + def get_custom_ecosystem_config(self, name: str) -> PluginConfig | None: name = name.replace("-", "_") if not (networks := self.get_plugin_config("networks")): # Shouldn't happen. diff --git a/src/ape/api/explorers.py b/src/ape/api/explorers.py index e1901d4a0f..1aeaa16788 100644 --- a/src/ape/api/explorers.py +++ b/src/ape/api/explorers.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from ape.api.networks import NetworkAPI from ape.utils.basemodel import BaseInterfaceModel @@ -44,7 +44,7 @@ def get_transaction_url(self, transaction_hash: str) -> str: """ @abstractmethod - def get_contract_type(self, address: "AddressType") -> Optional["ContractType"]: + def get_contract_type(self, address: "AddressType") -> "ContractType | None": """ Get the contract type for a given address if it has been published to this explorer. diff --git a/src/ape/api/networks.py b/src/ape/api/networks.py index 490ad44aa8..65c16a6cfb 100644 --- a/src/ape/api/networks.py +++ b/src/ape/api/networks.py @@ -3,7 +3,7 @@ from collections.abc import Collection, Iterator, Sequence from functools import cached_property, partial from pathlib import Path -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union +from typing import TYPE_CHECKING, Any, ClassVar from eth_pydantic_types import HexBytes from eth_utils import keccak @@ -90,7 +90,7 @@ def __repr__(self) -> str: return " Optional["MethodABI"]: + def abi(self) -> "MethodABI | None": """ Some proxies have special ABIs which may not exist in their contract-types by default, such as Safe's ``masterCopy()``. @@ -119,7 +119,7 @@ class EcosystemAPI(ExtraAttributesMixin, BaseInterfaceModel): fee_token_decimals: int = 18 """The number of the decimals the fee token has.""" - _default_network: Optional[str] = None + _default_network: str | None = None """The default network of the ecosystem, such as ``local``.""" @model_validator(mode="after") @@ -480,8 +480,8 @@ def decode_logs(self, logs: Sequence[dict], *events: "EventABI") -> Iterator["Co @raises_not_implemented def decode_primitive_value( # type: ignore[empty-body] - self, value: Any, output_type: Union[str, tuple, list] - ) -> Union[str, HexBytes, tuple]: + self, value: Any, output_type: str | tuple | list + ) -> str | HexBytes | tuple: """ Decode a primitive value-type given its ABI type as a ``str`` and the value itself. This method is a hook for converting @@ -490,10 +490,10 @@ def decode_primitive_value( # type: ignore[empty-body] Args: value (Any): The value to decode. - output_type (Union[str, tuple, list]): The value type. + output_type (str | tuple | list): The value type. Returns: - Union[str, HexBytes, tuple] + str | HexBytes | tuple """ @abstractmethod @@ -509,12 +509,12 @@ def create_transaction(self, **kwargs) -> "TransactionAPI": """ @abstractmethod - def decode_calldata(self, abi: Union["ConstructorABI", "MethodABI"], calldata: bytes) -> dict: + def decode_calldata(self, abi: "ConstructorABI | MethodABI", calldata: bytes) -> dict: """ Decode method calldata. Args: - abi (Union[ConstructorABI, MethodABI]): The method called. + abi ("ConstructorABI | MethodABI"): The method called. calldata (bytes): The raw calldata bytes. Returns: @@ -524,12 +524,12 @@ def decode_calldata(self, abi: Union["ConstructorABI", "MethodABI"], calldata: b """ @abstractmethod - def encode_calldata(self, abi: Union["ConstructorABI", "MethodABI"], *args: Any) -> HexBytes: + def encode_calldata(self, abi: "ConstructorABI | MethodABI", *args: Any) -> HexBytes: """ Encode method calldata. Args: - abi (Union[ConstructorABI, MethodABI]): The ABI of the method called. + abi ("ConstructorABI | MethodABI"): The ABI of the method called. *args (Any): The arguments given to the method. Returns: @@ -587,7 +587,7 @@ def get_network(self, network_name: str) -> "NetworkAPI": raise NetworkNotFoundError(network_name, ecosystem=self.name, options=networks) def get_network_data( - self, network_name: str, provider_filter: Optional[Collection[str]] = None + self, network_name: str, provider_filter: Collection[str] | None = None ) -> dict: """ Get a dictionary of data about providers in the network. @@ -597,7 +597,7 @@ def get_network_data( Args: network_name (str): The name of the network to get provider data from. - provider_filter (Optional[Collection[str]]): Optionally filter the providers + provider_filter (Collection[str] | None): Optionally filter the providers by name. Returns: @@ -629,7 +629,7 @@ def get_network_data( return data - def get_proxy_info(self, address: AddressType) -> Optional[ProxyInfoAPI]: + def get_proxy_info(self, address: AddressType) -> ProxyInfoAPI | None: """ Information about a proxy contract such as proxy type and implementation address. @@ -637,7 +637,7 @@ def get_proxy_info(self, address: AddressType) -> Optional[ProxyInfoAPI]: address (:class:`~ape.types.address.AddressType`): The address of the contract. Returns: - Optional[:class:`~ape.api.networks.ProxyInfoAPI`]: Returns ``None`` if the contract + ProxyInfoAPI | None: Returns ``None`` if the contract does not use any known proxy pattern. """ return None @@ -684,7 +684,7 @@ def enrich_trace(self, trace: "TraceAPI", **kwargs) -> "TraceAPI": @raises_not_implemented def get_python_types( # type: ignore[empty-body] self, abi_type: "ABIType" - ) -> Union[type, Sequence]: + ) -> type | Sequence: """ Get the Python types for a given ABI type. @@ -692,7 +692,7 @@ def get_python_types( # type: ignore[empty-body] abi_type (``ABIType``): The ABI type to get the Python types for. Returns: - Union[Type, Sequence]: The Python types for the given ABI type. + type | Sequence: The Python types for the given ABI type. """ @raises_not_implemented @@ -701,7 +701,7 @@ def decode_custom_error( data: HexBytes, address: AddressType, **kwargs, - ) -> Optional[CustomError]: + ) -> CustomError | None: """ Decode a custom error class from an ABI defined in a contract. @@ -713,7 +713,7 @@ def decode_custom_error( **kwargs: Additional init kwargs for the custom error class. Returns: - Optional[CustomError]: If it able to decode one, else ``None``. + CustomError | None: If it able to decode one, else ``None``. """ def _get_request_headers(self) -> RPCHeaders: @@ -760,7 +760,7 @@ class ProviderContextManager(ManagerAccessMixin): # due to an exception, when interactive mode is set. If we don't hold on # to a reference to this object, the provider is dropped and reconnecting results # in losing state when using a spawned local provider - _recycled_provider: ClassVar[Optional["ProviderAPI"]] = None + _recycled_provider: ClassVar["ProviderAPI | None"] = None def __init__( self, @@ -1035,7 +1035,7 @@ def transaction_acceptance_timeout(self) -> int: ) @cached_property - def explorer(self) -> Optional["ExplorerAPI"]: + def explorer(self) -> "ExplorerAPI | None": """ The block-explorer for the given network. @@ -1076,7 +1076,7 @@ def is_mainnet(self) -> bool: """ True when the network is the mainnet network for the ecosystem. """ - cfg_is_mainnet: Optional[bool] = self.config.get("is_mainnet") + cfg_is_mainnet: bool | None = self.config.get("is_mainnet") if cfg_is_mainnet is not None: return cfg_is_mainnet @@ -1180,8 +1180,8 @@ def _get_plugin_provider_names(self) -> Iterator[str]: def get_provider( self, - provider_name: Optional[str] = None, - provider_settings: Optional[dict] = None, + provider_name: str | None = None, + provider_settings: dict | None = None, connect: bool = False, ): """ @@ -1239,8 +1239,8 @@ def get_provider( def use_provider( self, - provider: Union[str, "ProviderAPI"], - provider_settings: Optional[dict] = None, + provider: "str | ProviderAPI", + provider_settings: dict | None = None, disconnect_after: bool = False, disconnect_on_exit: bool = True, ) -> ProviderContextManager: @@ -1289,7 +1289,7 @@ def use_provider( ) @property - def default_provider_name(self) -> Optional[str]: + def default_provider_name(self) -> str | None: """ The name of the default provider or ``None``. @@ -1314,7 +1314,7 @@ def default_provider_name(self) -> Optional[str]: return None @property - def default_provider(self) -> Optional["ProviderAPI"]: + def default_provider(self) -> "ProviderAPI | None": if (name := self.default_provider_name) and name in self.providers: return self.get_provider(name) @@ -1342,7 +1342,7 @@ def set_default_provider(self, provider_name: str): def use_default_provider( self, - provider_settings: Optional[dict] = None, + provider_settings: dict | None = None, disconnect_after: bool = False, ) -> ProviderContextManager: """ diff --git a/src/ape/api/projects.py b/src/ape/api/projects.py index 6a4cf254f7..22b3ffcc48 100644 --- a/src/ape/api/projects.py +++ b/src/ape/api/projects.py @@ -1,7 +1,6 @@ from abc import abstractmethod from functools import cached_property from pathlib import Path -from typing import Optional from pydantic import Field, field_validator @@ -103,7 +102,7 @@ def extract_config(self, **overrides) -> "ApeConfig": """ @classmethod - def attempt_validate(cls, **kwargs) -> Optional["ProjectAPI"]: + def attempt_validate(cls, **kwargs) -> "ProjectAPI | None": try: instance = cls(**kwargs) except ValueError: diff --git a/src/ape/api/providers.py b/src/ape/api/providers.py index 647d979150..58486bcae2 100644 --- a/src/ape/api/providers.py +++ b/src/ape/api/providers.py @@ -14,7 +14,7 @@ from pathlib import Path from signal import SIGINT, SIGTERM, signal from subprocess import DEVNULL, PIPE, Popen -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast +from typing import TYPE_CHECKING, Any, ClassVar, cast from eth_pydantic_types import HexBytes from eth_utils import to_hex @@ -71,12 +71,12 @@ class BlockAPI(BaseInterfaceModel): The number of transactions in the block. """ - hash: Optional[Any] = None # NOTE: pending block does not have a hash + hash: Any | None = None # NOTE: pending block does not have a hash """ The block hash identifier. """ - number: Optional[HexInt] = None # NOTE: pending block does not have a number + number: HexInt | None = None # NOTE: pending block does not have a number """ The block number identifier. """ @@ -94,7 +94,7 @@ class BlockAPI(BaseInterfaceModel): NOTE: The pending block uses the current timestamp. """ - _size: Optional[HexInt] = None + _size: HexInt | None = None @log_instead_of_fail(default="") def __repr__(self) -> str: @@ -189,7 +189,7 @@ class CallResult(BaseModel, ManagerAccessMixin): NOTE: Currently, you only get this for reverted calls from ``ProviderAPI``. """ - revert: Optional[ContractLogicError] = None + revert: ContractLogicError | None = None """ The revert, if the call reverted. """ @@ -220,7 +220,7 @@ def reverted(self) -> bool: return self.revert is not None @property - def revert_message(self) -> Optional[str]: + def revert_message(self) -> str | None: """ The revert message, if the call reverted. """ @@ -247,7 +247,7 @@ class ProviderAPI(BaseInterfaceModel): """ # TODO: In 0.9, make not optional. - NAME: ClassVar[Optional[str]] = None + NAME: ClassVar[str | None] = None # TODO: Remove in 0.9 and have NAME be defined at the class-level (in plugins). name: str @@ -312,21 +312,21 @@ def disconnect(self): """ @property - def ipc_path(self) -> Optional[Path]: + def ipc_path(self) -> Path | None: """ Return the IPC path for the provider, if supported. """ return None @property - def http_uri(self) -> Optional[str]: + def http_uri(self) -> str | None: """ Return the raw HTTP/HTTPS URI to connect to this provider, if supported. """ return None @property - def ws_uri(self) -> Optional[str]: + def ws_uri(self) -> str | None: """ Return the raw WS/WSS URI to connect to this provider, if supported. """ @@ -342,7 +342,7 @@ def settings(self) -> "PluginConfig": return CustomConfig.model_validate(data) @property - def connection_id(self) -> Optional[str]: + def connection_id(self) -> str | None: """ A connection ID to uniquely identify and manage multiple connections to providers, especially when working with multiple @@ -382,7 +382,7 @@ def chain_id(self) -> int: """ @abstractmethod - def get_balance(self, address: "AddressType", block_id: Optional["BlockID"] = None) -> int: + def get_balance(self, address: "AddressType", block_id: "BlockID | None" = None) -> int: """ Get the balance of an account. @@ -397,14 +397,14 @@ def get_balance(self, address: "AddressType", block_id: Optional["BlockID"] = No @abstractmethod def get_code( - self, address: "AddressType", block_id: Optional["BlockID"] = None + self, address: "AddressType", block_id: "BlockID | None" = None ) -> "ContractCode": """ Get the bytes a contract. Args: address (:class:`~ape.types.address.AddressType`): The address of the contract. - block_id (Optional[:class:`~ape.types.BlockID`]): The block ID + block_id ("BlockID | None"): The block ID for checking a previous account nonce. Returns: @@ -427,7 +427,7 @@ def network_choice(self) -> str: return f"{self.network.choice}:{self.name}" @abstractmethod - def make_request(self, rpc: str, parameters: Optional[Iterable] = None) -> Any: + def make_request(self, rpc: str, parameters: Iterable | None = None) -> Any: """ Make a raw RPC request to the provider. Advanced features such as tracing may utilize this to by-pass unnecessary @@ -460,7 +460,7 @@ def get_storage_at(self, *args, **kwargs) -> HexBytes: @raises_not_implemented def get_storage( # type: ignore[empty-body] - self, address: "AddressType", slot: int, block_id: Optional["BlockID"] = None + self, address: "AddressType", slot: int, block_id: "BlockID | None" = None ) -> HexBytes: """ Gets the raw value of a storage slot of a contract. @@ -468,7 +468,7 @@ def get_storage( # type: ignore[empty-body] Args: address (AddressType): The address of the contract. slot (int): Storage slot to read the value of. - block_id (Optional[:class:`~ape.types.BlockID`]): The block ID + block_id ("BlockID | None"): The block ID for checking a previous storage value. Returns: @@ -476,13 +476,13 @@ def get_storage( # type: ignore[empty-body] """ @abstractmethod - def get_nonce(self, address: "AddressType", block_id: Optional["BlockID"] = None) -> int: + def get_nonce(self, address: "AddressType", block_id: "BlockID | None" = None) -> int: """ Get the number of times an account has transacted. Args: address (AddressType): The address of the account. - block_id (Optional[:class:`~ape.types.BlockID`]): The block ID + block_id ("BlockID | None"): The block ID for checking a previous account nonce. Returns: @@ -490,14 +490,14 @@ def get_nonce(self, address: "AddressType", block_id: Optional["BlockID"] = None """ @abstractmethod - def estimate_gas_cost(self, txn: TransactionAPI, block_id: Optional["BlockID"] = None) -> int: + def estimate_gas_cost(self, txn: TransactionAPI, block_id: "BlockID | None" = None) -> int: """ Estimate the cost of gas for a transaction. Args: txn (:class:`~ape.api.transactions.TransactionAPI`): The transaction to estimate the gas for. - block_id (Optional[:class:`~ape.types.BlockID`]): The block ID + block_id ("BlockID | None"): The block ID to use when estimating the transaction. Useful for checking a past estimation cost of a transaction. @@ -585,8 +585,8 @@ def get_block(self, block_id: "BlockID") -> BlockAPI: def send_call( self, txn: TransactionAPI, - block_id: Optional["BlockID"] = None, - state: Optional[dict] = None, + block_id: "BlockID | None" = None, + state: dict | None = None, **kwargs, ) -> HexBytes: # Return value of function """ @@ -595,7 +595,7 @@ def send_call( Args: txn: :class:`~ape.api.transactions.TransactionAPI` - block_id (Optional[:class:`~ape.types.BlockID`]): The block ID + block_id ("BlockID | None"): The block ID to use to send a call at a historical point of a contract. Useful for checking a past estimation cost of a transaction. state (Optional[dict]): Modify the state of the blockchain @@ -831,7 +831,7 @@ def relock_account(self, address: "AddressType"): @raises_not_implemented def get_transaction_trace( # type: ignore[empty-body] - self, txn_hash: Union[HexBytes, str] + self, txn_hash: HexBytes | str ) -> "TraceAPI": """ Provide a detailed description of opcodes. @@ -847,9 +847,9 @@ def get_transaction_trace( # type: ignore[empty-body] @raises_not_implemented def poll_blocks( # type: ignore[empty-body] self, - stop_block: Optional[int] = None, - required_confirmations: Optional[int] = None, - new_block_timeout: Optional[int] = None, + stop_block: int | None = None, + required_confirmations: int | None = None, + new_block_timeout: int | None = None, ) -> Iterator[BlockAPI]: """ Poll new blocks. @@ -878,12 +878,12 @@ def poll_blocks( # type: ignore[empty-body] @raises_not_implemented def poll_logs( # type: ignore[empty-body] self, - stop_block: Optional[int] = None, - address: Optional["AddressType"] = None, - topics: Optional[list[Union[str, list[str]]]] = None, - required_confirmations: Optional[int] = None, - new_block_timeout: Optional[int] = None, - events: Optional[list["EventABI"]] = None, + stop_block: int | None = None, + address: "AddressType | None" = None, + topics: list[str | list[str]] | None = None, + required_confirmations: int | None = None, + new_block_timeout: int | None = None, + events: list["EventABI"] | None = None, ) -> Iterator["ContractLog"]: """ Poll new blocks. Optionally set a start block to include historical blocks. @@ -1054,12 +1054,12 @@ class SubprocessProvider(ProviderAPI): PROCESS_WAIT_TIMEOUT: int = 15 background: bool = False - process: Optional[Popen] = None + process: Popen | None = None allow_start: bool = True is_stopping: bool = False - stdout_queue: Optional[JoinableQueue] = None - stderr_queue: Optional[JoinableQueue] = None + stdout_queue: JoinableQueue | None = None + stderr_queue: JoinableQueue | None = None @property @abstractmethod @@ -1097,7 +1097,7 @@ def _stderr_logger(self) -> Logger: return self._get_process_output_logger("stderr", self.stderr_logs_path) @property - def connection_id(self) -> Optional[str]: + def connection_id(self) -> str | None: cmd_id = ",".join(self.build_command()) return f"{self.network_choice}:{cmd_id}" diff --git a/src/ape/api/query.py b/src/ape/api/query.py index 5198bd5324..ae63678a60 100644 --- a/src/ape/api/query.py +++ b/src/ape/api/query.py @@ -1,7 +1,7 @@ from abc import abstractmethod from collections.abc import Iterator, Sequence from functools import cache, cached_property -from typing import Any, Optional, Union +from typing import Any from ethpm_types.abi import EventABI, MethodABI from pydantic import NonNegativeInt, PositiveInt, model_validator @@ -11,14 +11,9 @@ from ape.types.address import AddressType from ape.utils.basemodel import BaseInterface, BaseInterfaceModel, BaseModel -QueryType = Union[ - "BlockQuery", - "BlockTransactionQuery", - "AccountTransactionQuery", - "ContractCreationQuery", - "ContractEventQuery", - "ContractMethodQuery", -] +QueryType = ( + "BlockQuery | BlockTransactionQuery | AccountTransactionQuery | ContractCreationQuery | ContractEventQuery | ContractMethodQuery" +) @cache @@ -195,7 +190,7 @@ class ContractCreation(BaseModel, BaseInterface): The contract deployer address. """ - factory: Optional[AddressType] = None + factory: AddressType | None = None """ The address of the factory contract, if there is one and it is known (depends on the query provider!). @@ -234,9 +229,9 @@ class ContractEventQuery(_BaseBlockQuery): logs emitted by ``contract`` between ``start_block`` and ``stop_block``. """ - contract: Union[list[AddressType], AddressType] + contract: list[AddressType] | AddressType event: EventABI - search_topics: Optional[dict[str, Any]] = None + search_topics: dict[str, Any] | None = None class ContractMethodQuery(_BaseBlockQuery): @@ -252,7 +247,7 @@ class ContractMethodQuery(_BaseBlockQuery): class QueryAPI(BaseInterface): @abstractmethod - def estimate_query(self, query: QueryType) -> Optional[int]: + def estimate_query(self, query: QueryType) -> int | None: """ Estimation of time needed to complete the query. The estimation is returned as an int representing milliseconds. A value of None indicates that the diff --git a/src/ape/api/trace.py b/src/ape/api/trace.py index 009f65e3a8..c586b24a94 100644 --- a/src/ape/api/trace.py +++ b/src/ape/api/trace.py @@ -1,7 +1,7 @@ import sys from abc import abstractmethod from collections.abc import Iterator, Sequence -from typing import IO, TYPE_CHECKING, Any, Optional +from typing import IO, TYPE_CHECKING, Any from ape.utils.basemodel import BaseInterfaceModel @@ -23,7 +23,7 @@ def show(self, verbose: bool = False, file: IO[str] = sys.stdout): @abstractmethod def get_gas_report( - self, exclude: Optional[Sequence["ContractFunctionPath"]] = None + self, exclude: Sequence["ContractFunctionPath"] | None = None ) -> "GasReport": """ Get the gas report. @@ -44,7 +44,7 @@ def return_value(self) -> Any: @property @abstractmethod - def revert_message(self) -> Optional[str]: + def revert_message(self) -> str | None: """ The revert message deduced from the trace. """ diff --git a/src/ape/api/transactions.py b/src/ape/api/transactions.py index eecbc3d6f8..f6f82d58da 100644 --- a/src/ape/api/transactions.py +++ b/src/ape/api/transactions.py @@ -4,7 +4,7 @@ from collections.abc import Iterator from datetime import datetime as datetime_type from functools import cached_property -from typing import IO, TYPE_CHECKING, Any, NoReturn, Optional, Union +from typing import IO, TYPE_CHECKING, Any, NoReturn from eth_pydantic_types import HexBytes, HexStr from eth_utils import humanize_hexstr, is_hex, to_hex, to_int @@ -48,21 +48,21 @@ class TransactionAPI(BaseInterfaceModel): such as typed-transactions from `EIP-1559 `__. """ - chain_id: Optional[HexInt] = Field(default=0, alias="chainId") - receiver: Optional[AddressType] = Field(default=None, alias="to") - sender: Optional[AddressType] = Field(default=None, alias="from") - gas_limit: Optional[HexInt] = Field(default=None, alias="gas") - nonce: Optional[HexInt] = None # NOTE: `Optional` only to denote using default behavior + chain_id: HexInt | None = Field(default=0, alias="chainId") + receiver: AddressType | None = Field(default=None, alias="to") + sender: AddressType | None = Field(default=None, alias="from") + gas_limit: HexInt | None = Field(default=None, alias="gas") + nonce: HexInt | None = None # NOTE: `None` only to denote using default behavior value: HexInt = 0 data: HexBytes = HexBytes("") type: HexInt - max_fee: Optional[HexInt] = None - max_priority_fee: Optional[HexInt] = None + max_fee: HexInt | None = None + max_priority_fee: HexInt | None = None # If left as None, will get set to the network's default required confirmations. - required_confirmations: Optional[HexInt] = Field(default=None, exclude=True) + required_confirmations: HexInt | None = Field(default=None, exclude=True) - signature: Optional[TransactionSignature] = Field(default=None, exclude=True) + signature: TransactionSignature | None = Field(default=None, exclude=True) model_config = ConfigDict(populate_by_name=True) @@ -99,7 +99,7 @@ def validate_gas_limit(cls, value): return value @property - def gas(self) -> Optional[int]: + def gas(self) -> int | None: """ Alias for ``.gas_limit``. """ @@ -149,7 +149,7 @@ def hash(self) -> HexBytes: return self.txn_hash @property - def receipt(self) -> Optional["ReceiptAPI"]: + def receipt(self) -> "ReceiptAPI | None": """ This transaction's associated published receipt, if it exists. """ @@ -196,7 +196,7 @@ def __repr__(self) -> str: def __str__(self) -> str: return self.to_string() - def to_string(self, calldata_repr: Optional["CalldataRepr"] = None) -> str: + def to_string(self, calldata_repr: "CalldataRepr | None" = None) -> str: """ Get the stringified representation of the transaction. @@ -233,7 +233,7 @@ def _get_calldata_repr_str(self, calldata_repr: "CalldataRepr") -> str: else calldata.to_0x_hex() ) - def _decoded_call(self) -> Optional[str]: + def _decoded_call(self) -> str | None: if not self.receiver: return "constructor()" @@ -326,14 +326,14 @@ class ReceiptAPI(ExtraAttributesMixin, BaseInterfaceModel): a :class:`ape.contracts.base.ContractInstance`. """ - contract_address: Optional[AddressType] = None + contract_address: AddressType | None = None block_number: HexInt gas_used: HexInt logs: list[dict] = [] status: HexInt txn_hash: HexStr transaction: TransactionAPI - _error: Optional[TransactionError] = None + _error: TransactionError | None = None @log_instead_of_fail(default="") def __repr__(self) -> str: @@ -375,7 +375,7 @@ def debug_logs_lines(self) -> list[str]: return [" ".join(map(str, ln)) for ln in self.debug_logs_typed] @property - def error(self) -> Optional[TransactionError]: + def error(self) -> TransactionError | None: return self._error @error.setter @@ -432,7 +432,7 @@ def trace(self) -> "TraceAPI": return self.provider.get_transaction_trace(self.txn_hash) @property - def _explorer(self) -> Optional["ExplorerAPI"]: + def _explorer(self) -> "ExplorerAPI | None": return self.provider.network.explorer @property @@ -471,9 +471,9 @@ def events(self) -> "ContractLogContainer": @abstractmethod def decode_logs( self, - abi: Optional[ - Union[list[Union["EventABI", "ContractEvent"]], Union["EventABI", "ContractEvent"]] - ] = None, + abi: ( + "list[EventABI | ContractEvent] | EventABI | ContractEvent | None" + ) = None, ) -> "ContractLogContainer": """ Decode the logs on the receipt. @@ -485,7 +485,7 @@ def decode_logs( list[:class:`~ape.types.ContractLog`] """ - def raise_for_status(self) -> Optional[NoReturn]: + def raise_for_status(self) -> NoReturn | None: """ Handle provider-specific errors regarding a non-successful :class:`~api.providers.TransactionStatusEnum`. @@ -562,7 +562,7 @@ def _await_confirmations(self): time.sleep(time_to_sleep) @property - def method_called(self) -> Optional["MethodABI"]: + def method_called(self) -> "MethodABI | None": """ The method ABI of the method called to produce this receipt. """ diff --git a/src/ape/cli/arguments.py b/src/ape/cli/arguments.py index 2934af0606..b097470449 100644 --- a/src/ape/cli/arguments.py +++ b/src/ape/cli/arguments.py @@ -1,6 +1,6 @@ from collections.abc import Iterable from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING import click from click import BadArgumentUsage @@ -48,7 +48,7 @@ class _ContractPaths: Helper callback class for handling CLI-given contract paths. """ - def __init__(self, value, project: Optional["ProjectManager"] = None): + def __init__(self, value, project: "ProjectManager | None" = None): from ape.utils.basemodel import ManagerAccessMixin self.value = value @@ -111,10 +111,10 @@ def exclude_patterns(self) -> set[str]: return access.config_manager.get_config("compile").exclude or set() - def do_exclude(self, path: Union[Path, str]) -> bool: + def do_exclude(self, path: Path | str) -> bool: return self.project.sources.is_excluded(path) - def compiler_is_unknown(self, path: Union[Path, str]) -> bool: + def compiler_is_unknown(self, path: Path | str) -> bool: from ape.utils.basemodel import ManagerAccessMixin from ape.utils.os import get_full_extension @@ -127,7 +127,7 @@ def compiler_is_unknown(self, path: Union[Path, str]) -> bool: return bool(unknown_compiler) - def lookup(self, path_iter: Iterable, path_set: Optional[set] = None) -> set[Path]: + def lookup(self, path_iter: Iterable, path_set: set | None = None) -> set[Path]: path_set = path_set or set() given_paths = [p for p in path_iter] # Handle iterators w/o losing it. diff --git a/src/ape/cli/choices.py b/src/ape/cli/choices.py index bc8a2e8e9c..ebd85eb632 100644 --- a/src/ape/cli/choices.py +++ b/src/ape/cli/choices.py @@ -3,7 +3,7 @@ from enum import Enum from functools import cache, cached_property from importlib import import_module -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any import click from click import Choice, Context, Parameter @@ -19,9 +19,9 @@ from ape.api.accounts import AccountAPI from ape.api.providers import ProviderAPI -_ACCOUNT_TYPE_FILTER = Union[ - None, Sequence["AccountAPI"], type["AccountAPI"], Callable[["AccountAPI"], bool] -] +_ACCOUNT_TYPE_FILTER = ( + None | Sequence["AccountAPI"] | type["AccountAPI"] | Callable[["AccountAPI"], bool] +) def _get_accounts(key: _ACCOUNT_TYPE_FILTER) -> list["AccountAPI"]: @@ -105,7 +105,7 @@ def cmd(choice): click.echo(f"__expected_{choice}") """ - def __init__(self, choices: Sequence[str], name: Optional[str] = None): + def __init__(self, choices: Sequence[str], name: str | None = None): self.choices = choices # Since we purposely skip the super() constructor, we need to make # sure the class still has a name. @@ -125,8 +125,8 @@ def print_choices(self): click.echo() def convert( - self, value: Any, param: Optional[Parameter], ctx: Optional[Context] - ) -> Optional[Any]: + self, value: Any, param: Parameter | None, ctx: Context | None + ) -> Any | None: # noinspection PyBroadException try: choice_index = int(value) @@ -156,7 +156,7 @@ def select(self) -> str: def select_account( - prompt_message: Optional[str] = None, key: _ACCOUNT_TYPE_FILTER = None + prompt_message: str | None = None, key: _ACCOUNT_TYPE_FILTER = None ) -> "AccountAPI": """ Prompt the user to pick from their accounts and return that account. @@ -165,8 +165,8 @@ def select_account( :meth:`~ape.cli.options.account_option`. Args: - prompt_message (Optional[str]): Customize the prompt message. - key (Union[None, type[AccountAPI], Callable[[AccountAPI], bool]]): + prompt_message (str | None): Customize the prompt message. + key (None | type[AccountAPI] | Callable[[AccountAPI], bool]): If given, the user may only select a matching account. You can provide a list of accounts, an account class type, or a callable for filtering the accounts. @@ -194,7 +194,7 @@ class AccountAliasPromptChoice(PromptChoice): def __init__( self, key: _ACCOUNT_TYPE_FILTER = None, - prompt_message: Optional[str] = None, + prompt_message: str | None = None, name: str = "account", ): # NOTE: we purposely skip the constructor of `PromptChoice` @@ -209,8 +209,8 @@ def choices(self) -> Sequence[str]: # type: ignore[override] return _LazySequence(self._choices_iterator) def convert( - self, value: Any, param: Optional[Parameter], ctx: Optional[Context] - ) -> Optional["AccountAPI"]: + self, value: Any, param: Parameter | None, ctx: Context | None + ) -> "AccountAPI | None": if value is None: return None @@ -295,7 +295,7 @@ def fail_from_invalid_choice(self, param): return self.fail("Invalid choice. Type the number or the alias.", param=param) -_NETWORK_FILTER = Optional[Union[list[str], str]] +_NETWORK_FILTER = list[str] | str | None _NONE_NETWORK = "__NONE_NETWORK__" @@ -363,8 +363,8 @@ def __init__( ecosystem: _NETWORK_FILTER = None, network: _NETWORK_FILTER = None, provider: _NETWORK_FILTER = None, - base_type: Optional[Union[type, str]] = None, - callback: Optional[Callable] = None, + base_type: type | str | None = None, + callback: Callable | None = None, ): self._base_type = base_type self.callback = callback @@ -375,7 +375,7 @@ def __init__( # NOTE: Purposely avoid super().init for performance reasons. @property - def base_type(self) -> Union[type["ProviderAPI"], str]: + def base_type(self) -> type["ProviderAPI"] | str: if self._base_type is not None: return self._base_type @@ -395,7 +395,7 @@ def choices(self) -> Sequence[Any]: # type: ignore[override] def get_metavar(self, param): return "[ecosystem-name][:[network-name][:[provider-name]]]" - def convert(self, value: Any, param: Optional[Parameter], ctx: Optional[Context]) -> Any: + def convert(self, value: Any, param: Parameter | None, ctx: Context | None) -> Any: if not value or value.lower() in ("none", "null"): return self.callback(ctx, param, _NONE_NETWORK) if self.callback else _NONE_NETWORK @@ -426,7 +426,7 @@ class OutputFormat(Enum): """A standard .yaml format of the data.""" -def output_format_choice(options: Optional[list[OutputFormat]] = None) -> Choice: +def output_format_choice(options: list[OutputFormat] | None = None) -> Choice: """ Returns a ``click.Choice()`` type for the given options. diff --git a/src/ape/cli/commands.py b/src/ape/cli/commands.py index 3cf62b7d3f..df199125a5 100644 --- a/src/ape/cli/commands.py +++ b/src/ape/cli/commands.py @@ -1,5 +1,5 @@ import inspect -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import click @@ -13,7 +13,7 @@ from ape.api.providers import ProviderAPI -def get_param_from_ctx(ctx: "Context", param: str) -> Optional[Any]: +def get_param_from_ctx(ctx: "Context", param: str) -> Any | None: if value := ctx.params.get(param): return value @@ -24,7 +24,7 @@ def get_param_from_ctx(ctx: "Context", param: str) -> Optional[Any]: return None -def parse_network(ctx: "Context") -> Optional["ProviderContextManager"]: +def parse_network(ctx: "Context") -> "ProviderContextManager | None": from ape.api.providers import ProviderAPI from ape.utils.basemodel import ManagerAccessMixin as access @@ -72,7 +72,7 @@ def __init__(self, *args, **kwargs): def parse_args(self, ctx: "Context", args: list[str]) -> list[str]: arguments = args # Renamed for better pdb support. - base_type: Optional[type] = None if self._use_cls_types else str + base_type: type | None = None if self._use_cls_types else str if existing_option := next( iter( x @@ -112,7 +112,7 @@ def invoke(self, ctx: "Context") -> Any: else: return self._invoke(ctx) - def _invoke(self, ctx: "Context", provider: Optional["ProviderAPI"] = None): + def _invoke(self, ctx: "Context", provider: "ProviderAPI | None" = None): # Will be put back with correct value if needed. # Else, causes issues. ctx.params.pop("network", None) diff --git a/src/ape/cli/options.py b/src/ape/cli/options.py index 3e0d24da73..2985255cab 100644 --- a/src/ape/cli/options.py +++ b/src/ape/cli/options.py @@ -3,7 +3,7 @@ from collections.abc import Callable from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Any, NoReturn, Optional, Union +from typing import TYPE_CHECKING, Any, NoReturn import click from click import Option @@ -51,7 +51,7 @@ def __getattr__(self, item: str) -> Any: return getattr(ManagerAccessMixin, item) @staticmethod - def abort(msg: str, base_error: Optional[Exception] = None) -> NoReturn: + def abort(msg: str, base_error: Exception | None = None) -> NoReturn: """ End execution of the current command invocation. @@ -69,9 +69,9 @@ def abort(msg: str, base_error: Optional[Exception] = None) -> NoReturn: def verbosity_option( - cli_logger: Optional[ApeLogger] = None, - default: Optional[Union[str, int, LogLevel]] = None, - callback: Optional[Callable] = None, + cli_logger: ApeLogger | None = None, + default: str | int | LogLevel | None = None, + callback: Callable | None = None, **kwargs, ) -> Callable: """A decorator that adds a `--verbosity, -v` option to the decorated @@ -97,9 +97,9 @@ def verbosity_option( def _create_verbosity_kwargs( - _logger: Optional[ApeLogger] = None, - default: Optional[Union[str, int, LogLevel]] = None, - callback: Optional[Callable] = None, + _logger: ApeLogger | None = None, + default: str | int | LogLevel | None = None, + callback: Callable | None = None, **kwargs, ) -> dict: default = logger.level if default is None else default @@ -137,7 +137,7 @@ def set_level(ctx, param, value): def ape_cli_context( - default_log_level: Optional[Union[str, int, LogLevel]] = None, + default_log_level: str | int | LogLevel | None = None, obj_type: type = ApeCliContextObject, ) -> Callable: """ @@ -230,10 +230,10 @@ def fn(): def network_option( - default: Optional[Union[str, Callable]] = "auto", - ecosystem: Optional[Union[list[str], str]] = None, - network: Optional[Union[list[str], str]] = None, - provider: Optional[Union[list[str], str]] = None, + default: str | Callable | None = "auto", + ecosystem: list[str] | str | None = None, + network: list[str] | str | None = None, + provider: list[str] | str | None = None, required: bool = False, **kwargs, ) -> Callable: @@ -423,7 +423,7 @@ def skip_confirmation_option(help="") -> Callable: def account_option( *param_decls, account_type: _ACCOUNT_TYPE_FILTER = None, - prompt: Optional[Union[str, bool]] = AccountAliasPromptChoice.DEFAULT_PROMPT, + prompt: str | bool | None = AccountAliasPromptChoice.DEFAULT_PROMPT, ) -> Callable: """ A CLI option that accepts either the account alias or the account number. @@ -446,7 +446,7 @@ def _account_callback(ctx, param, value): ) -def _load_contracts(ctx, param, value) -> Optional[Union["ContractType", list["ContractType"]]]: +def _load_contracts(ctx, param, value) -> "ContractType | list[ContractType] | None": if not value: return None diff --git a/src/ape/contracts/base.py b/src/ape/contracts/base.py index 3c00d49da2..154b096d69 100644 --- a/src/ape/contracts/base.py +++ b/src/ape/contracts/base.py @@ -4,7 +4,7 @@ from functools import cached_property, partial, singledispatchmethod from itertools import islice from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any import click from eth_pydantic_types import HexBytes @@ -352,7 +352,7 @@ def estimate_gas_cost(self, *args, **kwargs) -> int: return self.transact.estimate_gas_cost(*arguments, **kwargs) -def _select_method_abi(abis: list["MethodABI"], args: Union[tuple, list]) -> "MethodABI": +def _select_method_abi(abis: list["MethodABI"], args: tuple | list) -> "MethodABI": args = args or [] selected_abi = None for abi in abis: @@ -501,7 +501,7 @@ class ContractEvent(BaseInterfaceModel): contract: "ContractTypeWrapper" abi: EventABI - _logs: Optional[list[ContractLog]] = None + _logs: list[ContractLog] | None = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -555,7 +555,7 @@ def log_filter(self) -> LogFilter: return LogFilter.from_event(event=self.abi, addresses=addresses, start_block=0) @singledispatchmethod - def __getitem__(self, value) -> Union[ContractLog, list[ContractLog]]: # type: ignore[override] + def __getitem__(self, value) -> ContractLog | list[ContractLog]: # type: ignore[override] raise NotImplementedError(f"Cannot use '{type(value)}' to access logs.") @__getitem__.register @@ -673,9 +673,9 @@ def query( self, *columns: str, start_block: int = 0, - stop_block: Optional[int] = None, + stop_block: int | None = None, step: int = 1, - engine_to_use: Optional[str] = None, + engine_to_use: str | None = None, ) -> "DataFrame": """ Iterate through blocks for log events @@ -685,11 +685,11 @@ def query( return. start_block (int): The first block, by number, to include in the query. Defaults to ``0``. - stop_block (Optional[int]): The last block, by number, to include + stop_block (int | None): The last block, by number, to include in the query. Defaults to the latest block. step (int): The number of blocks to iterate between block numbers. Defaults to ``1``. - engine_to_use (Optional[str]): query engine to use, bypasses query + engine_to_use (str | None): query engine to use, bypasses query engine selection algorithm. Returns: @@ -734,9 +734,9 @@ def query( def range( self, start_or_stop: int, - stop: Optional[int] = None, - search_topics: Optional[dict[str, Any]] = None, - extra_addresses: Optional[list] = None, + stop: int | None = None, + search_topics: dict[str, Any] | None = None, + extra_addresses: list | None = None, ) -> Iterator[ContractLog]: """ Search through the logs for this event using the given filter parameters. @@ -745,11 +745,11 @@ def range( start_or_stop (int): When also given ``stop``, this is the earliest block number in the desired log set. Otherwise, it is the total amount of blocks to get starting from ``0``. - stop (Optional[int]): The latest block number in the + stop (int | None): The latest block number in the desired log set. Defaults to delegating to provider. - search_topics (Optional[dict]): Search topics, such as indexed event inputs, + search_topics (dict | None): Search topics, such as indexed event inputs, to query by. Defaults to getting all events. - extra_addresses (Optional[list[:class:`~ape.types.address.AddressType`]]): + extra_addresses (list | None): Additional contract addresses containing the same event type. Defaults to Additional contract addresses containing the same event type. Defaults to only looking at the contract instance where this event is defined. @@ -851,11 +851,11 @@ def from_receipt(self, receipt: "ReceiptAPI") -> list[ContractLog]: def poll_logs( self, - start_block: Optional[int] = None, - stop_block: Optional[int] = None, - required_confirmations: Optional[int] = None, - new_block_timeout: Optional[int] = None, - search_topics: Optional[dict[str, Any]] = None, + start_block: int | None = None, + stop_block: int | None = None, + required_confirmations: int | None = None, + new_block_timeout: int | None = None, + search_topics: dict[str, Any] | None = None, **search_topic_kwargs: dict[str, Any], ) -> Iterator[ContractLog]: """ @@ -870,17 +870,17 @@ def poll_logs( print(f"New event log found: block_number={new_log.block_number}") Args: - start_block (Optional[int]): The block number to start with. Defaults to the pending + start_block (int | None): The block number to start with. Defaults to the pending block number. - stop_block (Optional[int]): Optionally set a future block number to stop at. + stop_block (int | None): Optionally set a future block number to stop at. Defaults to never-ending. - required_confirmations (Optional[int]): The amount of confirmations to wait + required_confirmations (int | None): The amount of confirmations to wait before yielding the block. The more confirmations, the less likely a reorg will occur. Defaults to the network's configured required confirmations. - new_block_timeout (Optional[int]): The amount of time to wait for a new block before + new_block_timeout (int | None): The amount of time to wait for a new block before quitting. Defaults to 10 seconds for local networks or ``50 * block_time`` for live networks. - search_topics (Optional[dict[str, Any]]): A dictionary of search topics to use when + search_topics (dict[str, Any] | None): A dictionary of search topics to use when constructing a polling filter. Overrides the value of `**search_topics_kwargs`. search_topics_kwargs: Search topics to use when constructing a polling filter. Allows easily specifying topic filters using kwarg syntax but can be used w/ `search_topics` @@ -952,7 +952,7 @@ def __call__(self, *args, **kwargs) -> MockContractLog: class ContractTypeWrapper(ManagerAccessMixin): contract_type: "ContractType" - base_path: Optional[Path] = None + base_path: Path | None = None @property def selector_identifiers(self) -> dict[str, str]: @@ -971,7 +971,7 @@ def identifier_lookup(self) -> dict[str, "ABI_W_SELECTOR_T"]: return self.contract_type.identifier_lookup @property - def source_path(self) -> Optional[Path]: + def source_path(self) -> Path | None: """ Returns the path to the local contract if determined that this container belongs to the active project by cross-checking source_id. @@ -1078,7 +1078,7 @@ def __init__( self, address: "AddressType", contract_type: "ContractType", - txn_hash: Optional[Union[str, HexBytes]] = None, + txn_hash: str | HexBytes | None = None, ) -> None: super().__init__() self._address = address @@ -1137,7 +1137,7 @@ def from_receipt( return instance @property - def creation_metadata(self) -> Optional[ContractCreation]: + def creation_metadata(self) -> ContractCreation | None: """ Contract creation details: txn_hash, block, deployer, factory, receipt. See :class:`~ape.api.query.ContractCreation` for more details. @@ -1558,9 +1558,9 @@ def deployments(self): def at( self, address: "AddressType", - txn_hash: Optional[Union[str, HexBytes]] = None, + txn_hash: str | HexBytes | None = None, fetch_from_explorer: bool = True, - proxy_info: Optional["ProxyInfoAPI"] = None, + proxy_info: "ProxyInfoAPI | None" = None, detect_proxy: bool = True, ) -> ContractInstance: """ @@ -1577,7 +1577,7 @@ def at( **NOTE**: Things will not work as expected if the contract is not actually deployed to this address or if the contract at the given address has a different ABI than :attr:`~ape.contracts.ContractContainer.contract_type`. - txn_hash (Union[str, HexBytes]): The hash of the transaction that deployed the + txn_hash (str | HexBytes): The hash of the transaction that deployed the contract, if available. Defaults to ``None``. fetch_from_explorer (bool): Set to ``False`` to avoid fetching from an explorer. proxy_info (:class:`~ape.api.networks.ProxyInfoAPI` | None): Proxy info object to set @@ -1743,7 +1743,7 @@ def __repr__(self) -> str: return f"<{self.name}>" @only_raise_attribute_error - def __getattr__(self, item: str) -> Union[ContractContainer, "ContractNamespace"]: + def __getattr__(self, item: str) -> "ContractContainer | ContractNamespace": """ Access the next contract container or namespace. @@ -1751,7 +1751,7 @@ def __getattr__(self, item: str) -> Union[ContractContainer, "ContractNamespace" item (str): The name of the next node. Returns: - Union[:class:`~ape.contracts.base.ContractContainer`, + :class:`~ape.contracts.base.ContractContainer` | :class:`~ape.contracts.base.ContractNamespace`] """ _assert_not_ipython_check(item) diff --git a/src/ape/exceptions.py b/src/ape/exceptions.py index 9053f015dc..bccfe8bcbe 100644 --- a/src/ape/exceptions.py +++ b/src/ape/exceptions.py @@ -8,7 +8,7 @@ from inspect import getframeinfo, stack from pathlib import Path from types import CodeType, TracebackType -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable import click from rich import print as rich_print @@ -30,7 +30,7 @@ from ape.types.vm import BlockID, SnapshotID -FailedTxn = Union["TransactionAPI", "ReceiptAPI"] +FailedTxn = "TransactionAPI | ReceiptAPI" class ApeException(Exception): @@ -74,7 +74,7 @@ class SignatureError(AccountsError): Raised when there are issues with signing. """ - def __init__(self, message: str, transaction: Optional["TransactionAPI"] = None): + def __init__(self, message: str, transaction: "TransactionAPI | None" = None): self.transaction = transaction super().__init__(message) @@ -112,7 +112,7 @@ class ArgumentsLengthError(ContractDataError): def __init__( self, arguments_length: int, - inputs: Union["MethodABI", "ConstructorABI", int, list, None] = None, + inputs: "MethodABI | ConstructorABI | int | list | None" = None, **kwargs, ): prefix = ( @@ -123,7 +123,7 @@ def __init__( super().__init__(f"{prefix}.") return - inputs_ls: list[Union[MethodABI, ConstructorABI, int]] = ( + inputs_ls: list[MethodABI | ConstructorABI | int] = ( inputs if isinstance(inputs, list) else [inputs] ) if not inputs_ls: @@ -158,7 +158,7 @@ class DecodingError(ContractDataError): a contract call, transaction, or event. """ - def __init__(self, message: Optional[str] = None): + def __init__(self, message: str | None = None): message = message or "Output corrupted." super().__init__(message) @@ -169,10 +169,10 @@ class MethodNonPayableError(ContractDataError): """ -_TRACE_ARG = Optional[Union["TraceAPI", Callable[[], Optional["TraceAPI"]]]] -_SOURCE_TRACEBACK_ARG = Optional[ - Union["SourceTraceback", Callable[[], Optional["SourceTraceback"]]] -] +_TRACE_ARG = "TraceAPI | Callable[[], TraceAPI | None] | None" +_SOURCE_TRACEBACK_ARG = ( + "SourceTraceback | Callable[[], SourceTraceback | None] | None" +) class TransactionError(ApeException): @@ -184,14 +184,14 @@ class TransactionError(ApeException): def __init__( self, - message: Optional[str] = None, - base_err: Optional[Exception] = None, - code: Optional[int] = None, - txn: Optional[FailedTxn] = None, - trace: _TRACE_ARG = None, - contract_address: Optional["AddressType"] = None, - source_traceback: _SOURCE_TRACEBACK_ARG = None, - project: Optional["ProjectManager"] = None, + message: str | None = None, + base_err: Exception | None = None, + code: int | None = None, + txn: "FailedTxn | None" = None, + trace: "_TRACE_ARG" = None, + contract_address: "AddressType | None" = None, + source_traceback: "_SOURCE_TRACEBACK_ARG" = None, + project: "ProjectManager | None" = None, set_ape_traceback: bool = False, # Overridden in ContractLogicError ): message = message or (str(base_err) if base_err else self.DEFAULT_MESSAGE) @@ -213,7 +213,7 @@ def __init__( self.with_ape_traceback() @property - def address(self) -> Optional["AddressType"]: + def address(self) -> "AddressType | None": if addr := self.contract_address: return addr @@ -226,7 +226,7 @@ def address(self) -> Optional["AddressType"]: return receiver @cached_property - def contract_type(self) -> Optional["ContractType"]: + def contract_type(self) -> "ContractType | None": if not (address := self.address): # Contract address not found. return None @@ -240,7 +240,7 @@ def contract_type(self) -> Optional["ContractType"]: return None @property - def trace(self) -> Optional["TraceAPI"]: + def trace(self) -> "TraceAPI | None": tr = self._trace if callable(tr): result = tr() @@ -254,9 +254,9 @@ def trace(self, value): self._trace = value @property - def source_traceback(self) -> Optional["SourceTraceback"]: + def source_traceback(self) -> "SourceTraceback | None": tb = self._source_traceback - result: Optional[SourceTraceback] + result: SourceTraceback | None if not self._attempted_source_traceback and tb is None and self.txn is not None: result = _get_ape_traceback_from_tx(self.txn) # Prevent re-trying. @@ -273,7 +273,7 @@ def source_traceback(self) -> Optional["SourceTraceback"]: def source_traceback(self, value): self._source_traceback = value - def _get_ape_traceback(self) -> Optional[TracebackType]: + def _get_ape_traceback(self) -> TracebackType | None: if src_tb := self.source_traceback: # Create a custom Pythonic traceback using lines from the sources # found from analyzing the trace of the transaction. @@ -300,13 +300,13 @@ class ContractLogicError(VirtualMachineError): def __init__( self, - revert_message: Optional[str] = None, - txn: Optional[FailedTxn] = None, - trace: _TRACE_ARG = None, - contract_address: Optional["AddressType"] = None, - source_traceback: _SOURCE_TRACEBACK_ARG = None, - base_err: Optional[Exception] = None, - project: Optional["ProjectManager"] = None, + revert_message: str | None = None, + txn: "FailedTxn | None" = None, + trace: "_TRACE_ARG" = None, + contract_address: "AddressType | None" = None, + source_traceback: "_SOURCE_TRACEBACK_ARG" = None, + base_err: Exception | None = None, + project: "ProjectManager | None" = None, set_ape_traceback: bool = True, # Overridden default. ): self.txn = txn @@ -339,7 +339,7 @@ def revert_message(self, value): self.args = tuple([value, *args[1:]]) @property - def dev_message(self) -> Optional[str]: + def dev_message(self) -> str | None: """ The dev-string message of the exception. @@ -368,9 +368,9 @@ class OutOfGasError(VirtualMachineError): def __init__( self, - code: Optional[int] = None, - txn: Optional[FailedTxn] = None, - base_err: Optional[Exception] = None, + code: int | None = None, + txn: "FailedTxn | None" = None, + base_err: Exception | None = None, set_ape_traceback: bool = False, ): super().__init__( @@ -393,7 +393,7 @@ class EcosystemNotFoundError(NetworkError): Raised when the ecosystem with the given name was not found. """ - def __init__(self, ecosystem: str, options: Optional[Collection[str]] = None): + def __init__(self, ecosystem: str, options: Collection[str] | None = None): self.ecosystem = ecosystem self.options = options message = f"No ecosystem named '{ecosystem}'." @@ -417,8 +417,8 @@ class NetworkNotFoundError(NetworkError): def __init__( self, network: str, - ecosystem: Optional[str] = None, - options: Optional[Collection[str]] = None, + ecosystem: str | None = None, + options: Collection[str] | None = None, ): self.network = network options = options or [] @@ -458,9 +458,9 @@ class ProviderNotFoundError(NetworkError): def __init__( self, provider: str, - network: Optional[str] = None, - ecosystem: Optional[str] = None, - options: Optional[Collection[str]] = None, + network: str | None = None, + ecosystem: str | None = None, + options: Collection[str] | None = None, ): self.provider = provider self.network = network @@ -501,7 +501,7 @@ class ApeAttributeError(ProjectError, AttributeError): Raised when trying to access items via ``.`` access. """ - def __init__(self, msg: str, base_err: Optional[Exception] = None): + def __init__(self, msg: str, base_err: Exception | None = None): self.base_err = base_err super().__init__(msg) @@ -532,7 +532,7 @@ class BlockNotFoundError(ProviderError): Raised when unable to find a block. """ - def __init__(self, block_id: "BlockID", reason: Optional[str] = None): + def __init__(self, block_id: "BlockID", reason: str | None = None): if isinstance(block_id, bytes): block_id_str = block_id.hex() if not block_id_str.startswith("0x"): @@ -556,7 +556,7 @@ class TransactionNotFoundError(ProviderError): Raised when unable to find a transaction. """ - def __init__(self, transaction_hash: Optional[str] = None, error_message: Optional[str] = None): + def __init__(self, transaction_hash: str | None = None, error_message: str | None = None): message = ( f"Transaction '{transaction_hash}' not found." if transaction_hash @@ -681,9 +681,9 @@ class SubprocessTimeoutError(SubprocessError): def __init__( self, provider: "SubprocessProvider", - message: Optional[str] = None, - seconds: Optional[int] = None, - exception: Optional[Exception] = None, + message: str | None = None, + seconds: int | None = None, + exception: Exception | None = None, *args, **kwargs, ): @@ -691,8 +691,8 @@ def __init__( self._message = message or "Timed out waiting for process." self._seconds = seconds self._exception = exception - self._start_time: Optional[float] = None - self._is_running: Optional[bool] = None + self._start_time: float | None = None + self._is_running: bool | None = None def __enter__(self): self.start() @@ -756,8 +756,8 @@ class RPCTimeoutError(SubprocessTimeoutError): def __init__( self, provider: "SubprocessProvider", - seconds: Optional[int] = None, - exception: Optional[Exception] = None, + seconds: int | None = None, + exception: Exception | None = None, *args, **kwargs, ): @@ -786,7 +786,7 @@ class PluginVersionError(PluginInstallError): """ def __init__( - self, operation: str, reason: Optional[str] = None, resolution: Optional[str] = None + self, operation: str, reason: str | None = None, resolution: str | None = None ): message = f"Unable to {operation} plugin." if reason: @@ -797,7 +797,7 @@ def __init__( super().__init__(message) -def handle_ape_exception(err: ApeException, base_paths: Iterable[Union[Path, str]]) -> bool: +def handle_ape_exception(err: ApeException, base_paths: Iterable[Path | str]) -> bool: """ Handle a transaction error by showing relevant stack frames, including custom contract frames added to the exception. @@ -807,7 +807,7 @@ def handle_ape_exception(err: ApeException, base_paths: Iterable[Union[Path, str Args: err (:class:`~ape.exceptions.TransactionError`): The transaction error being handled. - base_paths (Optional[Iterable[Union[Path, str]]]): Optionally include additional + base_paths (Iterable[Path | str] | None): Optionally include additional source-path prefixes to use when finding relevant frames. Returns: @@ -825,7 +825,7 @@ def handle_ape_exception(err: ApeException, base_paths: Iterable[Union[Path, str return True -def _get_relevant_frames(base_paths: Iterable[Union[Path, str]]): +def _get_relevant_frames(base_paths: Iterable[Path | str]): # Abstracted for testing easement. tb = traceback.extract_tb(sys.exc_info()[2]) if relevant_tb := [f for f in tb if any(str(p) in f.filename for p in base_paths)]: @@ -841,7 +841,7 @@ class Abort(click.ClickException): useful for all user-facing errors. """ - def __init__(self, message: Optional[str] = None): + def __init__(self, message: str | None = None): if not message: caller = getframeinfo(stack()[1][0]) file_path = Path(caller.filename) @@ -851,7 +851,7 @@ def __init__(self, message: Optional[str] = None): super().__init__(message) @classmethod - def from_ape_exception(cls, exc: ApeException, show_traceback: Optional[bool] = None): + def from_ape_exception(cls, exc: ApeException, show_traceback: bool | None = None): show_traceback = ( logger.level == LogLevel.DEBUG.value if show_traceback is None else show_traceback ) @@ -881,11 +881,11 @@ def __init__( self, abi: "ErrorABI", inputs: dict[str, Any], - txn: Optional[FailedTxn] = None, - trace: _TRACE_ARG = None, - contract_address: Optional["AddressType"] = None, - base_err: Optional[Exception] = None, - source_traceback: _SOURCE_TRACEBACK_ARG = None, + txn: "FailedTxn | None" = None, + trace: "_TRACE_ARG" = None, + contract_address: "AddressType | None" = None, + base_err: Exception | None = None, + source_traceback: "_SOURCE_TRACEBACK_ARG" = None, ): self.abi = abi self.inputs = inputs @@ -918,7 +918,7 @@ def __repr__(self) -> str: return f"{name}({calldata})" -def _get_ape_traceback_from_tx(txn: FailedTxn) -> Optional["SourceTraceback"]: +def _get_ape_traceback_from_tx(txn: "FailedTxn") -> "SourceTraceback | None": from ape.api.transactions import ReceiptAPI try: @@ -944,8 +944,8 @@ def _get_ape_traceback_from_tx(txn: FailedTxn) -> Optional["SourceTraceback"]: def _get_custom_python_traceback( err: TransactionError, ape_traceback: "SourceTraceback", - project: Optional["ProjectManager"] = None, -) -> Optional[TracebackType]: + project: "ProjectManager | None" = None, +) -> TracebackType | None: # Manipulate python traceback to show lines from contract. # Help received from Jinja lib: # https://github.com/pallets/jinja/blob/main/src/jinja2/debug.py#L142 diff --git a/src/ape/logging.py b/src/ape/logging.py index 5b9792c5ce..45fa5f6cf2 100644 --- a/src/ape/logging.py +++ b/src/ape/logging.py @@ -6,7 +6,7 @@ from contextlib import contextmanager from enum import IntEnum from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, Optional, Union +from typing import IO, TYPE_CHECKING, Any from urllib.parse import urlparse, urlunparse import click @@ -70,7 +70,7 @@ def _isatty(stream: IO) -> bool: class ApeColorFormatter(logging.Formatter): - def __init__(self, fmt: Optional[str] = None): + def __init__(self, fmt: str | None = None): fmt = fmt or DEFAULT_LOG_FORMAT super().__init__(fmt=fmt) @@ -98,7 +98,7 @@ def format(self, record): class ClickHandler(logging.Handler): def __init__( - self, echo_kwargs: dict, handlers: Optional[Sequence[Callable[[str], str]]] = None + self, echo_kwargs: dict, handlers: Sequence[Callable[[str], str]] | None = None ): super().__init__() self.echo_kwargs = echo_kwargs @@ -140,17 +140,17 @@ def __init__( self.fmt = fmt @classmethod - def create(cls, fmt: Optional[str] = None) -> "ApeLogger": + def create(cls, fmt: str | None = None) -> "ApeLogger": fmt = fmt or DEFAULT_LOG_FORMAT _logger = get_logger("ape", fmt=fmt) return cls(_logger, fmt) - def format(self, fmt: Optional[str] = None): + def format(self, fmt: str | None = None): self.fmt = fmt or DEFAULT_LOG_FORMAT fmt = fmt or DEFAULT_LOG_FORMAT _format_logger(self._logger, fmt) - def _load_from_sys_argv(self, default: Optional[Union[str, int, LogLevel]] = None): + def _load_from_sys_argv(self, default: str | int | LogLevel | None = None): """ Load from sys.argv to beat race condition with `click`. """ @@ -187,7 +187,7 @@ def _load_from_sys_argv(self, default: Optional[Union[str, int, LogLevel]] = Non def level(self) -> int: return self._logger.level - def set_level(self, level: Union[str, int, LogLevel]): + def set_level(self, level: str | int | LogLevel): """ Change the global ape logger log-level. @@ -207,12 +207,12 @@ def set_level(self, level: Union[str, int, LogLevel]): _logger.setLevel(level) @contextmanager - def at_level(self, level: Union[str, int, LogLevel]) -> Iterator: + def at_level(self, level: str | int | LogLevel) -> Iterator: """ Change the log-level in a context. Args: - level (Union[str, int, LogLevel]): The level to use. + level (str | int | LogLevel): The level to use. Returns: Iterator @@ -273,7 +273,7 @@ def log_debug_stack_trace(self): self._logger.debug(stack_trace) def create_logger( - self, new_name: str, handlers: Optional[Sequence[Callable[[str], str]]] = None + self, new_name: str, handlers: Sequence[Callable[[str], str]] | None = None ) -> logging.Logger: _logger = get_logger(new_name, fmt=self.fmt, handlers=handlers) _logger.setLevel(self.level) @@ -282,7 +282,7 @@ def create_logger( def _format_logger( - _logger: logging.Logger, fmt: str, handlers: Optional[Sequence[Callable[[str], str]]] = None + _logger: logging.Logger, fmt: str, handlers: Sequence[Callable[[str], str]] | None = None ): handler = ClickHandler(echo_kwargs=CLICK_ECHO_KWARGS, handlers=handlers) formatter = ApeColorFormatter(fmt=fmt) @@ -298,17 +298,17 @@ def _format_logger( def get_logger( name: str, - fmt: Optional[str] = None, - handlers: Optional[Sequence[Callable[[str], str]]] = None, + fmt: str | None = None, + handlers: Sequence[Callable[[str], str]] | None = None, ) -> logging.Logger: """ Get a logger with the given ``name`` and configure it for usage with Ape. Args: name (str): The name of the logger. - fmt (Optional[str]): The format of the logger. Defaults to the Ape + fmt (str | None): The format of the logger. Defaults to the Ape logger's default format: ``"%(levelname)s%(plugin)s: %(message)s"``. - handlers (Optional[Sequence[Callable[[str], str]]]): Additional log message handlers. + handlers (Sequence[Callable[[str], str]] | None): Additional log message handlers. Returns: ``logging.Logger`` @@ -318,7 +318,7 @@ def get_logger( return _logger -def _get_level(level: Optional[Union[str, int, LogLevel]] = None) -> str: +def _get_level(level: str | int | LogLevel | None = None) -> str: if level is None: return DEFAULT_LOG_LEVEL elif isinstance(level, LogLevel): @@ -350,7 +350,7 @@ def sanitize_url(url: str) -> str: class _RichConsoleFactory: rich_console_map: dict[str, "RichConsole"] = {} - def get_console(self, file: Optional[IO[str]] = None, **kwargs) -> "RichConsole": + def get_console(self, file: IO[str] | None = None, **kwargs) -> "RichConsole": # Configure custom file console file_id = str(file) if file_id not in self.rich_console_map: @@ -365,12 +365,12 @@ def get_console(self, file: Optional[IO[str]] = None, **kwargs) -> "RichConsole" _factory = _RichConsoleFactory() -def get_rich_console(file: Optional[IO[str]] = None, **kwargs) -> "RichConsole": +def get_rich_console(file: IO[str] | None = None, **kwargs) -> "RichConsole": """ Get an Ape-configured rich console. Args: - file (Optional[IO[str]]): The file to output to. Will default + file (IO[str] | None): The file to output to. Will default to using stdout. Returns: diff --git a/src/ape/managers/_contractscache.py b/src/ape/managers/_contractscache.py index 69e526e77f..8726a23d68 100644 --- a/src/ape/managers/_contractscache.py +++ b/src/ape/managers/_contractscache.py @@ -4,7 +4,7 @@ from contextlib import contextmanager from functools import cached_property from pathlib import Path -from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Generic, TypeVar from ethpm_types import ABI, ContractType from ethpm_types.contract_type import ABIList @@ -56,7 +56,7 @@ def __init__( super().__init__(base_path / key) - def __getitem__(self, key: str) -> Optional[_BASE_MODEL]: # type: ignore + def __getitem__(self, key: str) -> _BASE_MODEL | None: # type: ignore return self.get_type(key) def __setitem__(self, key: str, value: _BASE_MODEL): # type: ignore @@ -77,7 +77,7 @@ def __contains__(self, key: str) -> bool: except KeyError: return False - def get_type(self, key: str) -> Optional[_BASE_MODEL]: + def get_type(self, key: str) -> _BASE_MODEL | None: if model := self.memory.get(key): return model @@ -130,8 +130,8 @@ def _get_data_cache( self, key: str, model_type: type, - ecosystem_key: Optional[str] = None, - network_key: Optional[str] = None, + ecosystem_key: str | None = None, + network_key: str | None = None, ): ecosystem_name = ecosystem_key or self.provider.network.ecosystem.name network_name = network_key or self.provider.network.name.replace("-fork", "") @@ -152,7 +152,7 @@ def deployments(self) -> DeploymentDiskCache: return DeploymentDiskCache() def __setitem__( - self, address: AddressType, item: Union[ContractType, ProxyInfoAPI, ContractCreation] + self, address: AddressType, item: ContractType | ProxyInfoAPI | ContractCreation ): """ Cache the given contract type. Contracts are cached in memory per session. @@ -182,8 +182,8 @@ def cache_contract_type( self, address: AddressType, contract_type: ContractType, - ecosystem_key: Optional[str] = None, - network_key: Optional[str] = None, + ecosystem_key: str | None = None, + network_key: str | None = None, ): """ Cache a contract type at the given address for the given network. @@ -214,8 +214,8 @@ def cache_contract_creation( self, address: AddressType, contract_creation: ContractCreation, - ecosystem_key: Optional[str] = None, - network_key: Optional[str] = None, + ecosystem_key: str | None = None, + network_key: str | None = None, ): """ Cache a contract creation object. @@ -274,7 +274,7 @@ def __contains__(self, address: AddressType) -> bool: def cache_deployment( self, contract_instance: ContractInstance, - proxy_info: Optional[ProxyInfoAPI] = None, + proxy_info: ProxyInfoAPI | None = None, detect_proxy: bool = True, ): """ @@ -283,7 +283,7 @@ def cache_deployment( Args: contract_instance (:class:`~ape.contracts.base.ContractInstance`): The contract to cache. - proxy_info (Optional[ProxyInfoAPI]): Pass in the proxy info, if it is known, to + proxy_info (ProxyInfoAPI | None): Pass in the proxy info, if it is known, to avoid the potentially expensive look-up. detect_proxy (bool): Set to ``False`` to avoid detecting if the contract is a proxy. @@ -364,7 +364,7 @@ def cache_blueprint(self, blueprint_id: str, contract_type: ContractType): """ self.blueprints[blueprint_id] = contract_type - def get_proxy_info(self, address: AddressType) -> Optional[ProxyInfoAPI]: + def get_proxy_info(self, address: AddressType) -> ProxyInfoAPI | None: """ Get proxy information about a contract using its address, either from a local cache, a disk cache, or the provider. @@ -373,11 +373,11 @@ def get_proxy_info(self, address: AddressType) -> Optional[ProxyInfoAPI]: address (AddressType): The address of the proxy contract. Returns: - Optional[:class:`~ape.api.networks.ProxyInfoAPI`] + :class:`~ape.api.networks.ProxyInfoAPI` | None """ return self.proxy_infos[address] - def get_creation_metadata(self, address: AddressType) -> Optional[ContractCreation]: + def get_creation_metadata(self, address: AddressType) -> ContractCreation | None: """ Get contract creation metadata containing txn_hash, deployer, factory, block. @@ -385,7 +385,7 @@ def get_creation_metadata(self, address: AddressType) -> Optional[ContractCreati address (AddressType): The address of the contract. Returns: - Optional[:class:`~ape.api.query.ContractCreation`] + :class:`~ape.api.query.ContractCreation` | None """ if creation := self.contract_creations[address]: return creation @@ -404,7 +404,7 @@ def get_creation_metadata(self, address: AddressType) -> Optional[ContractCreati self.contract_creations[address] = creation return creation - def get_blueprint(self, blueprint_id: str) -> Optional[ContractType]: + def get_blueprint(self, blueprint_id: str) -> ContractType | None: """ Get a cached blueprint contract type. @@ -418,7 +418,7 @@ def get_blueprint(self, blueprint_id: str) -> Optional[ContractType]: return self.blueprints[blueprint_id] def _get_errors( - self, address: AddressType, chain_id: Optional[int] = None + self, address: AddressType, chain_id: int | None = None ) -> set[type[CustomError]]: if chain_id is None and self.network_manager.active_provider is not None: chain_id = self.provider.chain_id @@ -435,7 +435,7 @@ def _get_errors( return set() def _cache_error( - self, address: AddressType, error: type[CustomError], chain_id: Optional[int] = None + self, address: AddressType, error: type[CustomError], chain_id: int | None = None ): if chain_id is None and self.network_manager.active_provider is not None: chain_id = self.provider.chain_id @@ -462,14 +462,14 @@ def __getitem__(self, address: AddressType) -> ContractType: return contract_type def get_multiple( - self, addresses: Collection[AddressType], concurrency: Optional[int] = None + self, addresses: Collection[AddressType], concurrency: int | None = None ) -> dict[AddressType, ContractType]: """ Get contract types for all given addresses. Args: addresses (list[AddressType): A list of addresses to get contract types for. - concurrency (Optional[int]): The number of threads to use. Defaults to + concurrency (int | None): The number of threads to use. Defaults to ``min(4, len(addresses))``. Returns: @@ -518,11 +518,11 @@ def get_contract_type(addr: AddressType): def get( self, address: AddressType, - default: Optional[ContractType] = None, + default: ContractType | None = None, fetch_from_explorer: bool = True, - proxy_info: Optional[ProxyInfoAPI] = None, + proxy_info: ProxyInfoAPI | None = None, detect_proxy: bool = True, - ) -> Optional[ContractType]: + ) -> ContractType | None: """ Get a contract type by address. If the contract is cached, it will return the contract from the cache. @@ -531,17 +531,17 @@ def get( Args: address (AddressType): The address of the contract. - default (Optional[ContractType]): A default contract when none is found. + default (ContractType | None): A default contract when none is found. Defaults to ``None``. fetch_from_explorer (bool): Set to ``False`` to avoid fetching from an explorer. Defaults to ``True``. Only fetches if it needs to (uses disk & memory caching otherwise). - proxy_info (Optional[ProxyInfoAPI]): Pass in the proxy info, if it is known, + proxy_info (ProxyInfoAPI | None): Pass in the proxy info, if it is known, to avoid the potentially expensive look-up. detect_proxy (bool): Set to ``False`` to avoid detecting if it is a proxy. Returns: - Optional[ContractType]: The contract type if it was able to get one, + ContractType | None: The contract type if it was able to get one, otherwise the default parameter. """ try: @@ -603,8 +603,8 @@ def _get_proxy_contract_type( address: AddressType, proxy_info: ProxyInfoAPI, fetch_from_explorer: bool = True, - default: Optional[ContractType] = None, - ) -> Optional[ContractType]: + default: ContractType | None = None, + ) -> ContractType | None: """ Combines the discoverable ABIs from the proxy contract and its implementation. """ @@ -643,8 +643,8 @@ def _get_contract_type( self, address: AddressType, fetch_from_explorer: bool = True, - default: Optional[ContractType] = None, - ) -> Optional[ContractType]: + default: ContractType | None = None, + ) -> ContractType | None: """ Get the _exact_ ContractType for a given address. For proxy contracts, returns the proxy ABIs if there are any and not the implementation ABIs. @@ -673,12 +673,12 @@ def get_container(cls, contract_type: ContractType) -> ContractContainer: def instance_at( self, - address: Union[str, AddressType], - contract_type: Optional[ContractType] = None, - txn_hash: Optional[Union[str, "HexBytes"]] = None, - abi: Optional[Union[list[ABI], dict, str, Path]] = None, + address: str | AddressType, + contract_type: ContractType | None = None, + txn_hash: "str | HexBytes | None" = None, + abi: list[ABI] | dict | str | Path | None = None, fetch_from_explorer: bool = True, - proxy_info: Optional[ProxyInfoAPI] = None, + proxy_info: ProxyInfoAPI | None = None, detect_proxy: bool = True, ) -> ContractInstance: """ @@ -693,18 +693,18 @@ def instance_at( :class:`~ape.exceptions.ContractNotFoundError`: When the contract type is not found. Args: - address (Union[str, AddressType]): The address of the plugin. If you are using the ENS + address (str | AddressType): The address of the plugin. If you are using the ENS plugin, you can also provide an ENS domain name. - contract_type (Optional[``ContractType``]): Optionally provide the contract type + contract_type (ContractType | None): Optionally provide the contract type in case it is not already known. - txn_hash (Optional[Union[str, HexBytes]]): The hash of the transaction responsible for + txn_hash (str | HexBytes | None): The hash of the transaction responsible for deploying the contract, if known. Useful for publishing. Defaults to ``None``. - abi (Optional[Union[list[ABI], dict, str, Path]]): Use an ABI str, dict, path, + abi (list[ABI] | dict | str | Path | None): Use an ABI str, dict, path, or ethpm models to create a contract instance class. fetch_from_explorer (bool): Set to ``False`` to avoid fetching from the explorer. Defaults to ``True``. Won't fetch unless it needs to (uses disk & memory caching first). - proxy_info (Optional[ProxyInfoAPI]): Pass in the proxy info, if it is known, to avoid + proxy_info (ProxyInfoAPI | None): Pass in the proxy info, if it is known, to avoid the potentially expensive look-up. detect_proxy (bool): Set to ``False`` to avoid detecting if the contract is a proxy. @@ -883,7 +883,7 @@ def clear_local_caches(self): self.deployments.clear_local() - def _get_contract_type_from_explorer(self, address: AddressType) -> Optional[ContractType]: + def _get_contract_type_from_explorer(self, address: AddressType) -> ContractType | None: if not self.provider.network.explorer: return None diff --git a/src/ape/managers/_deploymentscache.py b/src/ape/managers/_deploymentscache.py index 12f21078c1..e31ff7a00e 100644 --- a/src/ape/managers/_deploymentscache.py +++ b/src/ape/managers/_deploymentscache.py @@ -1,6 +1,5 @@ from contextlib import contextmanager from pathlib import Path -from typing import Optional from ape.managers.base import BaseManager from ape.types.address import AddressType @@ -14,7 +13,7 @@ class Deployment(BaseModel): """ address: AddressType - transaction_hash: Optional[str] = None + transaction_hash: str | None = None def __getitem__(self, key: str): # Mainly exists for backwards compat. @@ -89,8 +88,8 @@ def __contains__(self, contract_name: str): def get_deployments( self, contract_name: str, - ecosystem_key: Optional[str] = None, - network_key: Optional[str] = None, + ecosystem_key: str | None = None, + network_key: str | None = None, ) -> list[Deployment]: """ Get the deployments of the given contract on the currently connected network. @@ -121,9 +120,9 @@ def cache_deployment( self, address: AddressType, contract_name: str, - transaction_hash: Optional[str] = None, - ecosystem_key: Optional[str] = None, - network_key: Optional[str] = None, + transaction_hash: str | None = None, + ecosystem_key: str | None = None, + network_key: str | None = None, ): """ Update the deployments cache with a new contract. @@ -165,8 +164,8 @@ def _set_deployments( self, contract_name: str, deployments: list[Deployment], - ecosystem_key: Optional[str] = None, - network_key: Optional[str] = None, + ecosystem_key: str | None = None, + network_key: str | None = None, ): ecosystem_name = ecosystem_key or self.provider.network.ecosystem.name network_name = network_key or self.provider.network.name.replace("-fork", "") diff --git a/src/ape/managers/accounts.py b/src/ape/managers/accounts.py index ef00e01831..e997d4c425 100644 --- a/src/ape/managers/accounts.py +++ b/src/ape/managers/accounts.py @@ -2,7 +2,7 @@ from collections.abc import Generator, Iterator from contextlib import AbstractContextManager as ContextManager from functools import cached_property, singledispatchmethod -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from eth_utils import is_hex @@ -27,7 +27,7 @@ @contextlib.contextmanager def _use_sender( - account: Union[AccountAPI, TestAccountAPI], + account: AccountAPI | TestAccountAPI, ) -> "Generator[AccountAPI, TestAccountAPI, None]": try: _DEFAULT_SENDERS.append(account) @@ -199,7 +199,7 @@ def stop_impersonating(self, address: AddressType): def generate_test_account(self, container_name: str = "test") -> TestAccountAPI: return self.containers[container_name].generate_account() - def use_sender(self, account_id: Union[TestAccountAPI, AddressType, int]) -> "ContextManager": + def use_sender(self, account_id: TestAccountAPI | AddressType | int) -> "ContextManager": account = account_id if isinstance(account_id, TestAccountAPI) else self[account_id] return _use_sender(account) @@ -236,7 +236,7 @@ class AccountManager(BaseManager): _alias_to_account_cache: dict[str, AccountAPI] = {} @property - def default_sender(self) -> Optional[AccountAPI]: + def default_sender(self) -> AccountAPI | None: return _DEFAULT_SENDERS[-1] if _DEFAULT_SENDERS else None @cached_property @@ -450,7 +450,7 @@ def __contains__(self, address: AddressType) -> bool: def use_sender( self, - account_id: Union[AccountAPI, AddressType, str, int], + account_id: AccountAPI | AddressType | str | int, ) -> "ContextManager": if not isinstance(account_id, AccountAPI): if isinstance(account_id, int) or is_hex(account_id): @@ -470,8 +470,8 @@ def init_test_account( return self.test_accounts.init_test_account(index, address, private_key) def resolve_address( - self, account_id: Union["BaseAddress", AddressType, str, int, bytes] - ) -> Optional[AddressType]: + self, account_id: "BaseAddress | AddressType | str | int | bytes" + ) -> AddressType | None: """ Resolve the given input to an address. diff --git a/src/ape/managers/chain.py b/src/ape/managers/chain.py index 3d72787f85..aca9270c2d 100644 --- a/src/ape/managers/chain.py +++ b/src/ape/managers/chain.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from functools import cached_property, partial, singledispatchmethod from statistics import mean, median -from typing import IO, TYPE_CHECKING, Optional, Union, cast +from typing import IO, TYPE_CHECKING, cast import pandas as pd from rich.box import SIMPLE @@ -121,9 +121,9 @@ def query( self, *columns: str, start_block: int = 0, - stop_block: Optional[int] = None, + stop_block: int | None = None, step: int = 1, - engine_to_use: Optional[str] = None, + engine_to_use: str | None = None, ) -> pd.DataFrame: """ A method for querying blocks and returning an Iterator. If you @@ -139,11 +139,11 @@ def query( *columns (str): columns in the DataFrame to return start_block (int): The first block, by number, to include in the query. Defaults to 0. - stop_block (Optional[int]): The last block, by number, to include + stop_block (int | None): The last block, by number, to include in the query. Defaults to the latest block. step (int): The number of blocks to iterate between block numbers. Defaults to ``1``. - engine_to_use (Optional[str]): query engine to use, bypasses query + engine_to_use (str | None): query engine to use, bypasses query engine selection algorithm. Returns: @@ -182,9 +182,9 @@ def query( def range( self, start_or_stop: int, - stop: Optional[int] = None, + stop: int | None = None, step: int = 1, - engine_to_use: Optional[str] = None, + engine_to_use: str | None = None, ) -> Iterator[BlockAPI]: """ Iterate over blocks. Works similarly to python ``range()``. @@ -203,12 +203,12 @@ def range( start_or_stop (int): When given just a single value, it is the stop. Otherwise, it is the start. This mimics the behavior of ``range`` built-in Python function. - stop (Optional[int]): The block number to stop before. Also the total + stop (int | None): The block number to stop before. Also the total number of blocks to get. If not setting a start value, is set by the first argument. - step (Optional[int]): The value to increment by. Defaults to ``1``. + step (int | None): The value to increment by. Defaults to ``1``. number of blocks to get. Defaults to the latest block. - engine_to_use (Optional[str]): query engine to use, bypasses query + engine_to_use (str | None): query engine to use, bypasses query engine selection algorithm. Returns: @@ -241,10 +241,10 @@ def range( def poll_blocks( self, - start_block: Optional[int] = None, - stop_block: Optional[int] = None, - required_confirmations: Optional[int] = None, - new_block_timeout: Optional[int] = None, + start_block: int | None = None, + stop_block: int | None = None, + required_confirmations: int | None = None, + new_block_timeout: int | None = None, ) -> Iterator[BlockAPI]: """ Poll new blocks. Optionally set a start block to include historical blocks. @@ -264,14 +264,14 @@ def poll_blocks( print(f"New block found: number={new_block.number}") Args: - start_block (Optional[int]): The block number to start with. Defaults to the pending + start_block (int | None): The block number to start with. Defaults to the pending block number. - stop_block (Optional[int]): Optionally set a future block number to stop at. + stop_block (int | None): Optionally set a future block number to stop at. Defaults to never-ending. - required_confirmations (Optional[int]): The amount of confirmations to wait + required_confirmations (int | None): The amount of confirmations to wait before yielding the block. The more confirmations, the less likely a reorg will occur. Defaults to the network's configured required confirmations. - new_block_timeout (Optional[float]): The amount of time to wait for a new block before + new_block_timeout (float | None): The amount of time to wait for a new block before timing out. Defaults to 10 seconds for local networks or ``50 * block_time`` for live networks. @@ -350,8 +350,8 @@ def query( self, *columns: str, start_nonce: int = 0, - stop_nonce: Optional[int] = None, - engine_to_use: Optional[str] = None, + stop_nonce: int | None = None, + engine_to_use: str | None = None, ) -> pd.DataFrame: """ A method for querying transactions made by an account and returning an Iterator. @@ -367,9 +367,9 @@ def query( *columns (str): columns in the DataFrame to return start_nonce (int): The first transaction, by nonce, to include in the query. Defaults to 0. - stop_nonce (Optional[int]): The last transaction, by nonce, to include + stop_nonce (int | None): The last transaction, by nonce, to include in the query. Defaults to the latest transaction. - engine_to_use (Optional[str]): query engine to use, bypasses query + engine_to_use (str | None): query engine to use, bypasses query engine selection algorithm. Returns: @@ -515,7 +515,7 @@ def __getitem_base_address(self, address: BaseAddress) -> AccountHistory: return self._get_account_history(address) @__getitem__.register - def __getitem_str(self, account_or_hash: str) -> Union[AccountHistory, ReceiptAPI]: + def __getitem_str(self, account_or_hash: str) -> AccountHistory | ReceiptAPI: """ Get a receipt from the history by its transaction hash. If the receipt is not currently cached, will use the provider @@ -528,7 +528,7 @@ def __getitem_str(self, account_or_hash: str) -> Union[AccountHistory, ReceiptAP :class:`~ape.api.transactions.ReceiptAPI`: The receipt. """ - def _get_receipt() -> Optional[ReceiptAPI]: + def _get_receipt() -> ReceiptAPI | None: try: return self._get_receipt(account_or_hash) except Exception: @@ -612,7 +612,7 @@ def revert_to_block(self, block_number: int): for account_history in self._account_history_cache.values(): account_history.revert_to_block(block_number) - def _get_account_history(self, address: Union[BaseAddress, AddressType]) -> AccountHistory: + def _get_account_history(self, address: BaseAddress | AddressType) -> AccountHistory: address_key: AddressType = self.conversion_manager.convert(address, AddressType) if address_key not in self._account_history_cache: @@ -629,7 +629,7 @@ class ReportManager(BaseManager): **NOTE**: This class is not part of the public API. """ - def show_gas(self, report: "GasReport", file: Optional[IO[str]] = None): + def show_gas(self, report: "GasReport", file: IO[str] | None = None): tables: list[Table] = [] for contract_id, method_calls in report.items(): @@ -670,7 +670,7 @@ def show_gas(self, report: "GasReport", file: Optional[IO[str]] = None): self.echo(*tables, file=file) def echo( - self, *rich_items, file: Optional[IO[str]] = None, console: Optional["RichConsole"] = None + self, *rich_items, file: IO[str] | None = None, console: "RichConsole | None" = None ): console = console or get_rich_console(file) console.print(*rich_items) @@ -678,8 +678,8 @@ def echo( def show_source_traceback( self, traceback: "SourceTraceback", - file: Optional[IO[str]] = None, - console: Optional["RichConsole"] = None, + file: IO[str] | None = None, + console: "RichConsole | None" = None, failing: bool = True, ): console = console or get_rich_console(file) @@ -687,7 +687,7 @@ def show_source_traceback( console.print(str(traceback), style=style) def show_events( - self, events: list, file: Optional[IO[str]] = None, console: Optional["RichConsole"] = None + self, events: list, file: IO[str] | None = None, console: "RichConsole | None" = None ): console = console or get_rich_console(file) console.print("Events emitted:") @@ -799,7 +799,7 @@ def pending_timestamp(self) -> int: return self.provider.get_block("pending").timestamp @pending_timestamp.setter - def pending_timestamp(self, new_value: Union[int, str]): + def pending_timestamp(self, new_value: int | str): self.provider.set_timestamp(self.conversion_manager.convert(new_value, int)) @log_instead_of_fail(default="") @@ -828,7 +828,7 @@ def snapshot(self) -> "SnapshotID": return snapshot_id - def restore(self, snapshot_id: Optional["SnapshotID"] = None): + def restore(self, snapshot_id: "SnapshotID | None" = None): """ Regress the current call using the given snapshot ID. Allows developers to go back to a previous state. @@ -840,7 +840,7 @@ def restore(self, snapshot_id: Optional["SnapshotID"] = None): :class:`~ape.exceptions.ChainError`: When there are no snapshot IDs to select from. Args: - snapshot_id (Optional[:class:`~ape.types.SnapshotID`]): The snapshot ID. Defaults + snapshot_id (:class:`~ape.types.SnapshotID` | None): The snapshot ID. Defaults to the most recent snapshot ID. """ chain_id = self.provider.chain_id @@ -912,8 +912,8 @@ def isolate(self): def mine( self, num_blocks: int = 1, - timestamp: Optional[int] = None, - deltatime: Optional[int] = None, + timestamp: int | None = None, + deltatime: int | None = None, ) -> None: """ Mine any given number of blocks. @@ -924,9 +924,9 @@ def mine( Args: num_blocks (int): Choose the number of blocks to mine. Defaults to 1 block. - timestamp (Optional[int]): Designate a time (in seconds) to begin mining. + timestamp (int | None): Designate a time (in seconds) to begin mining. Defaults to None. - deltatime (Optional[int]): Designate a change in time (in seconds) to begin mining. + deltatime (int | None): Designate a change in time (in seconds) to begin mining. Defaults to None. """ if timestamp and deltatime: @@ -938,7 +938,7 @@ def mine( self.provider.mine(num_blocks) def get_balance( - self, address: Union[BaseAddress, AddressType, str], block_id: Optional["BlockID"] = None + self, address: BaseAddress | AddressType | str, block_id: "BlockID | None" = None ) -> int: """ Get the balance of the given address. If ``ape-ens`` is installed, @@ -959,7 +959,7 @@ def get_balance( return self.provider.get_balance(address, block_id=block_id) - def set_balance(self, account: Union[BaseAddress, AddressType, str], amount: Union[int, str]): + def set_balance(self, account: BaseAddress | AddressType | str, amount: int | str): """ Set an account balance, only works on development chains. @@ -996,7 +996,7 @@ def get_receipt(self, transaction_hash: str) -> ReceiptAPI: return receipt def get_code( - self, address: AddressType, block_id: Optional["BlockID"] = None + self, address: AddressType, block_id: "BlockID | None" = None ) -> "ContractCode": network = self.provider.network @@ -1018,7 +1018,7 @@ def get_code( self._code[network.ecosystem.name][network.name][address] = code return code - def get_delegate(self, address: AddressType) -> Optional[BaseAddress]: + def get_delegate(self, address: AddressType) -> BaseAddress | None: ecosystem = self.provider.network.ecosystem if not (proxy_info := ecosystem.get_proxy_info(address)): diff --git a/src/ape/managers/compilers.py b/src/ape/managers/compilers.py index 59d9df64c0..a7ab147ad4 100644 --- a/src/ape/managers/compilers.py +++ b/src/ape/managers/compilers.py @@ -2,7 +2,7 @@ from collections.abc import Iterable, Iterator, Sequence from functools import cached_property from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from eth_pydantic_types import HexBytes @@ -81,7 +81,7 @@ def registered_compilers(self) -> dict[str, "CompilerAPI"]: return registered_compilers - def get_compiler(self, name: str, settings: Optional[dict] = None) -> Optional["CompilerAPI"]: + def get_compiler(self, name: str, settings: dict | None = None) -> "CompilerAPI | None": for compiler in self.registered_compilers.values(): if compiler.name != name: continue @@ -96,10 +96,10 @@ def get_compiler(self, name: str, settings: Optional[dict] = None) -> Optional[" def compile( self, - contract_filepaths: Union[Path, str, Iterable[Union[Path, str]]], - project: Optional["ProjectManager"] = None, - settings: Optional[dict] = None, - excluded_compilers: Optional[list[str]] = None, + contract_filepaths: Path | str | Iterable[Path | str], + project: "ProjectManager | None" = None, + settings: dict | None = None, + excluded_compilers: list[str] | None = None, ) -> Iterator["ContractType"]: """ Invoke :meth:`ape.ape.compiler.CompilerAPI.compile` for each of the given files. @@ -111,11 +111,11 @@ def compile( file-extension as well as when there are contract-type collisions across compilers. Args: - contract_filepaths (Union[Path, str, Iterable[Union[Path, str]]]): The files to + contract_filepaths (Path | str | Iterable[Path | str]): The files to compile, as ``pathlib.Path`` objects or path-strs. project (Optional[:class:`~ape.managers.project.ProjectManager`]): Optionally compile a different project that the one from the current-working directory. - settings (Optional[Dict]): Adhoc compiler settings. Defaults to None. + settings (dict | None): Adhoc compiler settings. Defaults to None. Ensure the compiler name key is present in the dict for it to work. Returns: @@ -177,8 +177,8 @@ def compile_source( self, compiler_name: str, code: str, - project: Optional["ProjectManager"] = None, - settings: Optional[dict] = None, + project: "ProjectManager | None" = None, + settings: dict | None = None, **kwargs, ) -> ContractContainer: """ @@ -198,7 +198,7 @@ def compile_source( code (str): The source code to compile. project (Optional[:class:`~ape.managers.project.ProjectManager`]): Optionally compile a different project that the one from the current-working directory. - settings (Optional[dict]): Compiler settings. + settings (dict | None): Compiler settings. **kwargs (Any): Additional overrides for the ``ethpm_types.ContractType`` model. Returns: @@ -214,7 +214,7 @@ def compile_source( def get_imports( self, contract_filepaths: Sequence[Path], - project: Optional["ProjectManager"] = None, + project: "ProjectManager | None" = None, ) -> dict[str, list[str]]: """ Combine import dicts from all compilers, where the key is a contract's source_id @@ -222,7 +222,7 @@ def get_imports( Args: contract_filepaths (Sequence[pathlib.Path]): A list of source file paths to compile. - project (Optional[:class:`~ape.managers.project.ProjectManager`]): Optionally provide + project (:class:`~ape.managers.project.ProjectManager` | None): Optionally provide the project. Returns: @@ -299,7 +299,7 @@ def enrich_error(self, err: ContractLogicError) -> ContractLogicError: # No further enrichment. return err - def get_custom_error(self, err: ContractLogicError) -> Optional[CustomError]: + def get_custom_error(self, err: ContractLogicError) -> CustomError | None: """ Get a custom error for the given contract logic error using the contract-type found from address-data in the error. Returns ``None`` if the given error is @@ -311,7 +311,7 @@ def get_custom_error(self, err: ContractLogicError) -> Optional[CustomError]: as a custom error. Returns: - Optional[:class:`~ape.exceptions.CustomError`] + :class:`~ape.exceptions.CustomError` | None """ message = err.revert_message if not message.startswith("0x"): diff --git a/src/ape/managers/config.py b/src/ape/managers/config.py index d4b6a6025d..3ebe2dcde4 100644 --- a/src/ape/managers/config.py +++ b/src/ape/managers/config.py @@ -4,7 +4,7 @@ from contextlib import contextmanager from functools import cached_property from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from ape.api.config import ApeConfig from ape.logging import logger @@ -35,7 +35,7 @@ class ConfigManager(ExtraAttributesMixin, BaseManager): definitions, see :class:`~ape.api.config.ApeConfig`. """ - def __init__(self, data_folder: Optional[Path] = None, request_header: Optional[dict] = None): + def __init__(self, data_folder: Path | None = None, request_header: dict | None = None): if not data_folder and "APE_DATA_FOLDER" in os.environ: self.DATA_FOLDER = Path(os.environ["APE_DATA_FOLDER"]) else: @@ -119,7 +119,7 @@ def extract_config(cls, manifest: "PackageManifest", **overrides) -> ApeConfig: @contextmanager def isolate_data_folder( - self, keep: Optional[Union[Iterable[str], str]] = None + self, keep: Iterable[str] | str | None = None ) -> Iterator[Path]: """ Change Ape's DATA_FOLDER to point a temporary path, @@ -127,7 +127,7 @@ def isolate_data_folder( cached to disk will not persist. Args: - keep (Optional[Union[Iterable[str], str]]): Optionally, pass in + keep (Iterable[str] | str | None): Optionally, pass in a key of subdirectory names to include in the new isolated data folder. For example, pass ing ``"packages"`` to avoid having to re-download dependencies in an isolated environment. diff --git a/src/ape/managers/converters.py b/src/ape/managers/converters.py index 8801a9b040..5d8c533655 100644 --- a/src/ape/managers/converters.py +++ b/src/ape/managers/converters.py @@ -3,7 +3,7 @@ from datetime import datetime, timedelta, timezone from decimal import Decimal from functools import cached_property -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any from cchecksum import to_checksum_address from dateutil.parser import parse @@ -161,7 +161,7 @@ class IntAddressConverter(ConverterAPI): A converter that converts an integer address to an :class:`~ape.types.address.AddressType`. """ - _cache: dict[int, Union[AddressType, bool]] = {} + _cache: dict[int, AddressType | bool] = {} def is_convertible(self, value: Any) -> bool: if not isinstance(value, int): @@ -191,7 +191,7 @@ def convert(self, value: Any) -> AddressType: return res - def _convert(self, value: int) -> Union[AddressType, bool]: + def _convert(self, value: int) -> AddressType | bool: try: val = Address.__eth_pydantic_validate__(value) except Exception: @@ -206,7 +206,7 @@ class TimestampConverter(ConverterAPI): No timezone required, but should be formatted to UTC. """ - def is_convertible(self, value: Union[str, datetime, timedelta]) -> bool: + def is_convertible(self, value: str | datetime | timedelta) -> bool: if not isinstance(value, (str, datetime, timedelta)): return False if isinstance(value, str): @@ -219,7 +219,7 @@ def is_convertible(self, value: Union[str, datetime, timedelta]) -> bool: return False return True - def convert(self, value: Union[str, datetime, timedelta]) -> int: + def convert(self, value: str | datetime | timedelta) -> int: if isinstance(value, str): return int(parse(value).replace(tzinfo=timezone.utc).timestamp()) elif isinstance(value, datetime): @@ -318,7 +318,7 @@ def is_type(self, value: Any, to_type: type) -> bool: """ return is_checksum_address(value) if to_type is AddressType else isinstance(value, to_type) - def convert(self, value: Any, to_type: Union[type, tuple, list]) -> Any: + def convert(self, value: Any, to_type: type | tuple | list) -> Any: """ Convert the given value to the given type. This method accesses all :class:`~ape.api.convert.ConverterAPI` instances known to @@ -432,7 +432,7 @@ def get_converter(self, name: str) -> ConverterAPI: def convert_method_args( self, - abi: Union["MethodABI", "ConstructorABI", "EventABI", "ConstructorABI"], + abi: "MethodABI | ConstructorABI | EventABI", arguments: Sequence[Any], ): input_types = [i.canonical_type for i in abi.inputs] @@ -452,14 +452,33 @@ def convert_method_kwargs(self, kwargs) -> dict: fields = TransactionAPI.__pydantic_fields__ def get_real_type(type_): - all_types = getattr(type_, "_typevar_types", []) - if not all_types or not isinstance(all_types, (list, tuple)): - return type_ - - # Filter out None - valid_types = [t for t in all_types if t is not None] + # Handle both old (Optional/Union) and new (|) syntax + from typing import get_args, get_origin, Union + from types import UnionType + + # Try old syntax first (Optional/Union) - uses _typevar_types + all_types = getattr(type_, "_typevar_types", None) + if all_types and isinstance(all_types, (list, tuple)) and len(all_types) > 0: + # Old syntax found, use it + pass + else: + # Try new syntax (| operator) - uses get_args/get_origin + origin = get_origin(type_) + if origin is not None and origin in (Union, UnionType): + args = get_args(type_) + if args: + all_types = list(args) + else: + # No args means it's not a union, return as-is + return type_ + else: + # Not a union type, return as-is + return type_ + + # Filter out None/NoneType + valid_types = [t for t in all_types if t is not type(None) and t is not None] if len(valid_types) == 1: - # This is something like Optional[int], + # This is something like Optional[int] or int | None, # however, if the user provides a value, # we want to convert to the non-optional type. return valid_types[0] diff --git a/src/ape/managers/networks.py b/src/ape/managers/networks.py index f0b72c9d49..23723bb381 100644 --- a/src/ape/managers/networks.py +++ b/src/ape/managers/networks.py @@ -3,7 +3,7 @@ from collections.abc import Collection, Iterator from functools import cached_property from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from evmchains import PUBLIC_CHAIN_META from pydantic import ValidationError @@ -39,17 +39,17 @@ class NodeProcessData(BaseModel): The network triple ``ecosystem:network:node``. """ - ipc_path: Optional[Path] = None + ipc_path: Path | None = None """ The IPC path this node process communicates on. """ - http_uri: Optional[str] = None + http_uri: str | None = None """ The HTTP URI this node process exposes. """ - ws_uri: Optional[str] = None + ws_uri: str | None = None """ The websockets URI this node process exposes. """ @@ -97,7 +97,7 @@ def __bool__(self) -> bool: """ return bool(self.nodes) - def __contains__(self, pid_or_provider: Union[int, "SubprocessProvider"]) -> bool: + def __contains__(self, pid_or_provider: "int | SubprocessProvider") -> bool: if isinstance(pid_or_provider, int): return pid_or_provider in self.nodes @@ -107,7 +107,7 @@ def __contains__(self, pid_or_provider: Union[int, "SubprocessProvider"]) -> boo return False - def get(self, pid: Union[int, str]) -> Optional[NodeProcessData]: + def get(self, pid: int | str) -> NodeProcessData | None: return self.nodes.get(int(pid)) def lookup_processes(self, provider: "SubprocessProvider") -> dict[int, NodeProcessData]: @@ -177,8 +177,8 @@ class NetworkManager(BaseManager, ExtraAttributesMixin): ... """ - _active_provider: Optional["ProviderAPI"] = None - _default_ecosystem_name: Optional[str] = None + _active_provider: "ProviderAPI | None" = None + _default_ecosystem_name: str | None = None # For adhoc adding custom networks, or incorporating some defined # in other projects' configs. @@ -192,7 +192,7 @@ def __repr__(self) -> str: return f"<{content}>" @property - def active_provider(self) -> Optional["ProviderAPI"]: + def active_provider(self) -> "ProviderAPI | None": """ The currently connected provider if one exists. Otherwise, returns ``None``. """ @@ -262,7 +262,7 @@ def get_running_node(self, pid: int) -> "SubprocessProvider": if not (data := self.running_nodes.get(pid)): raise NetworkError(f"No running node for pid '{pid}'.") - uri: Optional[Union[str, Path]] = None + uri: str | Path | None = None if ipc := data.ipc_path: if ipc.exists(): uri = ipc @@ -351,9 +351,9 @@ def get_request_headers( def fork( self, - provider_name: Optional[str] = None, - provider_settings: Optional[dict] = None, - block_number: Optional[int] = None, + provider_name: str | None = None, + provider_settings: dict | None = None, + block_number: int | None = None, ) -> ProviderContextManager: """ Fork the currently connected network. @@ -363,7 +363,7 @@ def fork( When ``None``, returns the default provider. provider_settings (dict, optional): Settings to apply to the provider. Defaults to ``None``. - block_number (Optional[int]): Optionally specify the block number you wish to fork. + block_number (int | None): Optionally specify the block number you wish to fork. Negative block numbers are relative to HEAD. Defaults to the configured fork block number or HEAD. @@ -543,7 +543,7 @@ def create_custom_provider( self, connection_str: str, provider_cls: type["ProviderAPI"] = EthereumNodeProvider, - provider_name: Optional[str] = None, + provider_name: str | None = None, ) -> "ProviderAPI": """ Create a custom connection to a URI using the EthereumNodeProvider provider. @@ -555,7 +555,7 @@ def create_custom_provider( when using HTTP. provider_cls (type[:class:`~ape.api.providers.ProviderAPI`]): Defaults to :class:`~ape_ethereum.providers.EthereumNodeProvider`. - provider_name (Optional[str]): The name of the provider. Defaults to best guess. + provider_name (str | None): The name of the provider. Defaults to best guess. Returns: :class:`~ape.api.providers.ProviderAPI`: The Geth provider @@ -631,9 +631,9 @@ def __getattr__(self, attr_name: str) -> "EcosystemAPI": def get_network_choices( self, - ecosystem_filter: Optional[Union[list[str], str]] = None, - network_filter: Optional[Union[list[str], str]] = None, - provider_filter: Optional[Union[list[str], str]] = None, + ecosystem_filter: list[str] | str | None = None, + network_filter: list[str] | str | None = None, + provider_filter: list[str] | str | None = None, ) -> Iterator[str]: """ The set of all possible network choices available as a "network selection" @@ -648,11 +648,11 @@ def get_network_choices( combinations. Args: - ecosystem_filter (Optional[Union[list[str], str]]): Get only the specified ecosystems. + ecosystem_filter (list[str] | str | None): Get only the specified ecosystems. Defaults to getting all ecosystems. - network_filter (Optional[Union[list[str], str]]): Get only the specified networks. + network_filter (list[str] | str | None): Get only the specified networks. Defaults to getting all networks in ecosystems. - provider_filter (Optional[Union[list[str], str]]): Get only the specified providers. + provider_filter (list[str] | str | None): Get only the specified providers. Defaults to getting all providers in networks. Returns: @@ -746,7 +746,7 @@ def get_ecosystem(self, ecosystem_name: str) -> "EcosystemAPI": raise EcosystemNotFoundError(ecosystem_name, options=self.ecosystem_names) - def _get_ecosystem_from_evmchains(self, ecosystem_name: str) -> Optional["EcosystemAPI"]: + def _get_ecosystem_from_evmchains(self, ecosystem_name: str) -> "EcosystemAPI | None": if ecosystem_name not in PUBLIC_CHAIN_META: return None @@ -769,8 +769,8 @@ def _get_ecosystem_from_evmchains(self, ecosystem_name: str) -> Optional["Ecosys def get_provider_from_choice( self, - network_choice: Optional[str] = None, - provider_settings: Optional[dict] = None, + network_choice: str | None = None, + provider_settings: dict | None = None, ) -> "ProviderAPI": """ Get a :class:`~ape.api.providers.ProviderAPI` from a network choice. @@ -851,8 +851,8 @@ def get_provider_from_choice( def parse_network_choice( self, - network_choice: Optional[str] = None, - provider_settings: Optional[dict] = None, + network_choice: str | None = None, + provider_settings: dict | None = None, disconnect_after: bool = False, disconnect_on_exit: bool = True, ) -> ProviderContextManager: @@ -939,9 +939,9 @@ def network_data(self) -> dict: def get_network_data( self, - ecosystem_filter: Optional[Collection[str]] = None, - network_filter: Optional[Collection[str]] = None, - provider_filter: Optional[Collection[str]] = None, + ecosystem_filter: Collection[str] | None = None, + network_filter: Collection[str] | None = None, + provider_filter: Collection[str] | None = None, ): data: dict = {"ecosystems": []} @@ -961,8 +961,8 @@ def get_network_data( def _get_ecosystem_data( self, ecosystem_name: str, - network_filter: Optional[Collection[str]] = None, - provider_filter: Optional[Collection[str]] = None, + network_filter: Collection[str] | None = None, + provider_filter: Collection[str] | None = None, ) -> dict: ecosystem = self[ecosystem_name] ecosystem_data: dict = {"name": str(ecosystem_name)} @@ -990,7 +990,7 @@ def _invalidate_cache(self): self._custom_networks = [] -def _validate_filter(arg: Optional[Union[list[str], str]], options: set[str]): +def _validate_filter(arg: list[str] | str | None, options: set[str]): filters = arg or [] if isinstance(filters, str): diff --git a/src/ape/managers/plugins.py b/src/ape/managers/plugins.py index ebf3a5a76f..e9d2ed9aa1 100644 --- a/src/ape/managers/plugins.py +++ b/src/ape/managers/plugins.py @@ -2,7 +2,7 @@ from functools import cached_property from importlib import import_module from itertools import chain -from typing import Any, Optional +from typing import Any from ape.exceptions import ApeAttributeError from ape.logging import logger @@ -143,7 +143,7 @@ def _register_plugins(self): self.__registered = True - def _validate_plugin(self, plugin_name: str, plugin_cls) -> Optional[tuple[str, tuple]]: + def _validate_plugin(self, plugin_name: str, plugin_cls) -> tuple[str, tuple] | None: if valid_impl(plugin_cls): return clean_plugin_name(plugin_name), plugin_cls else: diff --git a/src/ape/managers/project.py b/src/ape/managers/project.py index 3b90ded994..6e81d0ce5d 100644 --- a/src/ape/managers/project.py +++ b/src/ape/managers/project.py @@ -6,7 +6,7 @@ from functools import cached_property, singledispatchmethod from pathlib import Path from re import Pattern -from typing import Any, Optional, Union, cast +from typing import Any, cast from eth_typing import HexStr from eth_utils import to_hex @@ -63,7 +63,7 @@ class SourceManager(BaseManager): but with more functionality for active development. """ - _path_cache: Optional[list[Path]] = None + _path_cache: list[Path] | None = None # perf: calculating paths from source Ids can be expensive. _path_to_source_id: dict[Path, str] = {} @@ -72,7 +72,7 @@ def __init__( self, root_path: Path, get_contracts_path: Callable, - exclude_globs: Optional[set[Union[str, Pattern]]] = None, + exclude_globs: set[str | Pattern] | None = None, ): self.root_path = root_path self.get_contracts_path = get_contracts_path @@ -106,7 +106,7 @@ def __getitem__(self, source_id: str) -> Source: return src - def get(self, source_id: str) -> Optional[Source]: + def get(self, source_id: str) -> Source | None: """ Get a Source by source_id. @@ -121,7 +121,7 @@ def get(self, source_id: str) -> Optional[Source]: for path in self.paths: if self._get_source_id(path) == source_id: - text: Union[str, dict] + text: str | dict if path.is_file(): try: text = path.read_text(encoding="utf8") @@ -282,12 +282,12 @@ def is_excluded(self, path: Path, exclude_missing_compilers: bool = True) -> boo self._exclude_cache[source_id] = False return False - def lookup(self, path_id: Union[str, Path]) -> Optional[Path]: + def lookup(self, path_id: str | Path) -> Path | None: """ Look-up a path by given a sub-path or a source ID. Args: - path_id (Union[str, Path]): Either part of a path + path_id (str | Path): Either part of a path or a source ID. Returns: @@ -301,7 +301,7 @@ def lookup(self, path_id: Union[str, Path]) -> Optional[Path]: input_stem = input_path.stem input_extension = get_full_extension(input_path) or None - def find_in_dir(dir_path: Path, path: Path) -> Optional[Path]: + def find_in_dir(dir_path: Path, path: Path) -> Path | None: # Try exact match with or without extension possible_matches = [] contracts_folder = self.get_contracts_path() @@ -411,7 +411,7 @@ def __len__(self) -> int: def get( self, name: str, compile_missing: bool = True, check_for_changes: bool = True - ) -> Optional[ContractContainer]: + ) -> ContractContainer | None: """ Get a contract by name. @@ -479,11 +479,11 @@ def values(self) -> Iterator[ContractContainer]: # Deleted before yield. continue - def _compile_missing_contracts(self, paths: Iterable[Union[Path, str]]): + def _compile_missing_contracts(self, paths: Iterable[Path | str]): non_compiled_sources = self._get_needs_compile(paths) self._compile_contracts(non_compiled_sources) - def _get_needs_compile(self, paths: Iterable[Union[Path, str]]) -> Iterable[Path]: + def _get_needs_compile(self, paths: Iterable[Path | str]) -> Iterable[Path]: for path in paths: if self._detect_change(path): if isinstance(path, str): @@ -493,8 +493,8 @@ def _get_needs_compile(self, paths: Iterable[Union[Path, str]]) -> Iterable[Path def _compile_contracts( self, - paths: Iterable[Union[Path, str]], - excluded_compilers: Optional[list[str]] = None, + paths: Iterable[Path | str], + excluded_compilers: list[str] | None = None, ): if not ( new_types := { @@ -525,9 +525,9 @@ def _compile_all(self, use_cache: bool = True) -> Iterator[ContractContainer]: def _compile( self, - paths: Union[Path, str, Iterable[Union[Path, str]]], + paths: Path | str | Iterable[Path | str], use_cache: bool = True, - excluded_compilers: Optional[list[str]] = None, + excluded_compilers: list[str] | None = None, ) -> Iterator[ContractContainer]: path_ls = list([paths] if isinstance(paths, (Path, str)) else paths) if not path_ls: @@ -553,7 +553,7 @@ def _compile( if contract_type.source_id and contract_type.source_id in src_ids: yield ContractContainer(contract_type) - def _detect_change(self, path: Union[Path, str]) -> bool: + def _detect_change(self, path: Path | str) -> bool: if not (existing_types := (self.project.manifest.contract_types or {}).values()): return True # Nothing compiled yet. @@ -604,12 +604,12 @@ class Dependency(BaseManager, ExtraAttributesMixin): them from ``project.dependencies``. """ - def __init__(self, api: DependencyAPI, project: Optional["ProjectManager"] = None): + def __init__(self, api: DependencyAPI, project: "ProjectManager | None" = None): self.api = api # This is the base project using this dependency. self.base_project = project or self.local_project # When installed (and set, lazily), this is the dependency project. - self._installation: Optional[ProjectManager] = None + self._installation: "ProjectManager | None" = None self._tried_fetch = False @log_instead_of_fail(default="") @@ -745,7 +745,7 @@ def uri(self) -> str: def install( self, use_cache: bool = True, - config_override: Optional[dict] = None, + config_override: dict | None = None, recurse: bool = True, ) -> "ProjectManager": """ @@ -874,7 +874,7 @@ def uninstall(self): def compile( self, use_cache: bool = True, - config_override: Optional[dict] = None, + config_override: dict | None = None, allow_install: bool = False, ) -> dict[str, ContractContainer]: """ @@ -882,7 +882,7 @@ def compile( Args: use_cache (bool): Set to ``False`` to force a re-compile. - config_override (Optional[dict]): Optionally override the configuration, + config_override (dict | None): Optionally override the configuration, which may be needed for compiling. allow_install (bool): Set to ``True`` to allow installing. @@ -1179,8 +1179,8 @@ def __getitem__(self, version: str) -> "ProjectManager": raise KeyError(version) def get( # type: ignore - self, version: str, default: Optional["ProjectManager"] = None - ) -> Optional["ProjectManager"]: + self, version: str, default: "ProjectManager | None" = None + ) -> "ProjectManager | None": options = _version_to_options(version) for vers in options: if not dict.__contains__(self, vers): # type: ignore @@ -1205,7 +1205,7 @@ class DependencyManager(BaseManager): # Class-level cache _cache: dict[DependencyAPI, Dependency] = {} - def __init__(self, project: Optional["ProjectManager"] = None): + def __init__(self, project: "ProjectManager | None" = None): self.project = project or self.local_project @log_instead_of_fail(default="") @@ -1298,9 +1298,9 @@ def specified(self) -> Iterator[Dependency]: def get_project_dependencies( self, use_cache: bool = True, - config_override: Optional[dict] = None, - name: Optional[str] = None, - version: Optional[str] = None, + config_override: dict | None = None, + name: str | None = None, + version: str | None = None, allow_install: bool = True, strict: bool = False, recurse: bool = True, @@ -1311,9 +1311,9 @@ def get_project_dependencies( Args: use_cache (bool): Set to ``False`` to force-reinstall dependencies. Defaults to ``True``. Does not work with ``allow_install=False``. - config_override (Optional[dict]): Override shared configuration for each dependency. - name (Optional[str]): Optionally only get dependencies with a certain name. - version (Optional[str]): Optionally only get dependencies with certain version. + config_override (dict | None): Override shared configuration for each dependency. + name (str | None): Optionally only get dependencies with a certain name. + version (str | None): Optionally only get dependencies with certain version. allow_install (bool): Set to ``False`` to not allow installing uninstalled specified dependencies. strict (bool): ``True`` requires the dependency to either be installed or install properly. recurse (bool): Set to ``False`` to not recursively install dependencies of dependencies. @@ -1416,13 +1416,13 @@ def uri_map(self) -> dict[str, Url]: def get( self, name: str, version: str, allow_install: bool = True - ) -> Optional["ProjectManager"]: + ) -> "ProjectManager | None": if dependency := self._get(name, version, allow_install=allow_install, checked=set()): return dependency.project return None - def get_dependency_api(self, package_id: str, version: Optional[str] = None) -> DependencyAPI: + def get_dependency_api(self, package_id: str, version: str | None = None) -> DependencyAPI: """ Get a dependency API. If not given version and there are multiple, returns the latest. @@ -1448,8 +1448,8 @@ def get_dependency_api(self, package_id: str, version: Optional[str] = None) -> raise ProjectError(message) def _get_dependency_api_by_package_id( - self, package_id: str, version: Optional[str] = None, attr: str = "package_id" - ) -> Optional[DependencyAPI]: + self, package_id: str, version: str | None = None, attr: str = "package_id" + ) -> DependencyAPI | None: matching = [] # First, only look at local configured packages (to give priority). @@ -1478,8 +1478,8 @@ def _get( name: str, version: str, allow_install: bool = True, - checked: Optional[set] = None, - ) -> Optional[Dependency]: + checked: set | None = None, + ) -> Dependency | None: checked = checked or set() # Check already-installed first to prevent having to install anything. @@ -1678,7 +1678,7 @@ def decode_dependency(self, **item: Any) -> DependencyAPI: f"Keys={', '.join([x for x in item.keys()])}" ) - def add(self, dependency: Union[dict, DependencyAPI]) -> Dependency: + def add(self, dependency: dict | DependencyAPI) -> Dependency: """ Add the dependency API data. This sets up a dependency such that it can be fetched. @@ -1710,7 +1710,7 @@ def add(self, dependency: Union[dict, DependencyAPI]) -> Dependency: f"Failed to add dependency {api.name}@{api.version_id}: {err}" ) from err - def install(self, **dependency: Any) -> Union[Dependency, list[Dependency]]: + def install(self, **dependency: Any) -> Dependency | list[Dependency]: """ Install dependencies. @@ -1741,9 +1741,9 @@ def install(self, **dependency: Any) -> Union[Dependency, list[Dependency]]: def install_dependency( self, - dependency_data: Union[dict, DependencyAPI], + dependency_data: dict | DependencyAPI, use_cache: bool = True, - config_override: Optional[dict] = None, + config_override: dict | None = None, recurse: bool = True, ) -> Dependency: dependency = self.add(dependency_data) @@ -1766,7 +1766,7 @@ def unpack(self, base_path: Path, cache_name: str = ".cache"): dependency.unpack(cache_folder) -def _load_manifest(path: Union[Path, str]) -> PackageManifest: +def _load_manifest(path: Path | str) -> PackageManifest: path = Path(path) return ( PackageManifest.model_validate_json(path.read_text()) @@ -1814,16 +1814,16 @@ def _project(self) -> "LocalProject": @classmethod def from_manifest( cls, - manifest: Union[PackageManifest, Path, str], - config_override: Optional[dict] = None, + manifest: PackageManifest | Path | str, + config_override: dict | None = None, ) -> "Project": """ Create an Ape project using only a manifest. Args: - manifest (Union[PackageManifest, Path, str]): Either a manifest or a + manifest (PackageManifest | Path | str): Either a manifest or a path to a manifest file. - config_override (Optional[Dict]): Optionally provide a config override. + config_override (dict | None): Optionally provide a config override. Returns: :class:`~ape.managers.project.ProjectManifest` @@ -1834,7 +1834,7 @@ def from_manifest( @classmethod def from_python_library( - cls, package_name: str, config_override: Optional[dict] = None + cls, package_name: str, config_override: dict | None = None ) -> "LocalProject": """ Create an Ape project instance from an installed Python package. @@ -1861,7 +1861,7 @@ def from_python_library( @classmethod @contextmanager def create_temporary_project( - cls, config_override: Optional[dict] = None + cls, config_override: dict | None = None ) -> Iterator["LocalProject"]: cls._invalidate_project_dependent_caches() with create_tempdir() as path: @@ -1879,7 +1879,7 @@ class Project(ProjectManager): manifests or local source-paths. """ - def __init__(self, manifest: PackageManifest, config_override: Optional[dict] = None): + def __init__(self, manifest: PackageManifest, config_override: dict | None = None): self._manifest = manifest self._config_override = config_override or {} @@ -1985,10 +1985,10 @@ def temp_config(self, **config): yield self.reconfigure(**existing_overrides) - def get(self, name: str) -> Optional[ContractContainer]: + def get(self, name: str) -> ContractContainer | None: return self.contracts.get(name) - def unpack(self, destination: Path, config_override: Optional[dict] = None) -> "LocalProject": + def unpack(self, destination: Path, config_override: dict | None = None) -> "LocalProject": """ Unpack the project to a location using the information from the manifest. Converts a manifest-based project @@ -2129,7 +2129,7 @@ def sources(self) -> dict[str, Source]: return self.manifest.sources or {} def load_contracts( - self, *source_ids: Union[str, Path], use_cache: bool = True + self, *source_ids: str | Path, use_cache: bool = True ) -> dict[str, ContractContainer]: result = { ct.name: ct @@ -2392,9 +2392,9 @@ class LocalProject(Project): def __init__( self, - path: Union[Path, str], - manifest_path: Optional[Path] = None, - config_override: Optional[dict] = None, + path: Path | str, + manifest_path: Path | None = None, + config_override: dict | None = None, ) -> None: self._session_source_change_check: set[str] = set() @@ -2596,7 +2596,7 @@ def project_api(self) -> ProjectAPI: # If we get here, there are more than 1 project types we should use. return MultiProject(apis=valid_apis, path=self._base_path) - def _get_ape_project_api(self) -> Optional[ApeProject]: + def _get_ape_project_api(self) -> ApeProject | None: if instance := ApeProject.attempt_validate(path=self._base_path): return cast(ApeProject, instance) @@ -2647,7 +2647,7 @@ def deployments(self) -> DeploymentManager: return DeploymentManager(self) @property - def exclusions(self) -> set[Union[str, Pattern]]: + def exclusions(self) -> set[str | Pattern]: """ Source-file exclusion glob patterns. """ @@ -2727,7 +2727,7 @@ def isolate_in_tempdir(self, **config_override) -> Iterator["LocalProject"]: project.manifest.sources = sources yield project - def unpack(self, destination: Path, config_override: Optional[dict] = None) -> "LocalProject": + def unpack(self, destination: Path, config_override: dict | None = None) -> "LocalProject": config_override = {**self._config_override, **(config_override or {})} def copytree(src, dst): @@ -2817,9 +2817,9 @@ def update_manifest(self, **kwargs): def load_contracts( self, - *source_ids: Union[str, Path], + *source_ids: str | Path, use_cache: bool = True, - excluded_compilers: Optional[list[str]] = None, + excluded_compilers: list[str] | None = None, ) -> dict[str, ContractContainer]: paths: Iterable[Path] starting: dict[str, ContractContainer] = {} @@ -2879,7 +2879,7 @@ def clean(self): self.sources._path_cache = None self._clear_cached_config() - def chdir(self, path: Optional[Path] = None): + def chdir(self, path: Path | None = None): """ Change the local project to the new path. @@ -2964,7 +2964,7 @@ def _clear_cached_config(self): if "config" in self.__dict__: del self.__dict__["config"] - def _create_contract_source(self, contract_type: ContractType) -> Optional[ContractSource]: + def _create_contract_source(self, contract_type: ContractType) -> ContractSource | None: if not (source_id := contract_type.source_id): return None @@ -3003,7 +3003,7 @@ def _update_contract_types(self, contract_types: dict[str, ContractType]): def _find_directory_with_extension( path: Path, extensions: set[str], recurse: bool = True -) -> Optional[Path]: +) -> Path | None: if not path.is_dir(): return None diff --git a/src/ape/managers/query.py b/src/ape/managers/query.py index 86cd971b2c..f25ccbc88e 100644 --- a/src/ape/managers/query.py +++ b/src/ape/managers/query.py @@ -3,7 +3,6 @@ from collections.abc import Iterator from functools import cached_property, singledispatchmethod from itertools import tee -from typing import Optional from ape.api.query import ( AccountTransactionQuery, @@ -32,11 +31,11 @@ def __init__(self): self.supports_contract_creation = None @singledispatchmethod - def estimate_query(self, query: QueryType) -> Optional[int]: # type: ignore + def estimate_query(self, query: QueryType) -> int | None: # type: ignore return None # can't handle this query @estimate_query.register - def estimate_block_query(self, query: BlockQuery) -> Optional[int]: + def estimate_block_query(self, query: BlockQuery) -> int | None: # NOTE: Very loose estimate of 100ms per block return (1 + query.stop_block - query.start_block) * 100 @@ -135,12 +134,12 @@ def _suggest_engines(self, engine_selection): def query( self, query: QueryType, - engine_to_use: Optional[str] = None, + engine_to_use: str | None = None, ) -> Iterator[BaseInterfaceModel]: """ Args: query (``QueryType``): The type of query to execute - engine_to_use (Optional[str]): Short-circuit selection logic using + engine_to_use (str | None): Short-circuit selection logic using a specific engine. Defaults is set by performance-based selection logic. Raises: diff --git a/src/ape/plugins/_utils.py b/src/ape/plugins/_utils.py index db2d19cc1a..69317b3199 100644 --- a/src/ape/plugins/_utils.py +++ b/src/ape/plugins/_utils.py @@ -4,7 +4,7 @@ from enum import Enum from functools import cached_property from shutil import which -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from urllib.parse import urlparse import click @@ -223,7 +223,7 @@ def from_package_names( cls, packages: Iterable[str], include_available: bool = True, - trusted_list: Optional[Iterable] = None, + trusted_list: Iterable | None = None, ) -> "PluginMetadataList": PluginMetadataList.model_rebuild() core = PluginGroup(plugin_type=PluginType.CORE) @@ -260,7 +260,7 @@ def __str__(self) -> str: def to_str( self, - include: Optional[Sequence[PluginType]] = None, + include: Sequence[PluginType] | None = None, include_version: bool = True, output_format: OutputFormat = OutputFormat.DEFAULT, ) -> str: @@ -274,7 +274,7 @@ def all_plugins(self) -> Iterator["PluginMetadata"]: yield from self.installed.plugins.values() yield from self.third_party.plugins.values() - def get_plugin(self, name: str, check_available: bool = True) -> Optional["PluginMetadata"]: + def get_plugin(self, name: str, check_available: bool = True) -> "PluginMetadata | None": name = name if name.startswith("ape_") else f"ape_{name}" if name in self.core.plugins: return self.core.plugins[name] @@ -301,7 +301,7 @@ class PluginMetadata(BaseInterfaceModel): name: str """The name of the plugin, such as ``trezor``.""" - version: Optional[str] = None + version: str | None = None """The version requested, if there is one.""" pip_command: list[str] = PIP_COMMAND @@ -366,7 +366,7 @@ def module_name(self) -> str: return f"ape_{self.name.replace('-', '_')}" @cached_property - def current_version(self) -> Optional[str]: + def current_version(self) -> str | None: """ The version currently installed if there is one. """ @@ -473,7 +473,7 @@ def check_installed(self, use_cache: bool = True) -> bool: return any(n == self.package_name for n in get_plugin_dists()) - def check_trusted(self, use_web: bool = True, trusted_list: Optional[Iterable] = None) -> bool: + def check_trusted(self, use_web: bool = True, trusted_list: Iterable | None = None) -> bool: if use_web: return self.is_available @@ -485,8 +485,8 @@ def check_trusted(self, use_web: bool = True, trusted_list: Optional[Iterable] = def prepare_package_manager_args( self, verb: str, - python_location: Optional[str] = None, - extra_args: Optional[Iterable[str]] = None, + python_location: str | None = None, + extra_args: Iterable[str] | None = None, ) -> list[str]: """ Build command arguments for pip or uv package managers. @@ -514,8 +514,8 @@ def _prepare_install( self, upgrade: bool = False, skip_confirmation: bool = False, - python_location: Optional[str] = None, - ) -> Optional[dict[str, Any]]: + python_location: str | None = None, + ) -> dict[str, Any] | None: # NOTE: Internal and only meant to be called by the CLI. if self.in_core: logger.error(f"Cannot install core 'ape' plugin '{self.name}'.") @@ -568,7 +568,7 @@ def _prepare_install( ) return None - def _get_uninstall_args(self, python_location: Optional[str]) -> list[str]: + def _get_uninstall_args(self, python_location: str | None) -> list[str]: arguments = self.prepare_package_manager_args( verb="uninstall", python_location=python_location ) @@ -640,7 +640,7 @@ def _log_modify_failed(self, verb: str): logger.error(f"Failed to {verb} plugin '{self._plugin}.") -def _split_name_and_version(value: str) -> tuple[str, Optional[str]]: +def _split_name_and_version(value: str) -> tuple[str, str | None]: if "@" in value: parts = [x for x in value.split("@") if x] return parts[0], "@".join(parts[1:]) @@ -692,9 +692,9 @@ def plugin_names(self) -> list[str]: def to_str( self, - max_length: Optional[int] = None, + max_length: int | None = None, include_version: bool = True, - output_format: Optional[OutputFormat] = OutputFormat.DEFAULT, + output_format: OutputFormat | None = OutputFormat.DEFAULT, ) -> str: output_format = output_format or OutputFormat.DEFAULT if output_format in (OutputFormat.DEFAULT, OutputFormat.PREFIXED): @@ -711,7 +711,7 @@ def to_str( def _get_default_formatted_str( self, - max_length: Optional[int] = None, + max_length: int | None = None, include_version: bool = True, include_prefix: bool = False, ) -> str: @@ -736,7 +736,7 @@ def _get_default_formatted_str( def _get_freeze_formatted_str( self, - max_length: Optional[int] = None, + max_length: int | None = None, include_version: bool = True, include_prefix: bool = False, ) -> str: @@ -767,9 +767,9 @@ class ApePluginsRepr: def __init__( self, metadata: PluginMetadataList, - include: Optional[Sequence[PluginType]] = None, + include: Sequence[PluginType] | None = None, include_version: bool = True, - output_format: Optional[OutputFormat] = None, + output_format: OutputFormat | None = None, ): self.include = include or (PluginType.INSTALLED, PluginType.THIRD_PARTY) self.metadata = metadata diff --git a/src/ape/pytest/config.py b/src/ape/pytest/config.py index f27bc37449..d6bb32e6b2 100644 --- a/src/ape/pytest/config.py +++ b/src/ape/pytest/config.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from ape.utils.basemodel import ManagerAccessMixin @@ -78,11 +78,11 @@ def track_coverage(self) -> bool: return self.pytest_config.getoption("--coverage") or self.ape_test_config.coverage.track @property - def xml_coverage(self) -> Union[bool, dict]: + def xml_coverage(self) -> bool | dict: return self.ape_test_config.coverage.reports.xml @property - def html_coverage(self) -> Union[bool, dict]: + def html_coverage(self) -> bool | dict: return self.ape_test_config.coverage.reports.html @cached_property @@ -114,7 +114,7 @@ def gas_exclusions(self) -> list["ContractFunctionPath"]: def coverage_exclusions(self) -> list["ContractFunctionPath"]: return _get_config_exclusions(self.ape_test_config.coverage) - def get_pytest_plugin(self, name: str) -> Optional[Any]: + def get_pytest_plugin(self, name: str) -> Any | None: if self.pytest_config.pluginmanager.has_plugin(name): return self.pytest_config.pluginmanager.get_plugin(name) diff --git a/src/ape/pytest/contextmanagers.py b/src/ape/pytest/contextmanagers.py index 905d1baebb..afc858f927 100644 --- a/src/ape/pytest/contextmanagers.py +++ b/src/ape/pytest/contextmanagers.py @@ -1,13 +1,13 @@ import re from re import Pattern -from typing import Optional, Union +from typing import TYPE_CHECKING from ethpm_types.abi import ErrorABI from ape.exceptions import ContractLogicError, CustomError, TransactionError from ape.utils.basemodel import ManagerAccessMixin -_RevertMessage = Union[str, re.Pattern] +_RevertMessage = str | re.Pattern class RevertInfo: @@ -24,14 +24,14 @@ class RevertInfo: class RevertsContextManager(ManagerAccessMixin): def __init__( self, - expected_message: Optional[Union[_RevertMessage, type[CustomError], ErrorABI]] = None, - dev_message: Optional[_RevertMessage] = None, + expected_message: _RevertMessage | type[CustomError] | ErrorABI | None = None, + dev_message: _RevertMessage | None = None, **error_inputs, ): self.expected_message = expected_message self.dev_message = dev_message self.error_inputs = error_inputs - self.revert_info: Optional[RevertInfo] = None + self.revert_info: RevertInfo | None = None def _check_dev_message(self, exception: ContractLogicError): """ @@ -103,7 +103,7 @@ def _check_expected_message(self, exception: ContractLogicError): raise AssertionError(f"{assertion_error_prefix} but got '{actual}'.") - def _check_custom_error(self, exception: Union[CustomError]): + def _check_custom_error(self, exception: CustomError): # perf: avoid loading from contracts namespace until needed. from ape.contracts import ContractInstance diff --git a/src/ape/pytest/coverage.py b/src/ape/pytest/coverage.py index cf00cc472b..9c65baa2f6 100644 --- a/src/ape/pytest/coverage.py +++ b/src/ape/pytest/coverage.py @@ -1,6 +1,6 @@ from collections.abc import Iterable from pathlib import Path -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING, Callable import click @@ -24,13 +24,13 @@ class CoverageData(ManagerAccessMixin): def __init__( self, project: "ProjectManager", - sources: Union[Iterable["ContractSource"], Callable[[], Iterable["ContractSource"]]], + sources: Iterable["ContractSource"] | Callable[[], Iterable["ContractSource"]], ): self.project = project - self._sources: Union[Iterable[ContractSource], Callable[[], Iterable[ContractSource]]] = ( + self._sources: Iterable[ContractSource] | Callable[[], Iterable[ContractSource]] = ( sources ) - self._report: Optional[CoverageReport] = None + self._report: CoverageReport | None = None @property def sources(self) -> list["ContractSource"]: @@ -144,8 +144,8 @@ class CoverageTracker(ManagerAccessMixin): def __init__( self, config_wrapper: "ConfigWrapper", - project: Optional["ProjectManager"] = None, - output_path: Optional[Path] = None, + project: "ProjectManager | None" = None, + output_path: Path | None = None, ): self.config_wrapper = config_wrapper self._project = project or self.local_project @@ -159,10 +159,10 @@ def __init__( self._output_path = Path.cwd() # Data gets initialized lazily (if coverage is needed). - self._data: Optional[CoverageData] = None + self._data: CoverageData | None = None @property - def data(self) -> Optional[CoverageData]: + def data(self) -> CoverageData | None: if not self.enabled: return None @@ -188,8 +188,8 @@ def reset(self): def cover( self, traceback: "SourceTraceback", - contract: Optional[str] = None, - function: Optional[str] = None, + contract: str | None = None, + function: str | None = None, ): """ Track the coverage from the given source traceback. @@ -208,9 +208,9 @@ def cover( if not self.data: return - last_path: Optional[Path] = None + last_path: Path | None = None last_pcs: set[int] = set() - last_call: Optional[str] = None + last_call: str | None = None main_fn = None if (contract and not function) or (function and not contract): @@ -265,9 +265,9 @@ def cover( def _cover( self, control_flow: "ControlFlow", - last_path: Optional[Path] = None, - last_pcs: Optional[set[int]] = None, - last_call: Optional[str] = None, + last_path: Path | None = None, + last_pcs: set[int] | None = None, + last_call: str | None = None, ) -> tuple[set[int], list[str]]: if not self.data or control_flow.source_path is None: return set(), [] diff --git a/src/ape/pytest/fixtures.py b/src/ape/pytest/fixtures.py index e8ef000e63..fea57c0a0d 100644 --- a/src/ape/pytest/fixtures.py +++ b/src/ape/pytest/fixtures.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field from fnmatch import fnmatch from functools import cached_property, singledispatchmethod -from typing import TYPE_CHECKING, ClassVar, Optional +from typing import TYPE_CHECKING, ClassVar import pytest from eth_utils import to_hex @@ -105,10 +105,10 @@ def cache_fixtures(self, item) -> "FixtureMap": return fixture_map - def get_fixture_scope(self, fixture_name: str) -> Optional[Scope]: + def get_fixture_scope(self, fixture_name: str) -> Scope | None: return self._fixture_name_to_info.get(fixture_name, {}).get("scope") - def is_stateful(self, name: str) -> Optional[bool]: + def is_stateful(self, name: str) -> bool | None: if name in self._stateful_fixtures_cache: # Used `@ape.fixture(chain_isolation=) # Or we already calculated. @@ -159,7 +159,7 @@ def add_fixture_info(self, name: str, **info): **info, } - def _get_cached_fixtures(self, nodeid: str) -> Optional["FixtureMap"]: + def _get_cached_fixtures(self, nodeid: str) -> "FixtureMap | None": return self._nodeid_to_fixture_map.get(nodeid) def needs_rebase(self, new_fixtures: list[str], snapshot: "Snapshot") -> bool: @@ -203,7 +203,7 @@ def rebase(self, scope: Scope, fixtures: "FixtureMap"): if rebase.return_scope is not None: log = f"{log} scope={rebase.return_scope}" - def _get_rebase(self, scope: Scope) -> Optional[FixtureRebase]: + def _get_rebase(self, scope: Scope) -> FixtureRebase | None: # Check for fixtures that are now invalid. For example, imagine a session # fixture comes into play after the module snapshot has been set. # Once we restore the module's state and move to the next module, @@ -239,7 +239,7 @@ def _get_rebase(self, scope: Scope) -> Optional[FixtureRebase]: class FixtureMap(dict[Scope, list[str]]): def __init__(self, item): self._item = item - self._parametrized_names: Optional[list[str]] = None + self._parametrized_names: list[str] | None = None super().__init__( { Scope.SESSION: [], @@ -506,7 +506,7 @@ class Snapshot: scope: Scope """Corresponds to fixture scope.""" - identifier: Optional["SnapshotID"] = None + identifier: "SnapshotID | None" = None """Snapshot ID taken before the peer-fixtures in the same scope.""" fixtures: list = field(default_factory=list) @@ -532,7 +532,7 @@ def __init__(self): } ) - def get_snapshot_id(self, scope: Scope) -> Optional["SnapshotID"]: + def get_snapshot_id(self, scope: Scope) -> "SnapshotID | None": return self[scope].identifier def set_snapshot_id(self, scope: Scope, snapshot_id: "SnapshotID"): @@ -557,7 +557,7 @@ def __init__( self, config_wrapper: "ConfigWrapper", receipt_capture: "ReceiptCapture", - chain_snapshots: Optional[dict] = None, + chain_snapshots: dict | None = None, ): self.config_wrapper = config_wrapper self.receipt_capture = receipt_capture @@ -620,7 +620,7 @@ def set_snapshot(self, scope: Scope): self.snapshots.set_snapshot_id(scope, snapshot_id) @allow_disconnected - def take_snapshot(self) -> Optional["SnapshotID"]: + def take_snapshot(self) -> "SnapshotID | None": try: return self.chain_manager.snapshot() except NotImplementedError: @@ -746,11 +746,11 @@ def clear(self): self.enter_blocks = [] @allow_disconnected - def _get_block_number(self) -> Optional[int]: + def _get_block_number(self) -> int | None: return self.provider.get_block("latest").number def _exclude_from_gas_report( - self, contract_name: str, method_name: Optional[str] = None + self, contract_name: str, method_name: str | None = None ) -> bool: """ Helper method to determine if a certain contract / method combination should be @@ -769,7 +769,7 @@ def _exclude_from_gas_report( return False -def fixture(chain_isolation: Optional[bool], **kwargs): +def fixture(chain_isolation: bool | None, **kwargs): """ A thin-wrapper around ``@pytest.fixture`` with extra capabilities. Set ``chain_isolation`` to ``False`` to signal to Ape that this fixture's diff --git a/src/ape/pytest/gas.py b/src/ape/pytest/gas.py index f9bbd8fdcb..dd8750a271 100644 --- a/src/ape/pytest/gas.py +++ b/src/ape/pytest/gas.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from evm_trace.gas import merge_reports @@ -23,7 +23,7 @@ class GasTracker(ManagerAccessMixin): def __init__(self, config_wrapper: "ConfigWrapper"): self.config_wrapper = config_wrapper - self.session_gas_report: Optional[GasReport] = None + self.session_gas_report: GasReport | None = None @property def enabled(self) -> bool: diff --git a/src/ape/pytest/runners.py b/src/ape/pytest/runners.py index c3bb981a96..7e5f3e0796 100644 --- a/src/ape/pytest/runners.py +++ b/src/ape/pytest/runners.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import click import pytest @@ -28,7 +28,7 @@ def __init__( receipt_capture: "ReceiptCapture", gas_tracker: "GasTracker", coverage_tracker: "CoverageTracker", - fixture_manager: Optional["FixtureManager"] = None, + fixture_manager: "FixtureManager | None" = None, ): self.config_wrapper = config_wrapper self.isolation_manager = isolation_manager @@ -56,7 +56,7 @@ def _provider_context(self) -> "ProviderContextManager": return self.network_manager.parse_network_choice(self.config_wrapper.network) @property - def _coverage_report(self) -> Optional["CoverageReport"]: + def _coverage_report(self) -> "CoverageReport | None": return self.coverage_tracker.data.report if self.coverage_tracker.data else None def pytest_exception_interact(self, report, call): diff --git a/src/ape/types/address.py b/src/ape/types/address.py index 785c0d3680..7a782d1490 100644 --- a/src/ape/types/address.py +++ b/src/ape/types/address.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Annotated, Any, Optional, Union +from typing import TYPE_CHECKING, Annotated, Any from eth_pydantic_types import Address as _Address from eth_pydantic_types import HexBytes20, HexStr20 @@ -10,7 +10,7 @@ from pydantic_core.core_schema import ValidationInfo -RawAddress = Union[str, int, HexStr20, HexBytes20] +RawAddress = str | int | HexStr20 | HexBytes20 """ A raw data-type representation of an address. """ @@ -26,7 +26,7 @@ class _AddressValidator(_Address, ManagerAccessMixin): """ @classmethod - def __eth_pydantic_validate__(cls, value: Any, info: Optional["ValidationInfo"] = None) -> str: + def __eth_pydantic_validate__(cls, value: Any, info: "ValidationInfo | None" = None) -> str: if type(value) in (list, tuple): return cls.conversion_manager.convert(value, list[AddressType]) diff --git a/src/ape/types/basic.py b/src/ape/types/basic.py index b5e54c815a..a18060b43b 100644 --- a/src/ape/types/basic.py +++ b/src/ape/types/basic.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Iterator, Sequence -from typing import Annotated, TypeVar, Union, overload +from typing import Annotated, TypeVar, overload from pydantic import BeforeValidator @@ -26,7 +26,7 @@ def _hex_int_validator(value, info): class _LazySequence(Sequence[_T]): - def __init__(self, generator: Union[Iterator[_T], Callable[[], Iterator[_T]]]): + def __init__(self, generator: Iterator[_T] | Callable[[], Iterator[_T]]): self._generator = generator self.cache: list = [] @@ -36,7 +36,7 @@ def __getitem__(self, index: int) -> _T: ... @overload def __getitem__(self, index: slice) -> Sequence[_T]: ... - def __getitem__(self, index: Union[int, slice]) -> Union[_T, Sequence[_T]]: + def __getitem__(self, index: int | slice) -> _T | Sequence[_T]: if isinstance(index, int): while len(self.cache) <= index: # Catch up the cache. diff --git a/src/ape/types/coverage.py b/src/ape/types/coverage.py index 614e67f579..fafe6d702d 100644 --- a/src/ape/types/coverage.py +++ b/src/ape/types/coverage.py @@ -2,7 +2,7 @@ from datetime import datetime from html.parser import HTMLParser from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from xml.dom.minidom import getDOMImplementation from xml.etree.ElementTree import Element, SubElement, tostring @@ -131,7 +131,7 @@ class CoverageStatement(BaseModel): type of check. """ - location: Optional[SourceLocation] = None + location: SourceLocation | None = None """ The location of the item (line, column, endline, endcolumn). If multiple PCs share an exact location, it is only tracked as one. @@ -147,7 +147,7 @@ class CoverageStatement(BaseModel): The times this node was hit. """ - tag: Optional[str] = None + tag: str | None = None """ An additional tag to mark this statement with. This is useful if the location field is empty. @@ -226,7 +226,7 @@ def model_dump(self, *args, **kwargs) -> dict: return attribs def profile_statement( - self, pc: int, location: Optional[SourceLocation] = None, tag: Optional[str] = None + self, pc: int, location: SourceLocation | None = None, tag: str | None = None ): """ Initialize a statement in the coverage profile with a hit count starting at zero. @@ -359,7 +359,7 @@ def include(self, name: str, full_name: str) -> FunctionCoverage: self.functions.append(func_cov) return func_cov - def get_function(self, full_name: str) -> Optional[FunctionCoverage]: + def get_function(self, full_name: str) -> FunctionCoverage | None: for func in self.functions: if func.full_name == full_name: return func @@ -1014,7 +1014,7 @@ def model_dump(self, *args, **kwargs) -> dict: return attribs - def get_source_coverage(self, source_id: str) -> Optional[ContractSourceCoverage]: + def get_source_coverage(self, source_id: str) -> ContractSourceCoverage | None: for project in self.projects: for src in project.sources: if src.source_id == source_id: diff --git a/src/ape/types/events.py b/src/ape/types/events.py index 63ca107261..8309d0ef5c 100644 --- a/src/ape/types/events.py +++ b/src/ape/types/events.py @@ -1,6 +1,6 @@ from collections.abc import Iterable, Iterator, Sequence from functools import cached_property -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from eth_pydantic_types import HexBytes, HexStr from eth_utils import encode_hex, is_hex, keccak, to_hex @@ -20,7 +20,7 @@ from ape.contracts import ContractEvent -TopicFilter = Sequence[Union[Optional[HexStr], Sequence[Optional[HexStr]]]] +TopicFilter = Sequence[HexStr | None | Sequence[HexStr | None]] class LogFilter(BaseModel): @@ -28,7 +28,7 @@ class LogFilter(BaseModel): events: list[EventABI] = [] topic_filter: TopicFilter = [] start_block: int = 0 - stop_block: Optional[int] = None # Use block height + stop_block: int | None = None # Use block height selectors: dict[str, EventABI] = {} @model_validator(mode="before") @@ -66,9 +66,9 @@ def model_dump(self, *args, **kwargs): @classmethod def from_event( cls, - event: Union[EventABI, "ContractEvent"], - search_topics: Optional[dict[str, Any]] = None, - addresses: Optional[list[AddressType]] = None, + event: "EventABI | ContractEvent", + search_topics: dict[str, Any] | None = None, + addresses: list[AddressType] | None = None, start_block=None, stop_block=None, ): @@ -168,7 +168,7 @@ def __init__(self, *args, **kwargs): log_index: HexInt """The index of the log on the transaction.""" - transaction_index: Optional[HexInt] = None + transaction_index: HexInt | None = None """ The index of the transaction's position when the log was created. Is `None` when from the pending block. @@ -290,7 +290,7 @@ def __eq__(self, other: Any) -> bool: # call __eq__ on parent class return super().__eq__(other) - def get(self, item: str, default: Optional[Any] = None) -> Any: + def get(self, item: str, default: Any | None = None) -> Any: return self.event_arguments.get(item, default) diff --git a/src/ape/types/gas.py b/src/ape/types/gas.py index f7c87244df..be4c323fd7 100644 --- a/src/ape/types/gas.py +++ b/src/ape/types/gas.py @@ -1,4 +1,4 @@ -from typing import Literal, Union +from typing import Literal from pydantic import BaseModel, field_validator @@ -22,7 +22,7 @@ def validate_multiplier(cls, value): return value -GasLimit = Union[Literal["auto", "max"], int, str, AutoGasLimit] +GasLimit = Literal["auto", "max"] | int | str | AutoGasLimit """ A value you can give to Ape for handling gas-limit calculations. ``"auto"`` refers to automatically figuring out the gas, diff --git a/src/ape/types/private_mempool.py b/src/ape/types/private_mempool.py index 41567a4095..00ab9ef1b9 100644 --- a/src/ape/types/private_mempool.py +++ b/src/ape/types/private_mempool.py @@ -5,7 +5,6 @@ from collections.abc import Iterator from enum import Enum -from typing import Optional, Union from eth_pydantic_types.hex import HexBytes, HexBytes32, HexInt from ethpm_types.abi import EventABI @@ -90,12 +89,12 @@ class Privacy(BaseModel): Preferences on what data should be shared about the bundle and its transactions """ - hints: Optional[list[PrivacyHint]] = None + hints: list[PrivacyHint] | None = None """ Hints on what data should be shared about the bundle and its transactions. """ - builders: Optional[list[str]] = None + builders: list[str] | None = None """ Names of the builders that should be allowed to see the bundle/transaction. """ @@ -111,7 +110,7 @@ class Inclusion(BaseModel): The first block the bundle is valid for. """ - max_block: Union[HexInt, None] = Field(None, alias="maxBlock") + max_block: HexInt | None = Field(None, alias="maxBlock") """ The last block the bundle is valid for. """ @@ -179,13 +178,13 @@ class Validity(BaseModel): Requirements for the bundle to be included in the block. """ - refund: Union[list[Refund], None] = None + refund: list[Refund] | None = None """ Specifies the minimum percent of a given bundle's earnings to redistribute for it to be included in a builder's block. """ - refund_config: Optional[list[RefundConfig]] = Field(None, alias="refundConfig") + refund_config: list[RefundConfig] | None = Field(None, alias="refundConfig") """ Specifies what addresses should receive what percent of the overall refund for this bundle, if it is enveloped by another bundle (e.g. a searcher backrun). @@ -207,17 +206,17 @@ class Bundle(BaseModel): Data used by block builders to check if the bundle should be considered for inclusion. """ - body: list[Union[BundleHashItem, BundleTxItem, BundleNestedItem]] + body: list[BundleHashItem | BundleTxItem | BundleNestedItem] """ The transactions to include in the bundle. """ - validity: Optional[Validity] = None + validity: Validity | None = None """ Requirements for the bundle to be included in the block. """ - privacy: Optional[Privacy] = None + privacy: Privacy | None = None """ Preferences on what data should be shared about the bundle and its transactions """ @@ -226,11 +225,11 @@ class Bundle(BaseModel): def build_for_block( cls, block: HexInt, - max_block: Optional[HexInt] = None, - version: Optional[ProtocolVersion] = None, - body: Optional[list[Union[BundleHashItem, BundleTxItem, BundleNestedItem]]] = None, - validity: Optional[Validity] = None, - privacy: Optional[Privacy] = None, + max_block: HexInt | None = None, + version: ProtocolVersion | None = None, + body: list[BundleHashItem | BundleTxItem | BundleNestedItem] | None = None, + validity: Validity | None = None, + privacy: Privacy | None = None, ) -> "Bundle": return cls( version=version or ProtocolVersion.V0_1, @@ -257,12 +256,12 @@ class SimBundleLogs(BaseModel): Logs returned by `mev_simBundle`. """ - tx_logs: Optional[list[dict]] = Field(None, alias="txLogs") + tx_logs: list[dict] | None = Field(None, alias="txLogs") """ Logs for transactions in bundle. """ - bundle_logs: Optional[list["SimBundleLogs"]] = Field(None, alias="bundleLogs") + bundle_logs: list["SimBundleLogs"] | None = Field(None, alias="bundleLogs") """ Logs for bundles in bundle. """ @@ -278,12 +277,12 @@ class SimulationReport(BaseModel): Whether the simulation was successful. """ - error: Optional[str] = None + error: str | None = None """ Error message if the simulation failed. """ - state_block: Optional[HexInt] = Field(None, alias="stateBlock") + state_block: HexInt | None = Field(None, alias="stateBlock") """ The block number of the simulated block. """ @@ -298,27 +297,27 @@ class SimulationReport(BaseModel): The profit of the simulated block. """ - refundable_value: Optional[HexInt] = Field(None, alias="refundableValue") + refundable_value: HexInt | None = Field(None, alias="refundableValue") """ The refundable value of the simulated block. """ - gas_used: Optional[HexInt] = Field(None, alias="gasUsed") + gas_used: HexInt | None = Field(None, alias="gasUsed") """ The gas used by the simulated block. """ - logs: Optional[list[SimBundleLogs]] = None + logs: list[SimBundleLogs] | None = None """ Logs returned by `mev_simBundle`. """ - exec_error: Optional[str] = Field(None, alias="execError") + exec_error: str | None = Field(None, alias="execError") """ Error message if the bundle execution failed. """ - revert: Optional[HexBytes] = None + revert: HexBytes | None = None """ Contains the return data if the transaction reverted """ diff --git a/src/ape/types/signatures.py b/src/ape/types/signatures.py index b7efe15687..4f22c2c103 100644 --- a/src/ape/types/signatures.py +++ b/src/ape/types/signatures.py @@ -1,5 +1,5 @@ from collections.abc import Iterator -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from eth_account import Account from eth_account.messages import SignableMessage @@ -41,7 +41,7 @@ def signable_message_repr(msg) -> str: SignableMessage.__repr__ = signable_message_repr # type: ignore[method-assign] -def _bytes_to_human_str(bytes_value: bytes) -> Optional[str]: +def _bytes_to_human_str(bytes_value: bytes) -> str | None: try: # Try as text return bytes_value.decode("utf8") @@ -84,7 +84,7 @@ class _Signature: The signature proof point (``s``) in an ECDSA signature. """ - def __iter__(self) -> Iterator[Union[int, bytes]]: + def __iter__(self) -> Iterator[int | bytes]: # NOTE: Allows tuple destructuring yield self.v yield self.r diff --git a/src/ape/types/trace.py b/src/ape/types/trace.py index 1ac708edf9..517f13d4ae 100644 --- a/src/ape/types/trace.py +++ b/src/ape/types/trace.py @@ -1,7 +1,7 @@ from collections.abc import Iterator from itertools import chain from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from eth_pydantic_types import HexBytes from ethpm_types import ASTNode, BaseModel @@ -45,7 +45,7 @@ class ControlFlow(BaseModel): The defining closure, such as a function or module, of the code sequence. """ - source_path: Optional[Path] = None + source_path: Path | None = None """ The path to the local contract file. Only exists when is from a local contract. @@ -99,7 +99,7 @@ def source_statements(self) -> list[SourceStatement]: return [x for x in self.statements if isinstance(x, SourceStatement)] @property - def begin_lineno(self) -> Optional[int]: + def begin_lineno(self) -> int | None: """ The first line number in the sequence. """ @@ -107,7 +107,7 @@ def begin_lineno(self) -> Optional[int]: return stmts[0].begin_lineno if stmts else None @property - def ws_begin_lineno(self) -> Optional[int]: + def ws_begin_lineno(self) -> int | None: """ The first line number in the sequence, including whitespace. """ @@ -146,7 +146,7 @@ def source_header(self) -> str: return result.strip() @property - def end_lineno(self) -> Optional[int]: + def end_lineno(self) -> int | None: """ The last line number. """ @@ -164,8 +164,8 @@ def pcs(self) -> set[int]: def extend( self, location: "SourceLocation", - pcs: Optional[set[int]] = None, - ws_start: Optional[int] = None, + pcs: set[int] | None = None, + ws_start: int | None = None, ): """ Extend this node's content with other content that follows it directly. @@ -176,8 +176,8 @@ def extend( Args: location (SourceLocation): The location of the content, in the form (lineno, col_offset, end_lineno, end_coloffset). - pcs (Optional[set[int]]): The PC values of the statements. - ws_start (Optional[int]): Optionally provide a white-space starting point + pcs (set[int] | None): The PC values of the statements. + ws_start (int | None): Optionally provide a white-space starting point to back-fill. """ @@ -254,7 +254,7 @@ def format(self, use_arrow: bool = True) -> str: return content @property - def next_statement(self) -> Optional[SourceStatement]: + def next_statement(self) -> SourceStatement | None: """ Returns the next statement that _would_ execute if the program were to progress to the next line. @@ -296,7 +296,7 @@ class SourceTraceback(RootModel[list[ControlFlow]]): """ @classmethod - def create(cls, contract_source: ContractSource, trace: "TraceAPI", data: Union[HexBytes, str]): + def create(cls, contract_source: ContractSource, trace: "TraceAPI", data: HexBytes | str): # Use the trace as a 'ManagerAccessMixin'. compilers = trace.compiler_manager source_id = contract_source.source_id @@ -333,7 +333,7 @@ def __setitem__(self, key, value): return self.root.__setitem__(key, value) @property - def revert_type(self) -> Optional[str]: + def revert_type(self) -> str | None: """ The revert type, such as a builtin-error code or a user dev-message, if there is one. @@ -356,7 +356,7 @@ def extend(self, __iterable) -> None: self.root.extend(__iterable.root) @property - def last(self) -> Optional[ControlFlow]: + def last(self) -> ControlFlow | None: """ The last control flow in the traceback, if there is one. """ @@ -445,8 +445,8 @@ def add_jump( location: "SourceLocation", function: Function, depth: int, - pcs: Optional[set[int]] = None, - source_path: Optional[Path] = None, + pcs: set[int] | None = None, + source_path: Path | None = None, ): """ Add an execution sequence from a jump. @@ -454,10 +454,10 @@ def add_jump( Args: location (``SourceLocation``): The location to add. function (``Function``): The function executing. - source_path (Optional[``Path``]): The path of the source file. + source_path (Path | None): The path of the source file. depth (int): The depth of the function call in the call tree. - pcs (Optional[set[int]]): The program counter values. - source_path (Optional[``Path``]): The path of the source file. + pcs (set[int] | None): The program counter values. + source_path (Path | None): The path of the source file. """ asts = function.get_content_asts(location) @@ -470,13 +470,13 @@ def add_jump( ControlFlow.model_rebuild() self._add(asts, content, pcs, function, depth, source_path=source_path) - def extend_last(self, location: "SourceLocation", pcs: Optional[set[int]] = None): + def extend_last(self, location: "SourceLocation", pcs: set[int] | None = None): """ Extend the last node with more content. Args: location (``SourceLocation``): The location of the new content. - pcs (Optional[set[int]]): The PC values to add on. + pcs (set[int] | None): The PC values to add on. """ if not self.last: @@ -496,9 +496,9 @@ def add_builtin_jump( self, name: str, _type: str, - full_name: Optional[str] = None, - source_path: Optional[Path] = None, - pcs: Optional[set[int]] = None, + full_name: str | None = None, + source_path: Path | None = None, + pcs: set[int] | None = None, ): """ A convenience method for appending a control flow that happened @@ -508,9 +508,9 @@ def add_builtin_jump( Args: name (str): The name of the compiler built-in. _type (str): A str describing the type of check. - full_name (Optional[str]): A full-name ID. - source_path (Optional[Path]): The source file related, if there is one. - pcs (Optional[set[int]]): Program counter values mapping to this check. + full_name (str | None): A full-name ID. + source_path (Path | None): The source file related, if there is one. + pcs (set[int] | None): Program counter values mapping to this check. """ pcs = pcs or set() closure = Closure(name=name, full_name=full_name or name) @@ -528,7 +528,7 @@ def _add( pcs: set[int], function: Function, depth: int, - source_path: Optional[Path] = None, + source_path: Path | None = None, ): statement = SourceStatement(asts=asts, content=content, pcs=pcs) exec_sequence = ControlFlow( @@ -544,7 +544,7 @@ class ContractFunctionPath: """ contract_name: str - method_name: Optional[str] = None + method_name: str | None = None @classmethod def from_str(cls, value: str) -> "ContractFunctionPath": diff --git a/src/ape/types/units.py b/src/ape/types/units.py index 22a2c1b1e1..e4c6854dfc 100644 --- a/src/ape/types/units.py +++ b/src/ape/types/units.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from pydantic_core.core_schema import ( CoreSchema, @@ -52,7 +52,7 @@ def __get_pydantic_core_schema__(cls, value, handler=None) -> CoreSchema: ) @staticmethod - def _validate(value: Any, info: Optional[ValidationInfo] = None) -> "CurrencyValueComparable": + def _validate(value: Any, info: ValidationInfo | None = None) -> "CurrencyValueComparable": # NOTE: For some reason, for this to work, it has to happen # in an "after" validator, or else it always only `int` type on the model. if value is None: diff --git a/src/ape/types/vm.py b/src/ape/types/vm.py index 3b7181e663..ba71d1f884 100644 --- a/src/ape/types/vm.py +++ b/src/ape/types/vm.py @@ -1,15 +1,15 @@ -from typing import Any, Literal, Union +from typing import Any, Literal from eth_typing import HexStr from hexbytes import HexBytes -BlockID = Union[int, HexStr, HexBytes, Literal["earliest", "latest", "pending"]] +BlockID = int | HexStr | HexBytes | Literal["earliest", "latest", "pending"] """ An ID that can match a block, such as the literals ``"earliest"``, ``"latest"``, or ``"pending"`` as well as a block number or hash (HexBytes). """ -ContractCode = Union[str, bytes, HexBytes] +ContractCode = str | bytes | HexBytes """ A type that represents contract code, which can be represented in string, bytes, or HexBytes. """ diff --git a/src/ape/utils/_github.py b/src/ape/utils/_github.py index a9f291e985..e5cd826f6c 100644 --- a/src/ape/utils/_github.py +++ b/src/ape/utils/_github.py @@ -6,7 +6,7 @@ from collections.abc import Iterator from io import BytesIO from pathlib import Path -from typing import Any, Optional, Union +from typing import Any from requests import HTTPError, Session from requests.adapters import HTTPAdapter @@ -26,7 +26,7 @@ def git(self) -> str: raise ProjectError("`git` not installed.") - def clone(self, url: str, target_path: Optional[Path] = None, branch: Optional[str] = None): + def clone(self, url: str, target_path: Path | None = None, branch: str | None = None): command = [self.git, "-c", "advice.detachedHead=false", "clone", url] if target_path: @@ -68,7 +68,7 @@ class _GithubClient: FRAMEWORK_NAME = "ape" _repo_cache: dict[str, dict] = {} - def __init__(self, session: Optional[Session] = None): + def __init__(self, session: Session | None = None): if session: # NOTE: Mostly allowed for testing purposes. self.__session = session @@ -171,8 +171,8 @@ def clone_repo( self, org_name: str, repo_name: str, - target_path: Union[str, Path], - branch: Optional[str] = None, + target_path: str | Path, + branch: str | None = None, scheme: str = "https", ): repo = self.get_repo(org_name, repo_name) @@ -196,7 +196,7 @@ def clone_repo( self.git.clone(url, branch=branch, target_path=target_path) def download_package( - self, org_name: str, repo_name: str, version: str, target_path: Union[Path, str] + self, org_name: str, repo_name: str, version: str, target_path: Path | str ): target_path = Path(target_path) # Handles str if not target_path or not target_path.is_dir(): @@ -209,7 +209,7 @@ def download_package( ) # Use temporary path to isolate a package when unzipping - with tempfile.TemporaryDirectory() as tmp: + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmp: temp_path = Path(tmp) with zipfile.ZipFile(BytesIO(release_content)) as zf: zf.extractall(temp_path) @@ -223,7 +223,7 @@ def download_package( for source_file in package_path.iterdir(): shutil.move(str(source_file), str(target_path)) - def _get(self, url: str, params: Optional[dict] = None) -> Any: + def _get(self, url: str, params: dict | None = None) -> Any: return self._request("GET", url, params=params) def _request(self, method: str, url: str, **kwargs) -> Any: diff --git a/src/ape/utils/abi.py b/src/ape/utils/abi.py index 4230f1fcc4..b6871fc265 100644 --- a/src/ape/utils/abi.py +++ b/src/ape/utils/abi.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from dataclasses import make_dataclass from enum import Enum -from typing import Any, Optional, Union +from typing import Any from eth_abi import grammar from eth_abi.abi import decode @@ -53,12 +53,12 @@ def read_data_from_stream(self, stream): ) -def is_array(abi_type: Union[str, ABIType]) -> bool: +def is_array(abi_type: str | ABIType) -> bool: """ Returns ``True`` if the given type is a probably an array. Args: - abi_type (Union[str, ABIType]): The type to check. + abi_type (str | ABIType): The type to check. Returns: bool @@ -90,7 +90,7 @@ class StructParser: A utility class responsible for parsing structs out of values. """ - def __init__(self, method_abi: Union[ConstructorABI, MethodABI, EventABI]): + def __init__(self, method_abi: ConstructorABI | MethodABI | EventABI): self.abi = method_abi @property @@ -103,19 +103,19 @@ def default_name(self) -> str: name = self.abi.name if isinstance(self.abi, MethodABI) else "constructor" return f"{name}_return" - def encode_input(self, values: Union[list, tuple, dict]) -> Any: + def encode_input(self, values: list | tuple | dict) -> Any: """ Convert dicts and other objects to struct inputs. Args: - values (Union[list, tuple]): A list of input values. + values (list | tuple | dict): A list of input values. Returns: Any: The same input values only decoded into structs when applicable. """ return [self._encode(ipt, v) for ipt, v in zip(self.abi.inputs, values)] - def decode_input(self, values: Union[Sequence, dict[str, Any]]) -> Any: + def decode_input(self, values: Sequence | dict[str, Any]) -> Any: return ( self._decode(self.abi.inputs, values) if isinstance(self.abi, (EventABI, MethodABI)) @@ -159,14 +159,14 @@ def _encode(self, _type: ABIType, value: Any): return value - def decode_output(self, values: Union[list, tuple]) -> Any: + def decode_output(self, values: list | tuple) -> Any: """ Parse a list of output types and values into structs. Values are only altered when they are a struct. This method also handles structs within structs as well as arrays of structs. Args: - values (Union[list, tuple]): A list of output values. + values (list | tuple): A list of output values. Returns: Any: The same input values only decoded into structs when applicable. @@ -176,8 +176,8 @@ def decode_output(self, values: Union[list, tuple]) -> Any: def _decode( self, - _types: Union[Sequence[ABIType]], - values: Union[Sequence, dict[str, Any]], + _types: Sequence[ABIType], + values: Sequence | dict[str, Any], ): if is_struct(_types): return self._create_struct(_types[0], values) @@ -251,7 +251,7 @@ def _decode( return return_values - def _create_struct(self, out_abi: ABIType, out_value: Any) -> Optional[Any]: + def _create_struct(self, out_abi: ABIType, out_value: Any) -> Any | None: if not out_abi.components or not out_value[0]: # Likely an empty tuple or not a struct. return None @@ -284,7 +284,7 @@ def _parse_components(self, components: list[ABIType], values) -> list: return parsed_values -def is_struct(outputs: Union[ABIType, Sequence[ABIType]]) -> bool: +def is_struct(outputs: ABIType | Sequence[ABIType]) -> bool: """ Returns ``True`` if the given output is a struct. """ @@ -435,7 +435,7 @@ def reduce(struct) -> tuple: return struct_def(*output_values) -def is_dynamic_sized_type(abi_type: Union[ABIType, str]) -> bool: +def is_dynamic_sized_type(abi_type: ABIType | str) -> bool: parsed = grammar.parse(str(abi_type)) return parsed.is_dynamic @@ -455,7 +455,7 @@ def event_name(self): return self.abi.name def decode( - self, topics: list[str], data: Union[str, bytes], use_hex_on_fail: bool = False + self, topics: list[str], data: str | bytes, use_hex_on_fail: bool = False ) -> dict: decoded = {} for abi, topic_value in zip(self.topic_abi_types, topics[1:]): @@ -545,7 +545,7 @@ def _enrich_natspec(natspec: str) -> str: return re.sub(NATSPEC_KEY_PATTERN, replacement, natspec) -def encode_topics(abi: EventABI, topics: Optional[dict[str, Any]] = None) -> list[HexStr]: +def encode_topics(abi: EventABI, topics: dict[str, Any] | None = None) -> list[HexStr]: """ Encode the given topics using the given ABI. Useful for searching logs. diff --git a/src/ape/utils/basemodel.py b/src/ape/utils/basemodel.py index 923960d61f..f8d70af199 100644 --- a/src/ape/utils/basemodel.py +++ b/src/ape/utils/basemodel.py @@ -9,7 +9,7 @@ from importlib import import_module from pathlib import Path from sys import getrecursionlimit -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union +from typing import TYPE_CHECKING, Any, ClassVar from ethpm_types import BaseModel as EthpmTypesBaseModel from pydantic import BaseModel as RootBaseModel @@ -145,7 +145,7 @@ def my_function(self): accounts = self.account_manager # And so on! """ - _test_runner: ClassVar[Optional["PytestApeRunner"]] = None + _test_runner: ClassVar["PytestApeRunner | None"] = None @manager_access def account_manager(cls) -> "AccountManager": @@ -259,7 +259,7 @@ class BaseInterface(ManagerAccessMixin, ABC): """ -def _get_alt(name: str) -> Optional[str]: +def _get_alt(name: str) -> str | None: alt = None if ("-" not in name and "_" not in name) or ("-" in name and "_" in name): alt = None @@ -321,7 +321,7 @@ class ExtraModelAttributes(EthpmTypesBaseModel): we can show a more accurate exception message. """ - attributes: Union[Any, Callable[[], Any], Callable[[str], Any]] + attributes: Any | Callable[[], Any] | Callable[[str], Any] """The attributes. The following types are supported: 1. A model or dictionary to lookup attributes. @@ -337,7 +337,7 @@ class ExtraModelAttributes(EthpmTypesBaseModel): include_getitem: bool = False """Whether to use these in ``__getitem__``.""" - additional_error_message: Optional[str] = None + additional_error_message: str | None = None """ An additional error message to include at the end of the normal IndexError message. @@ -365,7 +365,7 @@ def __contains__(self, name: Any) -> bool: return False - def get(self, name: str) -> Optional[Any]: + def get(self, name: str) -> Any | None: """ Get an attribute. @@ -387,7 +387,7 @@ def get(self, name: str) -> Optional[Any]: return None - def _get(self, name: str) -> Optional[Any]: + def _get(self, name: str) -> Any | None: attrs = self._attrs() return attrs.get(name) if hasattr(attrs, "get") else getattr(attrs, name, None) @@ -415,9 +415,9 @@ class BaseModel(EthpmTypesBaseModel): def model_copy( self: "Model", *, - update: Optional[Mapping[str, Any]] = None, + update: Mapping[str, Any] | None = None, deep: bool = False, - cache_clear: Optional[Sequence[str]] = None, + cache_clear: Sequence[str] | None = None, ) -> "Model": result = super().model_copy(update=update, deep=deep) @@ -637,7 +637,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._path = path - def model_read_file(self, path: Optional[Path] = None) -> dict: + def model_read_file(self, path: Path | None = None) -> dict: """ Get the file's raw data. This is different from ``model_dump()`` because it reads directly from the file without validation. @@ -656,12 +656,12 @@ def _model_read_file(cls, path: Path) -> dict: return {} - def model_dump_file(self, path: Optional[Path] = None, **kwargs): + def model_dump_file(self, path: Path | None = None, **kwargs): """ Save this model to disk. Args: - path (Optional[Path]): Optionally provide the path now + path (Path | None): Optionally provide the path now if one wasn't declared at init time. If given a directory, saves the file in that dir with the name of class with a .json suffix. @@ -679,7 +679,7 @@ def model_validate_file(cls, path: Path, **kwargs): Validate a file. Args: - path (Optional[Path]): Optionally provide the path now + path (Path): Optionally provide the path now if one wasn't declared at init time. **kwargs: Extra kwargs to pass to ``.model_validate_json()``. """ @@ -688,7 +688,7 @@ def model_validate_file(cls, path: Path, **kwargs): model._path = path return model - def _get_path(self, path: Optional[Path] = None) -> Path: + def _get_path(self, path: Path | None = None) -> Path: if save_path := (path or self._path): return save_path diff --git a/src/ape/utils/misc.py b/src/ape/utils/misc.py index 6160e7d0b1..ddc7fef716 100644 --- a/src/ape/utils/misc.py +++ b/src/ape/utils/misc.py @@ -18,7 +18,7 @@ from importlib.metadata import PackageNotFoundError, distributions from importlib.metadata import version as version_metadata from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast +from typing import TYPE_CHECKING, Any, TypeVar, cast import yaml from eth_keys import keys # type: ignore @@ -71,7 +71,7 @@ @functools.cache -def _get_distributions(pkg_name: Optional[str] = None) -> list: +def _get_distributions(pkg_name: str | None = None) -> list: """ Get a mapping of top-level packages to their distributions. """ @@ -87,7 +87,7 @@ def _get_distributions(pkg_name: Optional[str] = None) -> list: return distros -def pragma_str_to_specifier_set(pragma_str: str) -> Optional[SpecifierSet]: +def pragma_str_to_specifier_set(pragma_str: str) -> SpecifierSet | None: """ Convert the given pragma str to a ``packaging.version.SpecifierSet`` if possible. @@ -96,7 +96,7 @@ def pragma_str_to_specifier_set(pragma_str: str) -> Optional[SpecifierSet]: pragma_str (str): The str to convert. Returns: - ``Optional[packaging.version.SpecifierSet]`` + SpecifierSet | None """ pragma_parts = iter([x.strip(" ,") for x in pragma_str.split(" ")]) @@ -266,7 +266,7 @@ def gas_estimation_error_message(tx_error: Exception) -> str: ) -def extract_nested_value(root: Mapping, *args: str) -> Optional[dict]: +def extract_nested_value(root: Mapping, *args: str) -> dict | None: """ Dig through a nested ``dict`` using the given keys and return the last-found object. @@ -472,7 +472,7 @@ def _dict_overlay(mapping: dict[str, Any], overlay: dict[str, Any], depth: int = return mapping -def log_instead_of_fail(default: Optional[Any] = None): +def log_instead_of_fail(default: Any | None = None): """ A decorator for logging errors instead of raising. This is useful for methods like __repr__ which shouldn't fail. @@ -500,7 +500,7 @@ def wrapped(*args, **kwargs): _MOD_T = TypeVar("_MOD_T") -def as_our_module(cls_or_def: _MOD_T, doc_str: Optional[str] = None) -> Optional[_MOD_T]: +def as_our_module(cls_or_def: _MOD_T, doc_str: str | None = None) -> _MOD_T | None: """ Ape sometimes reclaims definitions from other packages, such as class:`~ape.types.signatures.SignableMessage`). When doing so, the doc str diff --git a/src/ape/utils/os.py b/src/ape/utils/os.py index adc94ea2a7..01ffa191dd 100644 --- a/src/ape/utils/os.py +++ b/src/ape/utils/os.py @@ -12,7 +12,7 @@ from pathlib import Path from re import Pattern from tempfile import TemporaryDirectory, gettempdir -from typing import Any, Optional, Union +from typing import Any # TODO: This method is no longer needed since the dropping of 3.9 @@ -64,7 +64,7 @@ def get_relative_path(target: Path, anchor: Path) -> Path: def get_all_files_in_directory( - path: Path, pattern: Optional[Union[Pattern, str]] = None, max_files: Optional[int] = None + path: Path, pattern: Pattern | str | None = None, max_files: int | None = None ) -> list[Path]: """ Returns all the files in a directory structure (recursive). @@ -79,9 +79,9 @@ def get_all_files_in_directory( Args: path (pathlib.Path): A directory containing files of interest. - pattern (Optional[Union[Pattern, str]]): Optionally provide a regex + pattern (Pattern | str | None): Optionally provide a regex pattern to match. - max_files (Optional[int]): Optionally set a max file count. This is useful + max_files (int | None): Optionally set a max file count. This is useful because huge file structures will be very slow. Returns: @@ -92,7 +92,7 @@ def get_all_files_in_directory( elif not path.is_dir(): return [] - pattern_obj: Optional[Pattern] = None + pattern_obj: Pattern | None = None if isinstance(pattern, str): pattern_obj = re.compile(pattern) elif pattern is not None: @@ -132,7 +132,7 @@ class use_temp_sys_path: a user's sys paths without permanently modifying it. """ - def __init__(self, path: Path, exclude: Optional[list[Path]] = None): + def __init__(self, path: Path, exclude: list[Path] | None = None): self.temp_path = str(path) self.exclude = [str(p) for p in exclude or []] @@ -156,7 +156,7 @@ def __exit__(self, *exc): sys.path.append(path) -def get_full_extension(path: Union[Path, str]) -> str: +def get_full_extension(path: Path | str) -> str: """ For a path like ``Path("Contract.t.sol")``, returns ``.t.sol``, unlike the regular Path @@ -187,13 +187,13 @@ def get_full_extension(path: Union[Path, str]) -> str: @contextmanager -def create_tempdir(name: Optional[str] = None) -> Iterator[Path]: +def create_tempdir(name: str | None = None) -> Iterator[Path]: """ Create a temporary directory. Differs from ``TemporaryDirectory()`` context-call alone because it automatically resolves the path. Args: - name (Optional[str]): Optional provide a name of the directory. + name (str | None): Optional provide a name of the directory. Else, defaults to root of ``tempfile.TemporaryDirectory()`` (resolved). @@ -214,7 +214,7 @@ def create_tempdir(name: Optional[str] = None) -> Iterator[Path]: def run_in_tempdir( fn: Callable[[Path], Any], - name: Optional[str] = None, + name: str | None = None, ): """ Run the given function in a temporary directory with its path @@ -223,7 +223,7 @@ def run_in_tempdir( Args: fn (Callable): A function that takes a path. It gets called with the resolved path to the temporary directory. - name (Optional[str]): Optionally name the temporary directory. + name (str | None): Optionally name the temporary directory. Returns: Any: The result of the function call. @@ -247,7 +247,7 @@ def in_tempdir(path: Path) -> bool: return normalized_path.startswith(temp_dir) -def path_match(path: Union[str, Path], *exclusions: str) -> bool: +def path_match(path: str | Path, *exclusions: str) -> bool: """ A better glob-matching function. For example: @@ -325,13 +325,13 @@ def get_package_path(package_name: str) -> Path: return package_path -def extract_archive(archive_file: Path, destination: Optional[Path] = None): +def extract_archive(archive_file: Path, destination: Path | None = None): """ Extract an archive file. Supports ``.zip`` or ``.tar.gz``. Args: archive_file (Path): The file-path to the archive. - destination (Optional[Path]): Optionally provide a destination. + destination (Path | None): Optionally provide a destination. Defaults to the parent directory of the archive file. """ destination = destination or archive_file.parent @@ -477,9 +477,9 @@ def __init__( self, original_path: Path, new_path: Path, - chdir: Optional[Callable[[Path], None]] = None, - on_push: Optional[Callable[[Path], dict]] = None, - on_pop: Optional[Callable[[dict], None]] = None, + chdir: Callable[[Path], None] | None = None, + on_push: Callable[[Path], dict] | None = None, + on_pop: Callable[[dict], None] | None = None, ): self.original_path = original_path self.new_path = new_path diff --git a/src/ape/utils/rpc.py b/src/ape/utils/rpc.py index 1535fa3b90..372fa2879d 100644 --- a/src/ape/utils/rpc.py +++ b/src/ape/utils/rpc.py @@ -1,7 +1,6 @@ import time from collections.abc import Callable from random import randint -from typing import Optional import requests from requests.models import CaseInsensitiveDict @@ -27,7 +26,7 @@ def allow_disconnected(fn: Callable): @allow_disconnected - def try_snapshot(self) -> Optional[SnapshotID]: + def try_snapshot(self) -> "SnapshotID | None": return self.chain.snapshot() """ @@ -102,7 +101,7 @@ def request_with_retry( max_retry_delay: int = 30_000, max_retries: int = 10, retry_jitter: int = 250, - is_rate_limit: Optional[Callable[[Exception], bool]] = None, + is_rate_limit: Callable[[Exception], bool] | None = None, ): """ Make a request with 429/rate-limit retry logic. diff --git a/src/ape/utils/trace.py b/src/ape/utils/trace.py index 34e177ba88..6506ff32cd 100644 --- a/src/ape/utils/trace.py +++ b/src/ape/utils/trace.py @@ -1,7 +1,7 @@ from collections.abc import Sequence from fnmatch import fnmatch from statistics import mean, median -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from eth_utils import is_0x_prefixed, to_hex from rich.box import SIMPLE @@ -20,8 +20,8 @@ def prettify_function( method: str, calldata: Any, - contract: Optional[str] = None, - returndata: Optional[Any] = None, + contract: str | None = None, + returndata: Any | None = None, stylize: bool = False, is_create: bool = False, depth: int = 0, @@ -92,7 +92,7 @@ def prettify_inputs(inputs: Any, stylize: bool = False) -> str: return f"({inputs})" -def _get_outputs_str(outputs: Any, stylize: bool = False, depth: int = 0) -> Optional[str]: +def _get_outputs_str(outputs: Any, stylize: bool = False, depth: int = 0) -> str | None: if outputs in ["0x", None, (), [], {}]: return None @@ -111,7 +111,7 @@ def _get_outputs_str(outputs: Any, stylize: bool = False, depth: int = 0) -> Opt def prettify_list( - ls: Union[list, tuple], + ls: list | tuple, depth: int = 0, indent: int = 2, wrap_threshold: int = DEFAULT_WRAP_THRESHOLD, @@ -170,7 +170,7 @@ def prettify_list( def prettify_dict( dictionary: dict, - color: Optional[str] = None, + color: str | None = None, indent: int = 2, wrap_threshold: int = DEFAULT_WRAP_THRESHOLD, ) -> str: @@ -179,7 +179,7 @@ def prettify_dict( Args: dictionary (dict): The dictionary to prettify. - color (Optional[str]): The color to use for pretty printing. + color (str | None): The color to use for pretty printing. Returns: str @@ -211,7 +211,7 @@ def prettify_dict( return f"{kv_str})" -def _list_to_multiline_str(value: Union[list, tuple], depth: int = 0, indent: int = 2) -> str: +def _list_to_multiline_str(value: list | tuple, depth: int = 0, indent: int = 2) -> str: spacing = indent * " " ls_spacing = spacing * (depth + 1) joined = ",\n".join([f"{ls_spacing}{v}" for v in value]) diff --git a/src/ape_accounts/_cli.py b/src/ape_accounts/_cli.py index 405f4aa6ee..3ca1dbea9e 100644 --- a/src/ape_accounts/_cli.py +++ b/src/ape_accounts/_cli.py @@ -1,6 +1,6 @@ import json from importlib import import_module -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import click from cchecksum import to_checksum_address @@ -139,7 +139,7 @@ def generate(cli_ctx, alias, hide_mnemonic, word_count, custom_hd_path): ) @non_existing_alias_argument() def _import(cli_ctx, alias, import_from_mnemonic, custom_hd_path): - account: Optional[KeyfileAccount] = None + account: KeyfileAccount | None = None def ask_for_passphrase(): return click.prompt( diff --git a/src/ape_accounts/accounts.py b/src/ape_accounts/accounts.py index a0d2b184ca..c4038ebce6 100644 --- a/src/ape_accounts/accounts.py +++ b/src/ape_accounts/accounts.py @@ -4,7 +4,7 @@ from functools import cached_property from os import environ from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any import click from eip712.messages import EIP712Message, EIP712Type @@ -84,9 +84,9 @@ def public_key(self) -> "HexBytes": def sign_authorization( self, address: AddressType, - chain_id: Optional[int] = None, - nonce: Optional[int] = None, - ) -> Optional[MessageSignature]: + chain_id: int | None = None, + nonce: int | None = None, + ) -> MessageSignature | None: if chain_id is None: chain_id = self.provider.chain_id @@ -104,7 +104,7 @@ def sign_authorization( s=to_bytes(signed_authorization.s), ) - def sign_message(self, msg: Any, **signer_options) -> Optional[MessageSignature]: + def sign_message(self, msg: Any, **signer_options) -> MessageSignature | None: # Convert str and int to SignableMessage if needed if isinstance(msg, str): msg = encode_defunct(text=msg) @@ -128,7 +128,7 @@ def sign_message(self, msg: Any, **signer_options) -> Optional[MessageSignature] def sign_transaction( self, txn: "TransactionAPI", **signer_options - ) -> Optional["TransactionAPI"]: + ) -> "TransactionAPI | None": # Signs any transaction that's given to it. # NOTE: Using JSON mode, as only primitive types can be signed. tx_data = txn.model_dump(mode="json", by_alias=True, exclude={"sender"}) @@ -167,7 +167,7 @@ def sign_raw_msghash(self, msghash: HexBytes) -> MessageSignature: s=to_bytes(signed_msg.s), ) - def set_delegate(self, contract: Union[BaseAddress, AddressType, str], **txn_kwargs): + def set_delegate(self, contract: BaseAddress | AddressType | str, **txn_kwargs): contract_address = self.conversion_manager.convert(contract, AddressType) sig = self.sign_authorization(contract_address, nonce=self.nonce + 1) auth = Authorization.from_signature( @@ -212,7 +212,7 @@ class KeyfileAccount(AccountAPI): keyfile_path: Path locked: bool = True __autosign: bool = False - __cached_signer: Optional[ApeSigner] = None + __cached_signer: ApeSigner | None = None @log_instead_of_fail(default="") def __repr__(self) -> str: @@ -253,7 +253,7 @@ def __signer(self) -> ApeSigner: return signer @property - def public_key(self) -> Optional[HexBytes]: + def public_key(self) -> HexBytes | None: keyfile_data = self.keyfile if "public_key" in keyfile_data: return HexBytes(bytes.fromhex(keyfile_data["public_key"])) @@ -267,7 +267,7 @@ def public_key(self) -> Optional[HexBytes]: return public_key - def unlock(self, passphrase: Optional[str] = None): + def unlock(self, passphrase: str | None = None): if not passphrase: # Check if environment variable is available env_variable = f"APE_ACCOUNTS_{self.alias}_PASSPHRASE" @@ -313,9 +313,9 @@ def delete(self): def sign_authorization( self, address: AddressType, - chain_id: Optional[int] = None, - nonce: Optional[int] = None, - ) -> Optional[MessageSignature]: + chain_id: int | None = None, + nonce: int | None = None, + ) -> MessageSignature | None: if chain_id is None: chain_id = self.provider.chain_id @@ -325,7 +325,7 @@ def sign_authorization( return self.__signer.sign_authorization(address=address, chain_id=chain_id, nonce=nonce) - def sign_message(self, msg: Any, **signer_options) -> Optional[MessageSignature]: + def sign_message(self, msg: Any, **signer_options) -> MessageSignature | None: display_msg, msg = _get_signing_message_with_display(msg) if display_msg is None: logger.warning("Unsupported message type, (type=%r, msg=%r)", type(msg), msg) @@ -338,13 +338,13 @@ def sign_message(self, msg: Any, **signer_options) -> Optional[MessageSignature] def sign_transaction( self, txn: "TransactionAPI", **signer_options - ) -> Optional["TransactionAPI"]: + ) -> "TransactionAPI | None": if not (self.__autosign or click.confirm(f"{txn}\n\nSign: ")): return None return self.__signer.sign_transaction(txn, **signer_options) - def sign_raw_msghash(self, msghash: HexBytes) -> Optional[MessageSignature]: + def sign_raw_msghash(self, msghash: HexBytes) -> MessageSignature | None: logger.warning( "Signing a raw hash directly is a dangerous action which could risk " "substantial losses! Only confirm if you are 100% sure of the origin!" @@ -360,13 +360,13 @@ def sign_raw_msghash(self, msghash: HexBytes) -> Optional[MessageSignature]: warnings.simplefilter("ignore") return self.__signer.sign_raw_msghash(msghash) - def set_autosign(self, enabled: bool, passphrase: Optional[str] = None): + def set_autosign(self, enabled: bool, passphrase: str | None = None): """ Allow this account to automatically sign messages and transactions. Args: enabled (bool): ``True`` to enable, ``False`` to disable. - passphrase (Optional[str]): Optionally provide the passphrase. + passphrase (str | None): Optionally provide the passphrase. If not provided, you will be prompted to enter it. """ if enabled: @@ -379,7 +379,7 @@ def set_autosign(self, enabled: bool, passphrase: Optional[str] = None): self.locked = True self.__cached_signer = None - def _prompt_for_passphrase(self, message: Optional[str] = None, **kwargs) -> str: + def _prompt_for_passphrase(self, message: str | None = None, **kwargs) -> str: message = message or f"Enter passphrase to unlock '{self.alias}'" return click.prompt( message, @@ -393,7 +393,7 @@ def __decrypt_keyfile(self, passphrase: str) -> bytes: except ValueError as err: raise InvalidPasswordError() from err - def set_delegate(self, contract: Union[BaseAddress, AddressType, str], **txn_kwargs): + def set_delegate(self, contract: BaseAddress | AddressType | str, **txn_kwargs): return self.__signer.set_delegate(contract, **txn_kwargs) def remove_delegate(self, **txn_kwargs): @@ -488,7 +488,7 @@ def import_account_from_private_key( # Abstracted to make testing easier. -def _get_signing_message_with_display(msg) -> tuple[Optional[str], Any]: +def _get_signing_message_with_display(msg) -> tuple[str | None, Any]: display_msg = None if isinstance(msg, str): diff --git a/src/ape_cache/query.py b/src/ape_cache/query.py index 6f7e0ac036..d47052d0bc 100644 --- a/src/ape_cache/query.py +++ b/src/ape_cache/query.py @@ -1,7 +1,7 @@ from collections.abc import Iterator from functools import singledispatchmethod from pathlib import Path -from typing import Any, Optional, cast +from typing import Any, cast from sqlalchemy import create_engine, func from sqlalchemy.engine import CursorResult @@ -124,7 +124,7 @@ def database_connection(self): or if the database has not been initialized. Returns: - Optional[`sqlalchemy.engine.Connection`] + `sqlalchemy.engine.Connection` | None """ if self.provider.network.is_local: return None @@ -208,7 +208,7 @@ def _contract_events_estimate_query_clause(self, query: ContractEventQuery) -> S ) @singledispatchmethod - def _compute_estimate(self, query: QueryType, result: CursorResult) -> Optional[int]: + def _compute_estimate(self, query: QueryType, result: CursorResult) -> int | None: """ A singledispatchemethod that computes the time a query will take to perform from the caching database @@ -221,7 +221,7 @@ def _compute_estimate_block_query( self, query: BlockQuery, result: CursorResult, - ) -> Optional[int]: + ) -> int | None: if result.scalar() == (1 + query.stop_block - query.start_block) // query.step: # NOTE: Assume 200 msec to get data from database return 200 @@ -235,7 +235,7 @@ def _compute_estimate_block_transaction_query( self, query: BlockTransactionQuery, result: CursorResult, - ) -> Optional[int]: + ) -> int | None: # TODO: Update `transactions` table schema so this query functions properly # Uncomment below after https://github.com/ApeWorX/ape/issues/994 # if result.scalar() > 0: # type: ignore @@ -250,7 +250,7 @@ def _compute_estimate_contract_events_query( self, query: ContractEventQuery, result: CursorResult, - ) -> Optional[int]: + ) -> int | None: if result.scalar() == (query.stop_block - query.start_block) // query.step: # NOTE: Assume 200 msec to get data from database return 200 @@ -259,7 +259,7 @@ def _compute_estimate_contract_events_query( # TODO: Allow partial queries return None - def estimate_query(self, query: QueryType) -> Optional[int]: + def estimate_query(self, query: QueryType) -> int | None: """ Method called by the client to return a query time estimate. @@ -268,7 +268,7 @@ def estimate_query(self, query: QueryType) -> Optional[int]: check of the number of rows that match the clause. Returns: - Optional[int] + int | None """ # NOTE: Because of Python shortcircuiting, the first time `database_connection` is missing @@ -404,7 +404,7 @@ def _cache_update_events_clause(self, query: ContractEventQuery) -> Insert: @singledispatchmethod def _get_cache_data( self, query: QueryType, result: Iterator[BaseInterfaceModel] - ) -> Optional[list[dict[str, Any]]]: + ) -> list[dict[str, Any]] | None: raise QueryEngineError( """ Not a compatible QueryType. For more details see our docs @@ -415,13 +415,13 @@ def _get_cache_data( @_get_cache_data.register def _get_block_cache_data( self, query: BlockQuery, result: Iterator[BaseInterfaceModel] - ) -> Optional[list[dict[str, Any]]]: + ) -> list[dict[str, Any]] | None: return [m.model_dump(mode="json", by_alias=False) for m in result] @_get_cache_data.register def _get_block_txns_data( self, query: BlockTransactionQuery, result: Iterator[BaseInterfaceModel] - ) -> Optional[list[dict[str, Any]]]: + ) -> list[dict[str, Any]] | None: new_result = [] table_columns = [c.key for c in Transactions.__table__.columns] # type: ignore txns: list[TransactionAPI] = cast(list[TransactionAPI], result) @@ -452,7 +452,7 @@ def _get_block_txns_data( @_get_cache_data.register def _get_cache_events_data( self, query: ContractEventQuery, result: Iterator[BaseInterfaceModel] - ) -> Optional[list[dict[str, Any]]]: + ) -> list[dict[str, Any]] | None: return [m.model_dump(mode="json", by_alias=False) for m in result] def update_cache(self, query: QueryType, result: Iterator[BaseInterfaceModel]): diff --git a/src/ape_compile/config.py b/src/ape_compile/config.py index ef0c749f5f..6f12ca01ce 100644 --- a/src/ape_compile/config.py +++ b/src/ape_compile/config.py @@ -1,6 +1,5 @@ import re from re import Pattern -from typing import Union from pydantic import field_serializer, field_validator from pydantic_settings import SettingsConfigDict @@ -28,7 +27,7 @@ class Config(PluginConfig): Configure general compiler settings. """ - exclude: set[Union[str, Pattern]] = set() + exclude: set[str | Pattern] = set() """ Source exclusion globs or regex patterns across all file types. To use regex, start your values with ``r"`` and they'll be turned diff --git a/src/ape_console/_cli.py b/src/ape_console/_cli.py index 964c85262a..f2132b11e1 100644 --- a/src/ape_console/_cli.py +++ b/src/ape_console/_cli.py @@ -9,7 +9,7 @@ from os import environ from pathlib import Path from types import ModuleType -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, cast import click @@ -131,7 +131,7 @@ def _local_extras(self) -> dict: def _global_extras(self) -> dict: return self._load_extras_file(self._global_path) - def get(self, key: str, default: Optional[Any] = None): + def get(self, key: str, default: Any | None = None): try: return self.__getitem__(key) except KeyError: @@ -167,11 +167,11 @@ def _load_extras_file(self, extras_file: Path) -> dict: def console( - project: Optional[Union["ProjectManager", Path]] = None, + project: "ProjectManager | Path | None" = None, verbose: bool = False, - extra_locals: Optional[dict] = None, + extra_locals: dict | None = None, embed: bool = False, - code: Optional[list[str]] = None, + code: list[str] | None = None, ): import IPython from IPython.terminal.ipapp import Config as IPythonConfig @@ -228,7 +228,7 @@ def _launch_console( ipy_config: "IPythonConfig", embed: bool, banner: str, - code: Optional[list[str]], + code: list[str] | None, ): import IPython diff --git a/src/ape_ethereum/ecosystem.py b/src/ape_ethereum/ecosystem.py index 6c64dbe7c5..7d44d95460 100644 --- a/src/ape_ethereum/ecosystem.py +++ b/src/ape_ethereum/ecosystem.py @@ -2,7 +2,7 @@ from collections.abc import Iterator, Sequence from decimal import Decimal from functools import cached_property -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast +from typing import TYPE_CHECKING, Any, ClassVar, cast import rlp # type: ignore from cchecksum import to_checksum_address @@ -106,7 +106,7 @@ class NetworkConfig(PluginConfig): considering a transaction 'confirmed'. """ - default_provider: Optional[str] = "node" + default_provider: str | None = "node" """ The default provider to use. If set to ``None``, ape will rely on an external plugin supplying the provider implementation, such as @@ -149,7 +149,7 @@ class NetworkConfig(PluginConfig): base_fee_multiplier: float = 1.0 """A multiplier to apply to a transaction base fee.""" - is_mainnet: Optional[bool] = None + is_mainnet: bool | None = None """ Set to ``True`` to declare as a mainnet or ``False`` to ensure it isn't detected as one. @@ -190,14 +190,14 @@ def validate_gas_limit(cls, value): class ForkedNetworkConfig(NetworkConfig): - upstream_provider: Optional[str] = None + upstream_provider: str | None = None """ The provider to use as the upstream-provider for this forked network. """ def create_local_network_config( - default_provider: Optional[str] = None, use_fork: bool = False, **kwargs + default_provider: str | None = None, use_fork: bool = False, **kwargs ): if "gas_limit" not in kwargs: kwargs["gas_limit"] = "max" @@ -308,7 +308,7 @@ def __contains__(self, key: str) -> bool: return super().__contains__(key) - def get(self, key: str, default: Optional[Any] = None) -> Any: + def get(self, key: str, default: Any | None = None) -> Any: net_key = key.replace("-", "_") if net_key.endswith("_fork"): if cfg := self._get_forked_config(net_key): @@ -324,7 +324,7 @@ def get(self, key: str, default: Optional[Any] = None) -> Any: except AttributeError: return default - def _get_forked_config(self, name: str) -> Optional[ForkedNetworkConfig]: + def _get_forked_config(self, name: str) -> ForkedNetworkConfig | None: live_key: str = name.replace("_fork", "") if self._forked_configs.get(live_key): return self._forked_configs[live_key] @@ -365,7 +365,7 @@ class Block(BlockAPI): uncles: list[HexBytes] = [] # Type re-declares. - hash: Optional[HexBytes] = None + hash: HexBytes | None = None parent_hash: HexBytes = Field( default=EMPTY_BYTES32, alias="parentHash" ) # NOTE: genesis block has no parent hash @@ -469,7 +469,7 @@ def encode_contract_blueprint( deploy_bytecode, contract_type.constructor, **converted_kwargs ) - def get_proxy_info(self, address: AddressType) -> Optional[ProxyInfo]: + def get_proxy_info(self, address: AddressType) -> ProxyInfo | None: contract_code = self.chain_manager.get_code(address) if isinstance(contract_code, bytes): contract_code = to_hex(contract_code) @@ -644,7 +644,7 @@ def decode_block(self, data: dict) -> BlockAPI: return Block.model_validate(data) - def _python_type_for_abi_type(self, abi_type: ABIType) -> Union[type, Sequence]: + def _python_type_for_abi_type(self, abi_type: ABIType) -> type | Sequence: # NOTE: An array can be an array of tuples, so we start with an array check if str(abi_type.type).endswith("]"): # remove one layer of the potential onion of array @@ -681,7 +681,7 @@ def _python_type_for_abi_type(self, abi_type: ABIType) -> Union[type, Sequence]: raise ConversionError(f"Unable to convert '{abi_type}'.") - def encode_calldata(self, abi: Union[ConstructorABI, MethodABI], *args) -> HexBytes: + def encode_calldata(self, abi: ConstructorABI | MethodABI, *args) -> HexBytes: if not abi.inputs: return HexBytes("") @@ -693,7 +693,7 @@ def encode_calldata(self, abi: Union[ConstructorABI, MethodABI], *args) -> HexBy encoded_calldata = encode(input_types, converted_args) return HexBytes(encoded_calldata) - def decode_calldata(self, abi: Union[ConstructorABI, MethodABI], calldata: bytes) -> dict: + def decode_calldata(self, abi: ConstructorABI | MethodABI, calldata: bytes) -> dict: raw_input_types = [i.canonical_type for i in abi.inputs] input_types = [parse_type(i.model_dump()) for i in abi.inputs] @@ -802,8 +802,8 @@ def _enrich_value(self, value: Any, **kwargs) -> Any: return value def decode_primitive_value( - self, value: Any, output_type: Union[str, tuple, list] - ) -> Union[str, HexBytes, int, tuple, list]: + self, value: Any, output_type: str | tuple | list + ) -> str | HexBytes | int | tuple | list: if output_type == "address": try: return self.decode_address(value) @@ -989,7 +989,7 @@ def decode_logs(self, logs: Sequence[dict], *events: EventABI) -> Iterator[Contr encode_hex(keccak(text=abi.selector)): LogInputABICollection(abi) for abi in events } - def get_abi(_topic: HexStr) -> Optional[LogInputABICollection]: + def get_abi(_topic: HexStr) -> LogInputABICollection | None: return abi_inputs[_topic] if _topic in abi_inputs else None for log in logs: @@ -1160,7 +1160,7 @@ def _enrich_calltree(self, call: dict, **kwargs) -> dict: if events := call.get("events"): call["events"] = self._enrich_trace_events(events, address=address, **kwargs) - method_abi: Optional[Union[MethodABI, ConstructorABI]] = None + method_abi: MethodABI | ConstructorABI | None = None if is_create: method_abi = contract_type.constructor name = "__new__" @@ -1249,7 +1249,7 @@ def _enrich_contract_id(self, address: AddressType, **kwargs) -> str: def _enrich_calldata( self, call: dict, - method_abi: Union[MethodABI, ConstructorABI], + method_abi: MethodABI | ConstructorABI, **kwargs, ) -> dict: calldata = call["calldata"] @@ -1362,7 +1362,7 @@ def _enrich_returndata(self, call: dict, method_abi: MethodABI, **kwargs) -> dic def _enrich_trace_events( self, events: list[dict], - address: Optional[AddressType] = None, + address: AddressType | None = None, **kwargs, ) -> list[dict]: return [self._enrich_trace_event(e, address=address, **kwargs) for e in events] @@ -1370,7 +1370,7 @@ def _enrich_trace_events( def _enrich_trace_event( self, event: dict, - address: Optional[AddressType] = None, + address: AddressType | None = None, **kwargs, ) -> dict: if "topics" not in event or len(event["topics"]) < 1: @@ -1434,7 +1434,7 @@ def _enrich_revert_message(self, call: dict) -> dict: def _get_contract_type_for_enrichment( self, address: AddressType, **kwargs - ) -> Optional["ContractType"]: + ) -> "ContractType | None": if not (contract_type := kwargs.get("contract_type")): try: contract_type = self.chain_manager.contracts.get(address) @@ -1443,7 +1443,7 @@ def _get_contract_type_for_enrichment( return contract_type - def get_python_types(self, abi_type: ABIType) -> Union[type, Sequence]: + def get_python_types(self, abi_type: ABIType) -> type | Sequence: return self._python_type_for_abi_type(abi_type) def decode_custom_error( @@ -1451,7 +1451,7 @@ def decode_custom_error( data: HexBytes, address: AddressType, **kwargs, - ) -> Optional[CustomError]: + ) -> CustomError | None: # Use an instance (required for proper error caching). try: contract = self.chain_manager.contracts.instance_at(address) @@ -1519,7 +1519,7 @@ def get_deployment_address(self, address: AddressType, nonce: int) -> AddressTyp return self.decode_address(address_bytes) -def parse_type(type_: dict[str, Any]) -> Union[str, tuple, list]: +def parse_type(type_: dict[str, Any]) -> str | tuple | list: if "tuple" not in type_["type"]: return type_["type"] diff --git a/src/ape_ethereum/multicall/handlers.py b/src/ape_ethereum/multicall/handlers.py index b0b5602af2..d2c3493cc5 100644 --- a/src/ape_ethereum/multicall/handlers.py +++ b/src/ape_ethereum/multicall/handlers.py @@ -1,7 +1,7 @@ from collections.abc import Iterator from functools import cached_property from types import ModuleType -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from ethpm_types import ContractType @@ -36,7 +36,7 @@ class BaseMulticall(ManagerAccessMixin): def __init__( self, address: "AddressType" = MULTICALL3_ADDRESS, - supported_chains: Optional[list[int]] = None, + supported_chains: list[int] | None = None, ) -> None: """ Initialize a new Multicall session object. By default, there are no calls to make. @@ -164,12 +164,12 @@ class Call(BaseMulticall): def __init__( self, address: "AddressType" = MULTICALL3_ADDRESS, - supported_chains: Optional[list[int]] = None, + supported_chains: list[int] | None = None, ) -> None: super().__init__(address=address, supported_chains=supported_chains) self.abis: list[MethodABI] = [] - self._result: Union[None, list[tuple[bool, HexBytes]]] = None + self._result: list[tuple[bool, HexBytes]] | None = None @property def handler(self) -> ContractCallHandler: # type: ignore[override] diff --git a/src/ape_ethereum/provider.py b/src/ape_ethereum/provider.py index 13d9cbbd6c..1e36e7b8f5 100644 --- a/src/ape_ethereum/provider.py +++ b/src/ape_ethereum/provider.py @@ -9,7 +9,7 @@ from copy import copy from functools import cached_property, wraps from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Callable, cast import ijson # type: ignore import requests @@ -132,16 +132,16 @@ class Web3Provider(ProviderAPI, ABC): `web3.py `__ python package. """ - _web3: Optional[Web3] = None - _client_version: Optional[str] = None + _web3: Web3 | None = None + _client_version: str | None = None - _call_trace_approach: Optional[TraceApproach] = None + _call_trace_approach: TraceApproach | None = None """ Is ``None`` until known. NOTE: This gets set in `ape_ethereum.trace.Trace`. """ - _supports_debug_trace_call: Optional[bool] = None + _supports_debug_trace_call: bool | None = None _transaction_trace_cache: dict[str, TransactionTrace] = {} @@ -199,7 +199,7 @@ def _network_config(self) -> dict: return (config or {}).get(self.network.name) or {} - def _get_configured_rpc(self, key: str, validator: Callable[[str], bool]) -> Optional[str]: + def _get_configured_rpc(self, key: str, validator: Callable[[str], bool]) -> str | None: # key = "uri", "http_uri", "ws_uri", or "ipc_path" settings = self.settings # Includes self.provider_settings and top-level config. result = None @@ -223,19 +223,19 @@ def _get_configured_rpc(self, key: str, validator: Callable[[str], bool]) -> Opt return None @property - def _configured_http_uri(self) -> Optional[str]: + def _configured_http_uri(self) -> str | None: return self._get_configured_rpc("http_uri", _is_http_url) @property - def _configured_ws_uri(self) -> Optional[str]: + def _configured_ws_uri(self) -> str | None: return self._get_configured_rpc("ws_uri", _is_ws_url) @property - def _configured_ipc_path(self) -> Optional[str]: + def _configured_ipc_path(self) -> str | None: return self._get_configured_rpc("ipc_path", _is_ipc_path) @property - def _configured_uri(self) -> Optional[str]: + def _configured_uri(self) -> str | None: for key in ("uri", "url", "ipc_path", "http_uri", "ws_uri"): if rpc := self._get_configured_rpc(key, _is_uri): return rpc @@ -243,7 +243,7 @@ def _configured_uri(self) -> Optional[str]: return None @property - def _configured_rpc(self) -> Optional[str]: + def _configured_rpc(self) -> str | None: """ First of URI, HTTP_URI, WS_URI, IPC_PATH found in the provider_settings or config. @@ -270,7 +270,7 @@ def _configured_rpc(self) -> Optional[str]: return None - def _get_connected_rpc(self, validator: Callable[[str], bool]) -> Optional[str]: + def _get_connected_rpc(self, validator: Callable[[str], bool]) -> str | None: """ The connected HTTP URI. If using providers like `ape-node`, configure your URI and that will @@ -284,19 +284,19 @@ def _get_connected_rpc(self, validator: Callable[[str], bool]) -> Optional[str]: return None @property - def _connected_http_uri(self) -> Optional[str]: + def _connected_http_uri(self) -> str | None: return self._get_connected_rpc(_is_http_url) @property - def _connected_ws_uri(self) -> Optional[str]: + def _connected_ws_uri(self) -> str | None: return self._get_connected_rpc(_is_ws_url) @property - def _connected_ipc_path(self) -> Optional[str]: + def _connected_ipc_path(self) -> str | None: return self._get_connected_rpc(_is_ipc_path) @property - def _connected_uri(self) -> Optional[str]: + def _connected_uri(self) -> str | None: return self._get_connected_rpc(_is_uri) @property @@ -336,7 +336,7 @@ def network_choice(self) -> str: return super().network_choice @property - def http_uri(self) -> Optional[str]: + def http_uri(self) -> str | None: if rpc := self._connected_http_uri: return rpc @@ -351,7 +351,7 @@ def http_uri(self) -> Optional[str]: return self._default_http_uri @property - def _default_http_uri(self) -> Optional[str]: + def _default_http_uri(self) -> str | None: if self.network.is_dev: # Nothing is configured and we are running geth --dev. # Use a default localhost value. @@ -377,7 +377,7 @@ def _default_http_uri(self) -> Optional[str]: return None @property - def ws_uri(self) -> Optional[str]: + def ws_uri(self) -> str | None: if rpc := self._connected_ws_uri: return rpc @@ -393,7 +393,7 @@ def ws_uri(self) -> Optional[str]: return None @property - def ipc_path(self) -> Optional[Path]: + def ipc_path(self) -> Path | None: if rpc := self._configured_ipc_path: # "ipc_path" found in config/settings return Path(rpc) @@ -405,7 +405,7 @@ def ipc_path(self) -> Optional[Path]: return None - def _get_random_rpc(self) -> Optional[str]: + def _get_random_rpc(self) -> str | None: if self.network.is_dev: return None @@ -468,7 +468,7 @@ def base_fee(self) -> int: return self._get_last_base_fee() @property - def call_trace_approach(self) -> Optional[TraceApproach]: + def call_trace_approach(self) -> TraceApproach | None: """ The default tracing approach to use when building up a call-tree. By default, Ape attempts to use the faster approach. Meaning, if @@ -523,7 +523,7 @@ def update_settings(self, new_settings: dict): self.provider_settings.update(new_settings) self.connect() - def estimate_gas_cost(self, txn: TransactionAPI, block_id: Optional["BlockID"] = None) -> int: + def estimate_gas_cost(self, txn: TransactionAPI, block_id: "BlockID | None" = None) -> int: # NOTE: Using JSON mode since used as request data. txn_dict = txn.model_dump(by_alias=True, mode="json") @@ -646,19 +646,19 @@ def _get_latest_block(self) -> BlockAPI: def _get_latest_block_rpc(self) -> dict: return self.make_request("eth_getBlockByNumber", ["latest", False]) - def get_nonce(self, address: "AddressType", block_id: Optional["BlockID"] = None) -> int: + def get_nonce(self, address: "AddressType", block_id: "BlockID | None" = None) -> int: return self.web3.eth.get_transaction_count(address, block_identifier=block_id) - def get_balance(self, address: "AddressType", block_id: Optional["BlockID"] = None) -> int: + def get_balance(self, address: "AddressType", block_id: "BlockID | None" = None) -> int: return self.web3.eth.get_balance(address, block_identifier=block_id) def get_code( - self, address: "AddressType", block_id: Optional["BlockID"] = None + self, address: "AddressType", block_id: "BlockID | None" = None ) -> "ContractCode": return self.web3.eth.get_code(address, block_identifier=block_id) def get_storage( - self, address: "AddressType", slot: int, block_id: Optional["BlockID"] = None + self, address: "AddressType", slot: int, block_id: "BlockID | None" = None ) -> HexBytes: try: return HexBytes(self.web3.eth.get_storage_at(address, slot, block_identifier=block_id)) @@ -682,8 +682,8 @@ def get_transaction_trace(self, transaction_hash: str, **kwargs) -> "TraceAPI": def send_call( self, txn: TransactionAPI, - block_id: Optional["BlockID"] = None, - state: Optional[dict] = None, + block_id: "BlockID | None" = None, + state: dict | None = None, **kwargs: Any, ) -> HexBytes: if block_id is not None: @@ -805,7 +805,7 @@ def _eth_call( return HexBytes(result) - def _prepare_call(self, txn: Union[dict, TransactionAPI], **kwargs) -> list: + def _prepare_call(self, txn: dict | TransactionAPI, **kwargs) -> list: # NOTE: Using mode="json" because used in request data. txn_dict = ( txn.model_dump(by_alias=True, mode="json") if isinstance(txn, TransactionAPI) else txn @@ -838,7 +838,7 @@ def get_receipt( self, txn_hash: str, required_confirmations: int = 0, - timeout: Optional[int] = None, + timeout: int | None = None, **kwargs, ) -> ReceiptAPI: if required_confirmations < 0: @@ -1000,9 +1000,9 @@ def _find_txn_by_account_and_nonce( def poll_blocks( self, - stop_block: Optional[int] = None, - required_confirmations: Optional[int] = None, - new_block_timeout: Optional[int] = None, + stop_block: int | None = None, + required_confirmations: int | None = None, + new_block_timeout: int | None = None, ) -> Iterator[BlockAPI]: # Wait half the time as the block time # to get data faster. @@ -1105,12 +1105,12 @@ def assert_chain_activity(): def poll_logs( self, - stop_block: Optional[int] = None, - address: Optional["AddressType"] = None, - topics: Optional[list[Union[str, list[str]]]] = None, - required_confirmations: Optional[int] = None, - new_block_timeout: Optional[int] = None, - events: Optional[list["EventABI"]] = None, + stop_block: int | None = None, + address: "AddressType | None" = None, + topics: list[str | list[str]] | None = None, + required_confirmations: int | None = None, + new_block_timeout: int | None = None, + events: list["EventABI"] | None = None, ) -> Iterator[ContractLog]: events = events or [] if required_confirmations is None: @@ -1137,7 +1137,7 @@ def poll_logs( log_filter = LogFilter(**log_params) yield from self.get_contract_logs(log_filter) - def block_ranges(self, start: int = 0, stop: Optional[int] = None, page: Optional[int] = None): + def block_ranges(self, start: int = 0, stop: int | None = None, page: int | None = None): if stop is None: stop = self.chain_manager.blocks.height if page is None: @@ -1351,10 +1351,10 @@ def _post_connect(self): network_key=self.network.name, ) - def make_request(self, rpc: str, parameters: Optional[Iterable] = None) -> Any: + def make_request(self, rpc: str, parameters: Iterable | None = None) -> Any: return request_with_retry(lambda: self._make_request(rpc, parameters=parameters)) - def _make_request(self, rpc: str, parameters: Optional[Iterable] = None) -> Any: + def _make_request(self, rpc: str, parameters: Iterable | None = None) -> Any: parameters = parameters or [] try: result = self.web3.provider.make_request(RPCEndpoint(rpc), parameters) @@ -1413,7 +1413,7 @@ def stream_request(self, method: str, params: Iterable, iter_path: str = "result del results[:] def create_access_list( - self, transaction: TransactionAPI, block_id: Optional["BlockID"] = None + self, transaction: TransactionAPI, block_id: "BlockID | None" = None ) -> list[AccessList]: """ Get the access list for a transaction use ``eth_createAccessList``. @@ -1489,12 +1489,12 @@ def get_virtual_machine_error(self, exception: Exception, **kwargs) -> VirtualMa def _handle_execution_reverted( self, - exception: Union[Exception, str], - txn: Optional[TransactionAPI] = None, + exception: Exception | str, + txn: TransactionAPI | None = None, trace: _TRACE_ARG = None, - contract_address: Optional["AddressType"] = None, + contract_address: "AddressType | None" = None, source_traceback: _SOURCE_TRACEBACK_ARG = None, - set_ape_traceback: Optional[bool] = None, + set_ape_traceback: bool | None = None, ) -> ContractLogicError: if hasattr(exception, "args") and len(exception.args) == 2: message = exception.args[0].replace("execution reverted: ", "") @@ -1553,7 +1553,7 @@ def _handle_execution_reverted( # Abstracted for unit-testing. -def _get_trace_from_revert_kwargs(**kwargs) -> Optional["TraceAPI"]: +def _get_trace_from_revert_kwargs(**kwargs) -> "TraceAPI | None": trace = kwargs.get("trace") txn = kwargs.get("txn") @@ -1589,7 +1589,7 @@ def connection_str(self) -> str: return self.uri or f"{self.ipc_path}" @property - def connection_id(self) -> Optional[str]: + def connection_id(self) -> str | None: return f"{self.network_choice}:{self.uri}" @property @@ -1605,7 +1605,7 @@ def data_dir(self) -> Path: return _get_default_data_dir() @property - def ipc_path(self) -> Optional[Path]: + def ipc_path(self) -> Path | None: if path := super().ipc_path: return path @@ -1644,7 +1644,7 @@ def has_poa_history(self) -> bool: return findings @cached_property - def _ots_api_level(self) -> Optional[int]: + def _ots_api_level(self) -> int | None: # NOTE: Returns None when OTS namespace is not enabled. try: result = self.make_request("ots_getApiLevel") @@ -1726,7 +1726,7 @@ def _log_connection(self, client_name: str): ) logger.info(f"{msg} {suffix}.") - def ots_get_contract_creator(self, address: "AddressType") -> Optional[dict]: + def ots_get_contract_creator(self, address: "AddressType") -> dict | None: if self._ots_api_level is None: return None @@ -1737,7 +1737,7 @@ def ots_get_contract_creator(self, address: "AddressType") -> Optional[dict]: return result - def _get_contract_creation_receipt(self, address: "AddressType") -> Optional[ReceiptAPI]: + def _get_contract_creation_receipt(self, address: "AddressType") -> ReceiptAPI | None: if result := self.ots_get_contract_creator(address): tx_hash = result["hash"] return self.get_receipt(tx_hash) @@ -1754,7 +1754,7 @@ def connect(self): self._complete_connect() def simulate_transaction_bundle( - self, bundle: Bundle, sim_overrides: Optional[dict] = None + self, bundle: Bundle, sim_overrides: dict | None = None ) -> SimulationReport: """ Submit a bundle and get the simulation result. @@ -1772,10 +1772,10 @@ def simulate_transaction_bundle( def _create_web3( - http_uri: Optional[str] = None, - ipc_path: Optional[Path] = None, - ws_uri: Optional[str] = None, - request_kwargs: Optional[dict] = None, + http_uri: str | None = None, + ipc_path: Path | None = None, + ws_uri: str | None = None, + request_kwargs: dict | None = None, ): # NOTE: This list is ordered by try-attempt. # Try ENV, then IPC, and then HTTP last. @@ -1826,7 +1826,7 @@ def _is_ws_url(val: str) -> bool: return val.startswith("wss://") or val.startswith("ws://") -def _is_ipc_path(val: Union[str, Path]) -> bool: +def _is_ipc_path(val: str | Path) -> bool: return f"{val}".endswith(".ipc") @@ -1843,7 +1843,7 @@ def trace(self) -> CallTrace: ) @cached_property - def source_traceback(self) -> Optional[SourceTraceback]: + def source_traceback(self) -> SourceTraceback | None: ct = self.contract_type if ct is None: return None diff --git a/src/ape_ethereum/proxies.py b/src/ape_ethereum/proxies.py index 854d8efbef..938ea53b5e 100644 --- a/src/ape_ethereum/proxies.py +++ b/src/ape_ethereum/proxies.py @@ -1,5 +1,5 @@ from enum import IntEnum, auto -from typing import Optional, cast +from typing import cast from eth_pydantic_types.hex import HexStr from ethpm_types import ContractType, MethodABI @@ -78,7 +78,7 @@ def __init__(self, **kwargs): self._abi = abi @property - def abi(self) -> Optional[MethodABI]: + def abi(self) -> MethodABI | None: return self._abi diff --git a/src/ape_ethereum/query.py b/src/ape_ethereum/query.py index 7dd8a4f9ca..d5b94cd346 100644 --- a/src/ape_ethereum/query.py +++ b/src/ape_ethereum/query.py @@ -1,6 +1,5 @@ from collections.abc import Iterator from functools import singledispatchmethod -from typing import Optional from ape.api.query import ContractCreation, ContractCreationQuery, QueryAPI, QueryType from ape.exceptions import APINotImplementedError, ProviderError, QueryEngineError @@ -16,7 +15,7 @@ def __init__(self): self.supports_contract_creation = None # will be set after we try for the first time @singledispatchmethod - def estimate_query(self, query: QueryType) -> Optional[int]: # type: ignore[override] + def estimate_query(self, query: QueryType) -> int | None: # type: ignore[override] return None @singledispatchmethod @@ -24,7 +23,7 @@ def perform_query(self, query: QueryType) -> Iterator: # type: ignore[override] raise QueryEngineError(f"Cannot handle '{type(query)}'.") @estimate_query.register - def estimate_contract_creation_query(self, query: ContractCreationQuery) -> Optional[int]: + def estimate_contract_creation_query(self, query: ContractCreationQuery) -> int | None: # NOTE: Extremely expensive query, involves binary search of all blocks in a chain # Very loose estimate of 5s per transaction for this query. if self.supports_contract_creation is False: diff --git a/src/ape_ethereum/trace.py b/src/ape_ethereum/trace.py index 164123567a..1d1badb9a9 100644 --- a/src/ape_ethereum/trace.py +++ b/src/ape_ethereum/trace.py @@ -5,7 +5,7 @@ from collections.abc import Iterable, Iterator, Sequence from enum import Enum from functools import cached_property -from typing import IO, TYPE_CHECKING, Any, Optional, Union +from typing import IO, TYPE_CHECKING, Any from eth_pydantic_types import HexStr from eth_utils import is_0x_prefixed, to_hex @@ -96,10 +96,10 @@ class Trace(TraceAPI): involved, Ape must use the ``.name`` as the identifier for all contracts. """ - call_trace_approach: Optional[TraceApproach] = None + call_trace_approach: TraceApproach | None = None """When None, attempts to deduce.""" - _enriched_calltree: Optional[dict] = None + _enriched_calltree: dict | None = None def __repr__(self) -> str: try: @@ -182,7 +182,7 @@ def addresses(self) -> Iterator["AddressType"]: yield from self.get_addresses_used() @cached_property - def root_contract_type(self) -> Optional["ContractType"]: + def root_contract_type(self) -> "ContractType | None": if address := self.transaction.get("to"): try: return self.chain_manager.contracts.get(address) @@ -192,7 +192,7 @@ def root_contract_type(self) -> Optional["ContractType"]: return None @cached_property - def root_method_abi(self) -> Optional["MethodABI"]: + def root_method_abi(self) -> "MethodABI | None": method_id = self.transaction.get("data", b"")[:10] if ct := self.root_contract_type: try: @@ -249,8 +249,8 @@ def _return_value_from_enriched_calltree(self) -> Any: return self._get_return_value_from_calltree(calltree) def _get_return_value_from_calltree( - self, calltree: Union[dict, CallTreeNode] - ) -> tuple[Optional[Any], ...]: + self, calltree: dict | CallTreeNode + ) -> tuple[Any | None, ...]: num_outputs = 1 if raw_return_data := ( calltree.get("returndata") if isinstance(calltree, dict) else calltree.returndata @@ -269,12 +269,12 @@ def _get_return_value_from_calltree( return tuple([None for _ in range(num_outputs)]) @cached_property - def revert_message(self) -> Optional[str]: + def revert_message(self) -> str | None: call = self.enriched_calltree if not call.get("failed", False): return None - def try_get_revert_msg(c) -> Optional[str]: + def try_get_revert_msg(c) -> str | None: if msg := c.get("revert_message"): return msg @@ -294,7 +294,7 @@ def try_get_revert_msg(c) -> Optional[str]: return None @cached_property - def _last_frame(self) -> Optional[dict]: + def _last_frame(self) -> dict | None: try: frame = deque(self.raw_trace_frames, maxlen=1) except Exception as err: @@ -304,7 +304,7 @@ def _last_frame(self) -> Optional[dict]: return frame[0] if frame else None @cached_property - def _revert_str_from_trace_frames(self) -> Optional[HexBytes]: + def _revert_str_from_trace_frames(self) -> HexBytes | None: if frame := self._last_frame: memory = frame.get("memory", []) if ret := "".join([x[2:] for x in memory[4:]]): @@ -313,7 +313,7 @@ def _revert_str_from_trace_frames(self) -> Optional[HexBytes]: return None @cached_property - def _return_data_from_trace_frames(self) -> Optional[HexBytes]: + def _return_data_from_trace_frames(self) -> HexBytes | None: if frame := self._last_frame: memory = frame["memory"] start_pos = int(frame["stack"][2], 16) // 32 @@ -368,13 +368,13 @@ def show(self, verbose: bool = False, file: IO[str] = sys.stdout): console.print(root) def get_gas_report( - self, exclude: Optional[Sequence["ContractFunctionPath"]] = None + self, exclude: Sequence["ContractFunctionPath"] | None = None ) -> "GasReport": call = self.enriched_calltree return self._get_gas_report_from_call(call, exclude=exclude) def _get_gas_report_from_call( - self, call: dict, exclude: Optional[Sequence["ContractFunctionPath"]] = None + self, call: dict, exclude: Sequence["ContractFunctionPath"] | None = None ) -> "GasReport": tx = self.transaction @@ -446,7 +446,7 @@ def _debug_trace_transaction_struct_logs_to_call(self) -> CallTreeNode: def _get_tree(self, verbose: bool = False) -> Tree: return parse_rich_tree(self.enriched_calltree, verbose=verbose) - def _get_abi(self, call: Union[dict, CallTreeNode]) -> Optional["MethodABI"]: + def _get_abi(self, call: dict | CallTreeNode) -> "MethodABI | None": if not (addr := call.get("address") if isinstance(call, dict) else call.address): return self.root_method_abi if not (calldata := call.get("calldata") if isinstance(call, dict) else call.calldata): @@ -542,7 +542,7 @@ def _discover_calltrace_approach(self) -> CallTreeNode: reason_str = ", ".join(f"{k}={v}" for k, v in reason_map.items()) raise ProviderError(f"Unable to create CallTreeNode. Reason(s): {reason_str}") - def _debug_trace_transaction(self, parameters: Optional[dict] = None) -> dict: + def _debug_trace_transaction(self, parameters: dict | None = None) -> dict: parameters = parameters or self.debug_trace_transaction_parameters return self.provider.make_request( "debug_traceTransaction", [self.transaction_hash, parameters] @@ -603,7 +603,7 @@ class CallTrace(Trace): call_trace_approach: TraceApproach = TraceApproach.GETH_STRUCT_LOG_PARSE """debug_traceCall must use the struct-log tracer.""" - supports_debug_trace_call: Optional[bool] = None + supports_debug_trace_call: bool | None = None @field_validator("tx", mode="before") @classmethod diff --git a/src/ape_ethereum/transactions.py b/src/ape_ethereum/transactions.py index 9884324f99..d9bc5d6114 100644 --- a/src/ape_ethereum/transactions.py +++ b/src/ape_ethereum/transactions.py @@ -1,7 +1,7 @@ import sys from enum import Enum, IntEnum from functools import cached_property -from typing import IO, TYPE_CHECKING, Any, Optional, Union +from typing import IO, TYPE_CHECKING, Any from eth_abi import decode from eth_account import Account as EthAccount @@ -140,10 +140,10 @@ class StaticFeeTransaction(BaseTransaction): Transactions that are pre-EIP-1559 and use the ``gasPrice`` field. """ - gas_price: Optional[HexInt] = Field(default=None, alias="gasPrice") - max_priority_fee: Optional[HexInt] = Field(default=None, exclude=True) # type: ignore + gas_price: HexInt | None = Field(default=None, alias="gasPrice") + max_priority_fee: HexInt | None = Field(default=None, exclude=True) # type: ignore type: HexInt = Field(default=TransactionType.STATIC.value, exclude=True) - max_fee: Optional[HexInt] = Field(default=None, exclude=True) # type: ignore + max_fee: HexInt | None = Field(default=None, exclude=True) # type: ignore @model_validator(mode="after") @classmethod @@ -159,7 +159,7 @@ class AccessListTransaction(StaticFeeTransaction): transactions are similar to legacy transaction with an added access list functionality. """ - gas_price: Optional[int] = Field(default=None, alias="gasPrice") + gas_price: int | None = Field(default=None, alias="gasPrice") type: int = TransactionType.ACCESS_LIST.value access_list: list[AccessList] = Field(default_factory=list, alias="accessList") @@ -175,8 +175,8 @@ class DynamicFeeTransaction(BaseTransaction): and ``maxPriorityFeePerGas`` fields. """ - max_priority_fee: Optional[HexInt] = Field(default=None, alias="maxPriorityFeePerGas") - max_fee: Optional[HexInt] = Field(default=None, alias="maxFeePerGas") + max_priority_fee: HexInt | None = Field(default=None, alias="maxPriorityFeePerGas") + max_fee: HexInt | None = Field(default=None, alias="maxFeePerGas") type: HexInt = TransactionType.DYNAMIC.value access_list: list[AccessList] = Field(default_factory=list, alias="accessList") @@ -302,14 +302,14 @@ def debug_logs_typed(self) -> list[tuple[Any]]: return list(trace.debug_logs) @cached_property - def contract_type(self) -> Optional["ContractType"]: + def contract_type(self) -> "ContractType | None": if address := (self.receiver or self.contract_address): return self.chain_manager.contracts.get(address) return None @cached_property - def method_called(self) -> Optional[MethodABI]: + def method_called(self) -> MethodABI | None: if not self.contract_type: return None @@ -333,7 +333,7 @@ def source_traceback(self) -> SourceTraceback: return SourceTraceback.model_validate([]) def raise_for_status(self): - err: Optional[TransactionError] = None + err: TransactionError | None = None if self.gas_limit is not None and self.ran_out_of_gas: err = OutOfGasError(txn=self) @@ -370,9 +370,9 @@ def show_events(self, file: IO[str] = sys.stdout): def decode_logs( self, - abi: Optional[ - Union[list[Union[EventABI, "ContractEvent"]], Union[EventABI, "ContractEvent"]] - ] = None, + abi: ( + "list[EventABI | ContractEvent] | EventABI | ContractEvent | None" + ) = None, ) -> ContractLogContainer: if not self.logs: # Short circuit. @@ -400,7 +400,7 @@ def decode_logs( } def get_default_log( - _log: dict, logs: ContractLogContainer, evt_name: Optional[str] = None + _log: dict, logs: ContractLogContainer, evt_name: str | None = None ) -> ContractLog: log_index = _log.get("logIndex", logs[-1].log_index + 1 if logs else 0) @@ -459,7 +459,7 @@ def get_default_log( return decoded_logs - def _decode_ds_note(self, log: dict) -> Optional[ContractLog]: + def _decode_ds_note(self, log: dict) -> ContractLog | None: if len(log["topics"]) == 0: # anon event log return None @@ -506,7 +506,7 @@ class SharedBlobReceipt(Receipt): blob transaction. """ - blob_gas_used: Optional[HexInt] = Field(default=None, alias="blobGasUsed") + blob_gas_used: HexInt | None = Field(default=None, alias="blobGasUsed") """ The total amount of blob gas consumed by the transactions within the block. """ diff --git a/src/ape_networks/config.py b/src/ape_networks/config.py index a519ed64fe..cde1369f50 100644 --- a/src/ape_networks/config.py +++ b/src/ape_networks/config.py @@ -1,4 +1,3 @@ -from typing import Optional from pydantic_settings import SettingsConfigDict @@ -19,7 +18,7 @@ class CustomNetwork(PluginConfig): ecosystem: str """The name of the ecosystem.""" - base_ecosystem_plugin: Optional[str] = None + base_ecosystem_plugin: str | None = None """The base ecosystem plugin to use, when applicable. Defaults to the default ecosystem.""" default_provider: str = "node" diff --git a/src/ape_node/provider.py b/src/ape_node/provider.py index e2283facfd..339ffcdce3 100644 --- a/src/ape_node/provider.py +++ b/src/ape_node/provider.py @@ -4,7 +4,7 @@ import shutil from pathlib import Path from subprocess import DEVNULL, PIPE, Popen -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from urllib.error import HTTPError, URLError from urllib.parse import urlparse from urllib.request import urlopen @@ -109,25 +109,25 @@ class GethDevProcess(BaseGethProcess): def __init__( self, data_dir: Path, - hostname: Optional[str] = None, - port: Optional[int] = None, - ipc_path: Optional[Path] = None, - ws_hostname: Optional[str] = None, - ws_port: Optional[str] = None, + hostname: str | None = None, + port: int | None = None, + ipc_path: Path | None = None, + ws_hostname: str | None = None, + ws_port: str | None = None, mnemonic: str = DEFAULT_TEST_MNEMONIC, number_of_accounts: int = DEFAULT_NUMBER_OF_TEST_ACCOUNTS, chain_id: int = DEFAULT_TEST_CHAIN_ID, - initial_balance: Union[str, int] = DEFAULT_TEST_ACCOUNT_BALANCE, - executable: Optional[Union[list[str], str]] = None, + initial_balance: str | int = DEFAULT_TEST_ACCOUNT_BALANCE, + executable: list[str] | str | None = None, auto_disconnect: bool = True, - extra_funded_accounts: Optional[list[str]] = None, - hd_path: Optional[str] = DEFAULT_TEST_HD_PATH, - block_time: Optional[int] = None, + extra_funded_accounts: list[str] | None = None, + hd_path: str | None = DEFAULT_TEST_HD_PATH, + block_time: int | None = None, generate_accounts: bool = True, initialize_chain: bool = True, background: bool = False, verify_bin: bool = True, - rpc_api: Optional[list[str]] = None, + rpc_api: list[str] | None = None, ): if isinstance(executable, str): # Legacy. @@ -337,19 +337,19 @@ def is_rpc_ready(self) -> bool: return True @property - def _hostname(self) -> Optional[str]: + def _hostname(self) -> str | None: return self.geth_kwargs.get("rpc_addr") @property - def _port(self) -> Optional[str]: + def _port(self) -> str | None: return self.geth_kwargs.get("rpc_port") @property - def _ws_hostname(self) -> Optional[str]: + def _ws_hostname(self) -> str | None: return self.geth_kwargs.get("ws_addr") @property - def _ws_port(self) -> Optional[str]: + def _ws_port(self) -> str | None: return self.geth_kwargs.get("ws_port") def connect(self, timeout: int = 60): @@ -457,25 +457,25 @@ class EthereumNodeConfig(PluginConfig): such as which URIs to use for each network. """ - executable: Optional[list[str]] = None + executable: list[str] | None = None """ For starting nodes, select the executable. Defaults to using ``shutil.which("geth")``. """ - data_dir: Optional[Path] = None + data_dir: Path | None = None """ For node-management, choose where the geth data directory shall be located. Defaults to using a location within Ape's DATA_FOLDER. """ - ipc_path: Optional[Path] = None + ipc_path: Path | None = None """ For IPC connections, select the IPC path. If managing a process, web3.py can determine the IPC w/o needing to manually configure. """ - call_trace_approach: Optional[TraceApproach] = None + call_trace_approach: TraceApproach | None = None """ Select the trace approach to use. Defaults to deducing one based on your node's client-version and available RPCs. @@ -486,7 +486,7 @@ class EthereumNodeConfig(PluginConfig): Optionally specify request headers to use whenever using this provider. """ - rpc_api: Optional[list[str]] = None + rpc_api: list[str] | None = None """ RPC APIs to enable. Defaults to all geth APIs. """ @@ -525,7 +525,7 @@ def __init__(self): # NOTE: Using EthereumNodeProvider because of it's geth-derived default behavior. # TODO: In 0.9, change NAME to be `gethdev`, so for local networks it is more obvious. class GethDev(EthereumNodeProvider, TestProviderAPI, SubprocessProvider): - _process: Optional[GethDevProcess] = None + _process: GethDevProcess | None = None name: str = "node" @property @@ -543,7 +543,7 @@ def chain_id(self) -> int: return self.settings.ethereum.local.get("chain_id", DEFAULT_TEST_CHAIN_ID) @property - def block_time(self) -> Optional[int]: + def block_time(self) -> int | None: return self.settings.ethereum.local.get("block_time") @property @@ -574,7 +574,7 @@ def auto_mine(self, value): raise NotImplementedError("'auto_mine' setter not implemented.") @property - def ipc_path(self) -> Optional[Path]: + def ipc_path(self) -> Path | None: if rpc := self._configured_ipc_path: # "ipc_path" found in config/settings return Path(rpc) diff --git a/src/ape_node/query.py b/src/ape_node/query.py index e231d74a39..045e49dc41 100644 --- a/src/ape_node/query.py +++ b/src/ape_node/query.py @@ -1,6 +1,5 @@ from collections.abc import Iterator from functools import singledispatchmethod -from typing import Optional from ape.api.query import ContractCreation, ContractCreationQuery, QueryAPI, QueryType from ape.exceptions import QueryEngineError @@ -10,7 +9,7 @@ class OtterscanQueryEngine(QueryAPI): @singledispatchmethod - def estimate_query(self, query: QueryType) -> Optional[int]: # type: ignore[override] + def estimate_query(self, query: QueryType) -> int | None: # type: ignore[override] return None @singledispatchmethod @@ -20,7 +19,7 @@ def perform_query(self, query: QueryType) -> Iterator: # type: ignore[override] ) @estimate_query.register - def estimate_contract_creation_query(self, query: ContractCreationQuery) -> Optional[int]: + def estimate_contract_creation_query(self, query: ContractCreationQuery) -> int | None: if getattr(self.provider, "_ots_api_level", None) is not None: return 250 return None diff --git a/src/ape_pm/_cli.py b/src/ape_pm/_cli.py index de837199e4..7b261fb569 100644 --- a/src/ape_pm/_cli.py +++ b/src/ape_pm/_cli.py @@ -1,7 +1,7 @@ import sys from importlib import import_module from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import click @@ -113,7 +113,7 @@ def rows(): click.echo(row.strip()) -def _handle_package_path(path: Path, original_value: Optional[str] = None) -> dict: +def _handle_package_path(path: Path, original_value: str | None = None) -> dict: if not path.exists(): value = original_value or path.as_posix() raise click.BadArgumentUsage(f"Unknown package '{value}'.") diff --git a/src/ape_pm/compiler.py b/src/ape_pm/compiler.py index 46444eaa3f..0e0a9ebeb8 100644 --- a/src/ape_pm/compiler.py +++ b/src/ape_pm/compiler.py @@ -2,7 +2,7 @@ from collections.abc import Iterable, Iterator from json import JSONDecodeError from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from eth_pydantic_types import HexBytes from eth_utils import is_0x_prefixed @@ -35,8 +35,8 @@ def get_versions(self, all_paths: Iterable[Path]) -> set[str]: def compile( self, contract_filepaths: Iterable[Path], - project: Optional["ProjectManager"], - settings: Optional[dict] = None, + project: "ProjectManager | None", + settings: dict | None = None, ) -> Iterator[ContractType]: project = project or self.local_project source_ids = { @@ -68,7 +68,7 @@ def compile( def compile_code( self, code: str, - project: Optional["ProjectManager"] = None, + project: "ProjectManager | None" = None, **kwargs, ) -> ContractType: code = code or "[]" diff --git a/src/ape_pm/dependency.py b/src/ape_pm/dependency.py index d5b9a26cd4..d228246db2 100644 --- a/src/ape_pm/dependency.py +++ b/src/ape_pm/dependency.py @@ -4,7 +4,6 @@ from collections.abc import Iterable from functools import cached_property from pathlib import Path -from typing import Optional, Union import requests from pydantic import model_validator @@ -18,7 +17,7 @@ from ape.utils.os import _remove_readonly, clean_path, extract_archive, get_package_path, in_tempdir -def _fetch_local(src: Path, destination: Path, config_override: Optional[dict] = None): +def _fetch_local(src: Path, destination: Path, config_override: dict | None = None): if src.is_dir(): project = ManagerAccessMixin.Project(src, config_override=config_override) project.unpack(destination) @@ -42,7 +41,7 @@ class LocalDependency(DependencyAPI): The root path (and API defining key) to the dependency files. """ - version: Optional[str] = None + version: str | None = None """ Specified version. """ @@ -119,7 +118,7 @@ class GithubDependency(DependencyAPI): such as ``dapphub/erc20``. """ - ref: Optional[str] = None + ref: str | None = None """ The branch or tag to use. When using this field instead of the 'release' field, the repository @@ -129,7 +128,7 @@ class GithubDependency(DependencyAPI): **NOTE**: Will be ignored if given a 'release'. """ - version: Optional[str] = None + version: str | None = None """ The release version to use. When using this field instead of the 'ref' field, the GitHub @@ -307,7 +306,7 @@ class NpmDependency(DependencyAPI): The package must already be installed! """ - version: Optional[str] = None + version: str | None = None """ Specify the version, if not wanting to use discovered version from install. @@ -360,7 +359,7 @@ def package_id(self) -> str: return str(self.npm).split("node_modules")[-1].strip(os.path.sep) @cached_property - def version_from_installed_package_json(self) -> Optional[str]: + def version_from_installed_package_json(self) -> str | None: """ The version from package.json in the installed package. Requires having run `npm install`. @@ -368,7 +367,7 @@ def version_from_installed_package_json(self) -> Optional[str]: return _get_version_from_package_json(self.npm) @cached_property - def version_from_project_package_json(self) -> Optional[str]: + def version_from_project_package_json(self) -> str | None: """ The version from your project's package.json, if exists. """ @@ -391,8 +390,8 @@ def fetch(self, destination: Path): def _get_version_from_package_json( - base_path: Path, dict_path: Optional[Iterable[Union[str, Path]]] = None -) -> Optional[str]: + base_path: Path, dict_path: Iterable[str | Path] | None = None +) -> str | None: package_json = base_path / "package.json" if not package_json.is_file(): return None @@ -423,21 +422,21 @@ class PythonDependency(DependencyAPI): A dependency installed from Python tooling, such as `pip`. """ - site_package: Optional[str] = None + site_package: str | None = None """ The Python site-package name, such as ``"snekmate"``. Cannot use with ``pypi:``. Requires the dependency to have been installed either via ``pip`` or something alike. """ - pypi: Optional[str] = None + pypi: str | None = None """ The ``pypi`` reference, such as ``"snekmate"``. Cannot use with ``python:``. When set, downloads the dependency from ``pypi`` using HTTP directly (not ``pip``). """ - version: Optional[str] = None + version: str | None = None """ Optionally specify the version expected to be installed. """ @@ -462,7 +461,7 @@ def validate_model(cls, values): return values @cached_property - def path(self) -> Optional[Path]: + def path(self) -> Path | None: if self.pypi: # Is pypi: specified; has no special path. return None @@ -483,7 +482,7 @@ def package_id(self) -> str: raise ProjectError("Must provide either 'pypi:' or 'python:' for python-base dependencies.") @property - def python(self) -> Optional[str]: + def python(self) -> str | None: # For backwards-compat; serves as an undocumented alias. return self.site_package @@ -551,7 +550,7 @@ def package_data(self) -> dict: return response.json() @cached_property - def version_from_package_data(self) -> Optional[str]: + def version_from_package_data(self) -> str | None: return self.package_data.get("info", {}).get("version") @cached_property diff --git a/src/ape_pm/project.py b/src/ape_pm/project.py index fe62a3b6b4..89930b8b1d 100644 --- a/src/ape_pm/project.py +++ b/src/ape_pm/project.py @@ -1,7 +1,7 @@ import sys from collections.abc import Iterable from pathlib import Path -from typing import Any, Optional +from typing import Any from ape.utils._github import _GithubClient, github_client @@ -255,7 +255,7 @@ def _parse_solidity_config( data: dict, dependencies: list[dict], lib_paths: Iterable[str], - contracts_folder: Optional[str] = None, + contracts_folder: str | None = None, ) -> dict: sol_cfg: dict = {} @@ -288,7 +288,7 @@ def _parse_remappings( foundry_remappings: list[str], dependencies: list[dict], lib_paths: Iterable[str], - contracts_folder: Optional[str] = None, + contracts_folder: str | None = None, ) -> list[str]: ape_sol_remappings: set[str] = set() diff --git a/src/ape_run/_cli.py b/src/ape_run/_cli.py index c15dea1391..df2255d344 100644 --- a/src/ape_run/_cli.py +++ b/src/ape_run/_cli.py @@ -5,7 +5,7 @@ from contextlib import contextmanager from pathlib import Path from runpy import run_module -from typing import Any, Union +from typing import Any import click @@ -95,7 +95,7 @@ def invoke(self, ctx: click.Context) -> Any: # Don't handle error - raise exception as normal. raise - def _get_command(self, filepath: Path) -> Union[click.Command, click.Group, None]: + def _get_command(self, filepath: Path) -> click.Command | click.Group | None: scripts_folder = Path.cwd() / "scripts" relative_filepath = filepath.relative_to(scripts_folder) @@ -176,7 +176,7 @@ def call(): return call @property - def commands(self) -> dict[str, Union[click.Command, click.Group]]: + def commands(self) -> dict[str, click.Command | click.Group]: # perf: Don't reference `.local_project.scripts_folder` here; # it's too slow when doing just doing `--help`. scripts_folder = Path.cwd() / "scripts" diff --git a/src/ape_test/accounts.py b/src/ape_test/accounts.py index d2bed151f3..1abc228ed2 100644 --- a/src/ape_test/accounts.py +++ b/src/ape_test/accounts.py @@ -1,6 +1,6 @@ from collections.abc import Iterator from functools import cached_property -from typing import Optional, cast +from typing import cast from ape.api.accounts import TestAccountAPI, TestAccountContainerAPI from ape.exceptions import ProviderNotConnectedError @@ -59,7 +59,7 @@ def get_test_account(self, index: int) -> TestAccountAPI: except (NotImplementedError, ProviderNotConnectedError): return self.generate_account(index=index) - def generate_account(self, index: Optional[int] = None) -> "TestAccountAPI": + def generate_account(self, index: int | None = None) -> "TestAccountAPI": new_index = ( self.number_of_accounts + len(self.generated_accounts) if index is None else index ) diff --git a/src/ape_test/config.py b/src/ape_test/config.py index adb2541a6f..f7a030f5e1 100644 --- a/src/ape_test/config.py +++ b/src/ape_test/config.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, NewType, Optional, Union +from typing import TYPE_CHECKING, NewType from pydantic import NonNegativeInt, field_validator from pydantic_settings import SettingsConfigDict @@ -25,7 +25,7 @@ class EthTesterProviderConfig(PluginConfig): class GasExclusion(PluginConfig): contract_name: str = "*" # If only given method, searches across all contracts. - method_name: Optional[str] = None # By default, match all methods in a contract + method_name: str | None = None # By default, match all methods in a contract model_config = SettingsConfigDict(extra="allow", env_prefix="APE_TEST_") @@ -70,7 +70,7 @@ def show(self) -> bool: return "terminal" in self.reports -_ReportType = Union[bool, dict] +_ReportType = bool | dict """Dict is for extra report settings.""" @@ -214,7 +214,7 @@ class ApeTestConfig(PluginConfig): useful for debugging the framework itself. """ - isolation: Union[bool, IsolationConfig] = True + isolation: bool | IsolationConfig = True """ Configure which scope-specific isolation to enable. Set to ``False`` to disable all and ``True`` (default) to disable all. diff --git a/src/ape_test/provider.py b/src/ape_test/provider.py index af2fa12adc..9e71017348 100644 --- a/src/ape_test/provider.py +++ b/src/ape_test/provider.py @@ -4,7 +4,7 @@ from functools import cached_property from pathlib import Path from re import Pattern -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast from eth.exceptions import HeaderNotFound from eth_pydantic_types import HexBytes @@ -112,8 +112,8 @@ class ApeTester(EthereumTesterProvider): def __init__(self, config: "ApeTestConfig", chain_id: int): self.config = config self.chain_id = chain_id - self._backend: Optional[ApeEVMBackend] = None - self._ethereum_tester: Optional[EthereumTester] = None + self._backend: ApeEVMBackend | None = None + self._ethereum_tester: EthereumTester | None = None @property def ethereum_tester(self) -> EthereumTester: @@ -159,7 +159,7 @@ def api_endpoints(self) -> dict: # type: ignore class LocalProvider(TestProviderAPI, Web3Provider): - _evm_backend: Optional[PyEVMBackend] = None + _evm_backend: PyEVMBackend | None = None _CANNOT_AFFORD_GAS_PATTERN: Pattern = re.compile( r"Sender b'[\\*|\w]*' cannot afford txn gas (\d+) with account balance (\d+)" ) @@ -197,15 +197,15 @@ def max_gas(self) -> int: return self.evm_backend.get_block_by_number("latest")["gas_limit"] @property - def http_uri(self) -> Optional[str]: + def http_uri(self) -> str | None: return None @property - def ws_uri(self) -> Optional[str]: + def ws_uri(self) -> str | None: return None @property - def ipc_path(self) -> Optional[Path]: + def ipc_path(self) -> Path | None: return None def connect(self): @@ -230,7 +230,7 @@ def update_settings(self, new_settings: dict): self.connect() def estimate_gas_cost( - self, txn: "TransactionAPI", block_id: Optional["BlockID"] = None, **kwargs + self, txn: "TransactionAPI", block_id: "BlockID | None" = None, **kwargs ) -> int: if isinstance(self.network.gas_limit, int): return self.network.gas_limit @@ -311,8 +311,8 @@ def base_fee(self) -> int: def send_call( self, txn: "TransactionAPI", - block_id: Optional["BlockID"] = None, - state: Optional[dict] = None, + block_id: "BlockID | None" = None, + state: dict | None = None, **kwargs, ) -> HexBytes: # NOTE: Using JSON mode since used as request data. @@ -472,13 +472,13 @@ def set_timestamp(self, new_timestamp: int): def mine(self, num_blocks: int = 1): self.evm_backend.mine_blocks(num_blocks) - def get_balance(self, address: AddressType, block_id: Optional["BlockID"] = None) -> int: + def get_balance(self, address: AddressType, block_id: "BlockID | None" = None) -> int: # perf: Using evm_backend directly instead of going through web3. return self.evm_backend.get_balance( HexBytes(address), block_number="latest" if block_id is None else block_id ) - def get_nonce(self, address: AddressType, block_id: Optional["BlockID"] = None) -> int: + def get_nonce(self, address: AddressType, block_id: "BlockID | None" = None) -> int: return self.evm_backend.get_nonce( HexBytes(address), block_number="latest" if block_id is None else block_id )