diff --git a/setup.py b/setup.py index 6f85642d60..f79c2da3e4 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,7 @@ "pytest-timeout>=2.2.0,<3", # For avoiding timing out during tests "hypothesis>=6.2.0,<7.0", # Strategy-based fuzzer "hypothesis-jsonschema==0.19.0", # JSON Schema fuzzer extension + "pandas", # Needed to test w/ narwhals ], "lint": [ "ruff>=0.10.0", # Unified linter and formatter @@ -92,10 +93,9 @@ "lazyasd>=0.1.4", "asttokens>=2.4.1,<3", # Peer dependency; w/o pin container build fails. "cchecksum>=0.0.3,<1", - # Pandas peer-dep: Numpy 2.0 causes issues for some users. - "numpy<2", + "more-itertools; python_version<'3.10'", # backport for `itertools.pairwise` + "narwhals>=1.29,<2", "packaging>=23.0,<24", - "pandas>=2.2.2,<3", "pluggy>=1.3,<2", "pydantic>=2.10.0,<3", "pydantic-settings>=2.5.2,<3", diff --git a/src/ape/api/config.py b/src/ape/api/config.py index 01113c053c..a49ade0454 100644 --- a/src/ape/api/config.py +++ b/src/ape/api/config.py @@ -7,7 +7,8 @@ import yaml from ethpm_types import PackageManifest, PackageMeta, Source -from pydantic import ConfigDict, Field, ValidationError, model_validator +from narwhals.stable.v1 import Implementation as DataframeImplementation +from pydantic import ConfigDict, Field, ValidationError, field_validator, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict from ape.exceptions import ConfigError @@ -202,6 +203,17 @@ class DeploymentConfig(PluginConfig): """ +class QueryConfig(PluginConfig): + """Add 'query:' key to your config.""" + + backend: DataframeImplementation = DataframeImplementation.PANDAS + """Which Narwhals backend implementation to use.""" + + @field_validator("backend", mode="before") + def convert_backend_str(cls, value: Any) -> DataframeImplementation: + return DataframeImplementation.from_backend(value) + + def _get_problem_with_config(errors: list, path: Path) -> Optional[str]: # Attempt to find line numbers in the config matching. cfg_content = Source(content=path.read_text(encoding="utf8")).content @@ -368,6 +380,8 @@ def __init__(self, *args, **kwargs): The version of the project. """ + query: QueryConfig = QueryConfig() + # NOTE: Plugin configs are technically "extras". model_config = SettingsConfigDict(extra="allow", env_prefix="APE_") diff --git a/src/ape/api/providers.py b/src/ape/api/providers.py index 922d4b96cb..5e028bbf81 100644 --- a/src/ape/api/providers.py +++ b/src/ape/api/providers.py @@ -19,9 +19,6 @@ from eth_utils import to_hex from pydantic import Field, computed_field, field_serializer, model_validator -from ape.api.networks import NetworkAPI -from ape.api.query import BlockTransactionQuery -from ape.api.transactions import ReceiptAPI, TransactionAPI from ape.exceptions import ( APINotImplementedError, ProviderError, @@ -44,6 +41,9 @@ from ape.utils.process import JoinableQueue, spawn from ape.utils.rpc import RPCHeaders +from .networks import NetworkAPI +from .transactions import ReceiptAPI, TransactionAPI + if TYPE_CHECKING: from eth_pydantic_types import HexBytes from ethpm_types.abi import EventABI @@ -151,12 +151,21 @@ def transactions(self) -> list[TransactionAPI]: """ All transactions in a block. """ + from ape.api.query import BlockTransactionQuery + if self.hash is None: - # Unable to query transactions. + # NOTE: Only "unsealed" blocks do not have a hash + raise ProviderError("Unable to find block transactions: not sealed yet") + + elif self.num_transactions == 0: return [] try: - query = BlockTransactionQuery(columns=["*"], block_id=self.hash) + query = BlockTransactionQuery( + columns=["*"], + num_transactions=self.num_transactions, + block_id=self.hash, + ) return cast(list[TransactionAPI], list(self.query_manager.query(query))) except QueryEngineError as err: # NOTE: Re-raising a better error here because was confusing diff --git a/src/ape/api/query.py b/src/ape/api/query.py index 5198bd5324..17c9852447 100644 --- a/src/ape/api/query.py +++ b/src/ape/api/query.py @@ -1,24 +1,30 @@ from abc import abstractmethod from collections.abc import Iterator, Sequence from functools import cache, cached_property -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union +import narwhals as nw from ethpm_types.abi import EventABI, MethodABI -from pydantic import NonNegativeInt, PositiveInt, model_validator +from pydantic import NonNegativeInt, PositiveInt, field_validator, model_validator -from ape.api.transactions import ReceiptAPI, TransactionAPI from ape.logging import logger +from ape.types import ContractLog from ape.types.address import AddressType from ape.utils.basemodel import BaseInterface, BaseInterfaceModel, BaseModel -QueryType = Union[ - "BlockQuery", - "BlockTransactionQuery", - "AccountTransactionQuery", - "ContractCreationQuery", - "ContractEventQuery", - "ContractMethodQuery", -] +from .providers import BlockAPI +from .transactions import ReceiptAPI, TransactionAPI + +if TYPE_CHECKING: + from narwhals.typing import Frame + + from ape.managers.query import QueryResult + + try: + # Only on Python 3.11 + from typing import Self # type: ignore + except ImportError: + from typing_extensions import Self # type: ignore @cache @@ -93,13 +99,58 @@ def extract_fields(item: BaseInterfaceModel, columns: Sequence[str]) -> list[Any return [getattr(item, col, None) for col in columns] -class _BaseQuery(BaseModel): - columns: Sequence[str] +ModelType = TypeVar("ModelType", bound=BaseInterfaceModel) + + +class _BaseQuery(BaseModel, Generic[ModelType]): + Model: ClassVar[Optional[type[BaseInterfaceModel]]] = None + + columns: set[str] + + @field_validator("columns", mode="before") + def expand_wildcard(cls, value: Any) -> Any: + if cls.Model: + return validate_and_expand_columns(value, cls.Model) + + return value + + # Methods for determining query "coverage" and constraining search + @property + def start_index(self) -> int: + raise NotImplementedError() + + @property + def end_index(self) -> int: + raise NotImplementedError() + + def __len__(self) -> int: + return self.end_index - self.start_index - # TODO: Support "*" from getting the EcosystemAPI fields + def __contains__(self, other: Any) -> bool: + if not isinstance(other, _BaseQuery): + raise ValueError() + + # NOTE: Return True if `other` is "covered by" `self` + return other.start_index >= self.start_index and other.end_index <= self.end_index + + # Methods for determining query "ordering" + def __lt__(self, other: Any) -> bool: + if not isinstance(other, _BaseQuery): + raise ValueError() + + if self.start_index < other.start_index: + return True + + elif self.start_index == other.start_index: + # NOTE: If start matches, return True for smaller range covered + return self.end_index < other.end_index + + else: + return False class _BaseBlockQuery(_BaseQuery): + Model = BlockAPI start_block: NonNegativeInt = 0 stop_block: NonNegativeInt step: PositiveInt = 1 @@ -121,35 +172,55 @@ def check_start_block_before_stop_block(cls, values): return values + @property + def start_index(self) -> int: + return self.start_block + + @property + def end_index(self) -> int: + return self.stop_block + -class BlockQuery(_BaseBlockQuery): +class BlockQuery(_BaseBlockQuery, _BaseQuery[BlockAPI]): """ A ``QueryType`` that collects properties of ``BlockAPI`` over a range of blocks between ``start_block`` and ``stop_block``. """ -class BlockTransactionQuery(_BaseQuery): +class BlockTransactionQuery(_BaseQuery[TransactionAPI]): """ A ``QueryType`` that collects properties of ``TransactionAPI`` over a range of transactions collected inside the ``BlockAPI` object represented by ``block_id``. """ + Model = TransactionAPI + block_id: Any + num_transactions: NonNegativeInt + + @property + def start_index(self) -> int: + return 0 + + @property + def end_index(self) -> int: + return self.num_transactions - 1 -class AccountTransactionQuery(_BaseQuery): +class AccountTransactionQuery(_BaseQuery[TransactionAPI]): """ A ``QueryType`` that collects properties of ``TransactionAPI`` over a range of transactions made by ``account`` between ``start_nonce`` and ``stop_nonce``. """ + Model = TransactionAPI + account: AddressType start_nonce: NonNegativeInt = 0 stop_nonce: NonNegativeInt @model_validator(mode="before") - @classmethod def check_start_nonce_before_stop_nonce(cls, values: dict) -> dict: if values["stop_nonce"] < values["start_nonce"]: raise ValueError( @@ -159,17 +230,16 @@ def check_start_nonce_before_stop_nonce(cls, values: dict) -> dict: return values + @property + def start_index(self) -> int: + return self.start_nonce -class ContractCreationQuery(_BaseQuery): - """ - A ``QueryType`` that obtains information about contract deployment. - Returns ``ContractCreation(txn_hash, block, deployer, factory)``. - """ - - contract: AddressType + @property + def end_index(self) -> int: + return self.stop_nonce -class ContractCreation(BaseModel, BaseInterface): +class ContractCreation(BaseInterfaceModel): """ Contract-creation metadata, such as the transaction and deployer. Useful for contract-verification, @@ -228,18 +298,40 @@ def from_receipt(cls, receipt: ReceiptAPI) -> "ContractCreation": ) -class ContractEventQuery(_BaseBlockQuery): +class ContractCreationQuery(_BaseQuery[ContractCreation]): + """ + A ``QueryType`` that obtains information about contract deployment. + Returns ``ContractCreation(txn_hash, block, deployer, factory)``. + """ + + Model = ContractCreation + + contract: AddressType + + @property + def start_index(self) -> int: + return 0 + + @property + def end_index(self) -> int: + # TODO: Can this support multiple instances? Do we care anymore? + return 1 + + +class ContractEventQuery(_BaseBlockQuery, _BaseQuery[ContractLog]): """ A ``QueryType`` that collects members from ``event`` over a range of logs emitted by ``contract`` between ``start_block`` and ``stop_block``. """ + Model = ContractLog + contract: Union[list[AddressType], AddressType] event: EventABI search_topics: Optional[dict[str, Any]] = None -class ContractMethodQuery(_BaseBlockQuery): +class ContractMethodQuery(_BaseBlockQuery, _BaseQuery[Any]): """ A ``QueryType`` that collects return values from calling ``method`` in ``contract`` over a range of blocks between ``start_block`` and ``stop_block``. @@ -250,7 +342,162 @@ class ContractMethodQuery(_BaseBlockQuery): method_args: dict[str, Any] -class QueryAPI(BaseInterface): +class CursorAPI(BaseInterfaceModel, Generic[ModelType]): + query: _BaseQuery[ModelType] + + def shrink( + self, + start_index: Optional[int] = None, + end_index: Optional[int] = None, + ) -> "Self": + """ + Create a copy of this object with the query window shrunk inwards to `start_index` and/or + `end_index`. Note that `.shrink` should always be called with strictly less coverage than + original query window of this cursor model for use in the `QueryManager`'s solver algorithm. + + Args: + start_index (Optional[int]): The new `start_index` that this cursor should start at. + end_index (Optional[int]): The new `end_index` that this cursor should start at. + + Returns: + Self: a copy of itself, only with the smaller query window applied. + """ + raise NotImplementedError + + @property + def total_time(self) -> float: + """ + The estimated total time that this cursor would take to execute. Note that this is only an + approximation, but should be relatively accurate for the `QueryManager`'s solver algorithm + to work well. Is used for printing metrics to the user. + + Default implementation of this property is the span of this cursor times `.time_per_row`. + + Returns: + float: Time (in seconds) that the query should take to execute fully. + """ + return (self.query.end_index - self.query.start_index) * (self.time_per_row) + + @property + @abstractmethod + def time_per_row(self) -> float: + """ + The estimated average time spent (per row) that this cursor would take to execute. Note + that this is only an approximation, but should be relatively accurate for the + `QueryManager`'s solver algorithm to work well. Is used for determining the correct + ordering of cursor's within the solver algorithm. + + Returns: + float: Average time (in seconds) that the query should take to execute a single row. + """ + + # Conversion out to fulfill user query requirements + def as_dataframe(self, backend: nw.Implementation) -> "Frame": + """ + Execute and return this Cursor as a `~narwhals.v1.DataFrame` or `~narwhals.v1.LazyFrame` + object. The use of `backend is exactly as it is mentioned in the `narwhals` documentation: + https://narwhals-dev.github.io/narwhals/api-reference/typing/#narwhals.typing.Frame + + It is recommended to use whatever method of conversion makes sense within your query + plugin, for example you can use `~narwhals.from_dict` to convert results into a Frame: + https://narwhals-dev.github.io/narwhals/api-reference/narwhals/#narwhals.from_dict + + Default implementation of this method uses `.as_model_iter()` to fulfill this requirement. + + Args: + backend (:object:`~narwhals.Implementation): A Narwhals-compatible backend specifier. + See: https://narwhals-dev.github.io/narwhals/api-reference/implementation/ + + Returns: + (`~narwhals.v1.DataFrame` | `~narwhals.v1.LazyFrame`): A narwhals dataframe. + """ + data: dict[str, list] = {column: [] for column in self.query.columns} + + for item in self.as_model_iter(): + for column in data: + data[column].append(getattr(item, column)) + + return nw.from_dict(data, backend=backend) + + @abstractmethod + def as_model_iter(self) -> Iterator[ModelType]: + """ + Execute and return this Cursor as an iterated sequence of `ModelType` objects. This will + be used for Ape's internal APIs in order to fulfill certain higher-level use cases within + Ape. Note that a plugin is expected to assemble this iterated sequence in the most + efficient manner possible. + + Returns: + `Iterator[ModelType]`: A sequence of Ape API models. + """ + + +QueryType = Union[ + AccountTransactionQuery, + BlockQuery, + BlockTransactionQuery, + ContractCreationQuery, + ContractEventQuery, + ContractMethodQuery, +] + + +class QueryEngineAPI(BaseInterface): + def exec(self, query: QueryType) -> Iterator[CursorAPI]: + """ + Obtain `CursorAPI` object(s) that may covers (subset of) `query`. A plugin should yield + one or more cursor(s) that covers some subset of the length of `query`'s row-space, as + indicated by `QueryType.start_index` and `QueryType.end_index`. These query types will + either be fed into an algorithm to determine the cheapest possible coverage of the query, + or be sourced directly from the provider in response to a user-specified query. + + Note that this method encourages the use of `@singledispatchmethod` decorator to make it + possible to specify only certain types of queries that your plugin might be able to handle, + which will cause it to skip using this plugin for non-overriden queries by default, as this + method yields an empty iterator which will indicate that your plugin can be skipped. + + Add `exec = functools.singledispatchmethod(QueryEngineAPI.exec)` to your subclass, and then + `@exec.register` as a decorator on your method in order to support particular query types. + + Args: + query (`~QueryType`): The query being handled by this method. + + Returns: + Iterator[`~CursorAPI`]: Zero (or more) cursor(s) that provide data for a portion of + `query`'s range. Defaults to not providing any coverage. + + Usage example:: + + >>> from functools import singledispatchmethod + >>> from ape.api import CursorAPI, QueryEngineAPI + >>> class PluginCursor(CursorAPI): + ... ... # See `CursorAPI`'s documentation for methods to implement + >>> class PluginQueryEngine(QueryEngineAPI): + ... # NOTE: Do this if you want to define multiple dispatch handlers easily + ... exec = singledispatchmethod(QueryEngineAPI.exec) + ... # NOTE: Do *not* use the name `exec` for the dispatch method's name! + ... @exec.register + ... def exec_queryX(self, query: SomethingQuery) -> Iterator[PluginCursor]: + ... yield PluginCursor(query=query, ...) + ... # NOTE: Can yield more cursors if plugin does not have full coverage, + ... # or has piece-wise coverage of the query space + + """ + return iter([]) # Will avoid using any cursors from this plugin for querying this type + + def cache(self, result: "QueryResult"): + """ + Once a query is solved, this method will be called on every query plugin as a callback for + whatever application logic you might want to perform using the final `QueryResult` cursor. + By default, this method does nothing, so only override if it is needed to perform specific + application logic for your plugin (caching, pre-indexing, etc.) + + Args: + result (`~ape.managers.query.QueryResult`): the final solved Cursor representing all + the data that most efficiently covers the original `~QueryType`. + """ + + # TODO: Deprecate below in v0.9 @abstractmethod def estimate_query(self, query: QueryType) -> Optional[int]: """ @@ -287,3 +534,7 @@ def update_cache(self, query: QueryType, result: Iterator[BaseInterfaceModel]): query (``QueryType``): query that was executed result (``Iterator``): the result of the query """ + + +# TODO: Remove in v0.9 +QueryAPI = QueryEngineAPI diff --git a/src/ape/contracts/base.py b/src/ape/contracts/base.py index c6532a2149..40f15fd64d 100644 --- a/src/ape/contracts/base.py +++ b/src/ape/contracts/base.py @@ -1,23 +1,19 @@ import difflib import types from collections.abc import Callable, Iterator -from functools import cached_property, partial, singledispatchmethod +from functools import cached_property, singledispatchmethod from itertools import islice from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union, cast import click +import narwhals.stable.v1 as nw from eth_pydantic_types import HexBytes from eth_utils import to_hex from ethpm_types.abi import EventABI from ape.api.address import Address, BaseAddress -from ape.api.query import ( - ContractCreation, - ContractEventQuery, - extract_fields, - validate_and_expand_columns, -) +from ape.api.query import ContractCreation, ContractEventQuery, validate_and_expand_columns from ape.exceptions import ( ApeAttributeError, ArgumentsLengthError, @@ -46,7 +42,7 @@ if TYPE_CHECKING: from ethpm_types.abi import ConstructorABI, ErrorABI, MethodABI from ethpm_types.contract_type import ABI_W_SELECTOR_T, ContractType - from pandas import DataFrame + from narwhals.typing import Frame from ape.api.networks import ProxyInfoAPI from ape.api.transactions import ReceiptAPI, TransactionAPI @@ -623,7 +619,8 @@ def query( stop_block: Optional[int] = None, step: int = 1, engine_to_use: Optional[str] = None, - ) -> "DataFrame": + backend: Union[str, nw.Implementation, None] = None, + ) -> "Frame": """ Iterate through blocks for log events @@ -638,13 +635,12 @@ def query( Defaults to ``1``. engine_to_use (Optional[str]): query engine to use, bypasses query engine selection algorithm. + backend (Union[:object:`~narwhals.Implementation, str None]): A Narwhals-compatible + backend. See: https://narwhals-dev.github.io/narwhals/api-reference/implementation Returns: - pd.DataFrame + :class:`~narwhals.typing.Frame` """ - # perf: pandas import is really slow. Avoid importing at module level. - import pandas as pd - HEAD = self.chain_manager.blocks.height if start_block < 0: start_block = HEAD + start_block @@ -670,13 +666,25 @@ def query( # Only query for a specific contract when checking an instance. query["contract"] = self.contract.address + # TODO: In v0.9, just use `result.as_dataframe(backend=backend)` API contract_event_query = ContractEventQuery(**query) contract_events = self.query_manager.query( contract_event_query, engine_to_use=engine_to_use ) columns_ls = validate_and_expand_columns(columns, ContractLog) - data = map(partial(extract_fields, columns=columns_ls), contract_events) - return pd.DataFrame(columns=columns_ls, data=data) + + data: dict[str, list] = {column: [] for column in columns_ls} + for log in contract_events: + for column in data: + data[column].append(getattr(log, column)) + + if backend is None: + backend = cast(nw.Implementation, self.config_manager.query.backend) + + elif isinstance(backend, str): + backend = nw.Implementation.from_backend(backend) + + return nw.from_dict(data=data, backend=backend) def range( self, diff --git a/src/ape/managers/chain.py b/src/ape/managers/chain.py index b285e0a2a2..e7f8393c9e 100644 --- a/src/ape/managers/chain.py +++ b/src/ape/managers/chain.py @@ -1,22 +1,17 @@ from collections import defaultdict from collections.abc import Iterator from contextlib import contextmanager -from functools import cached_property, partial, singledispatchmethod +from functools import cached_property, singledispatchmethod from statistics import mean, median from typing import IO, TYPE_CHECKING, Optional, Union, cast -import pandas as pd +import narwhals.stable.v1 as nw from rich.box import SIMPLE from rich.table import Table from ape.api.address import BaseAddress from ape.api.providers import BlockAPI -from ape.api.query import ( - AccountTransactionQuery, - BlockQuery, - extract_fields, - validate_and_expand_columns, -) +from ape.api.query import AccountTransactionQuery, BlockQuery, validate_and_expand_columns from ape.api.transactions import ReceiptAPI from ape.exceptions import ( APINotImplementedError, @@ -35,6 +30,7 @@ from ape.utils.misc import ZERO_ADDRESS, is_evm_precompile, is_zero_hex, log_instead_of_fail if TYPE_CHECKING: + from narwhals.typing import Frame from rich.console import Console as RichConsole from ape.types import BlockID, ContractCode, GasReport, SnapshotID, SourceTraceback @@ -123,7 +119,8 @@ def query( stop_block: Optional[int] = None, step: int = 1, engine_to_use: Optional[str] = None, - ) -> pd.DataFrame: + backend: Union[str, nw.Implementation, None] = None, + ) -> "Frame": """ A method for querying blocks and returning an Iterator. If you do not provide a starting block, the 0 block is assumed. If you do not @@ -144,9 +141,11 @@ def query( Defaults to ``1``. engine_to_use (Optional[str]): query engine to use, bypasses query engine selection algorithm. + backend (Union[:object:`~narwhals.Implementation, str None]): A Narwhals-compatible + backend. See: https://narwhals-dev.github.io/narwhals/api-reference/implementation Returns: - pd.DataFrame + :class:`~narwhals.typing.Frame` """ if start_block < 0: @@ -170,13 +169,23 @@ def query( step=step, ) + # TODO: In v0.9, just use `result.as_dataframe(backend=backend)` API blocks = self.query_manager.query(query, engine_to_use=engine_to_use) columns: list[str] = validate_and_expand_columns( # type: ignore columns, self.head.__class__ ) - extraction = partial(extract_fields, columns=columns) - data = map(lambda b: extraction(b), blocks) - return pd.DataFrame(columns=columns, data=data) + data: dict[str, list] = {column: [] for column in columns} + for block in blocks: + for column in data: + data[column].append(getattr(block, column)) + + if backend is None: + backend = cast(nw.Implementation, self.config_manager.query.backend) + + elif isinstance(backend, str): + backend = nw.Implementation.from_backend(backend) + + return nw.from_dict(data=data, backend=backend) def range( self, @@ -351,7 +360,8 @@ def query( start_nonce: int = 0, stop_nonce: Optional[int] = None, engine_to_use: Optional[str] = None, - ) -> pd.DataFrame: + backend: Union[str, nw.Implementation, None] = None, + ) -> "Frame": """ A method for querying transactions made by an account and returning an Iterator. If you do not provide a starting nonce, the first transaction is assumed. @@ -370,9 +380,11 @@ def query( in the query. Defaults to the latest transaction. engine_to_use (Optional[str]): query engine to use, bypasses query engine selection algorithm. + backend (Union[:object:`~narwhals.Implementation, str None]): A Narwhals-compatible + backend. See: https://narwhals-dev.github.io/narwhals/api-reference/implementation Returns: - pd.DataFrame + :class:`~narwhals.typing.Frame` """ if start_nonce < 0: @@ -396,11 +408,21 @@ def query( stop_nonce=stop_nonce, ) + # TODO: In v0.9, just use `result.as_dataframe(backend=backend)` API txns = self.query_manager.query(query, engine_to_use=engine_to_use) columns = validate_and_expand_columns(columns, ReceiptAPI) # type: ignore - extraction = partial(extract_fields, columns=columns) - data = map(lambda tx: extraction(tx), txns) - return pd.DataFrame(columns=columns, data=data) + data: dict[str, list] = {column: [] for column in columns} + for txn in txns: + for column in data: + data[column].append(getattr(txn, column)) + + if backend is None: + backend = cast(nw.Implementation, self.config_manager.query.backend) + + elif isinstance(backend, str): + backend = nw.Implementation.from_backend(backend) + + return nw.from_dict(data=data, backend=backend) def __iter__(self) -> Iterator[ReceiptAPI]: # type: ignore[override] yield from self.outgoing diff --git a/src/ape/managers/query.py b/src/ape/managers/query.py index 86cd971b2c..efc7119746 100644 --- a/src/ape/managers/query.py +++ b/src/ape/managers/query.py @@ -1,17 +1,25 @@ import difflib +import os import time from collections.abc import Iterator from functools import cached_property, singledispatchmethod from itertools import tee -from typing import Optional +from typing import TYPE_CHECKING, Optional, Union, cast +import narwhals as nw +from pydantic import model_validator + +# TODO: Switch to `import narwhals.v1 as nw` per narwhals documentation from ape.api.query import ( AccountTransactionQuery, BaseInterfaceModel, BlockQuery, BlockTransactionQuery, ContractEventQuery, + CursorAPI, + ModelType, QueryAPI, + QueryEngineAPI, QueryType, ) from ape.api.transactions import ReceiptAPI, TransactionAPI @@ -21,16 +29,173 @@ from ape.plugins._utils import clean_plugin_name from ape.utils.basemodel import ManagerAccessMixin +try: + from itertools import pairwise + +except ImportError: + # TODO: Remove when 3.9 dropped (`itertools.pairwise` introduced in 3.10) + from more_itertools import pairwise # type: ignore[import-not-found,no-redef,assignment] + + +if TYPE_CHECKING: + from narwhals.typing import Frame + + from ape.api.providers import BlockAPI + + try: + # Only on Python 3.11 + from typing import Self # type: ignore + except ImportError: + from typing_extensions import Self # type: ignore + + +class _RpcCursor(CursorAPI): + def shrink( + self, + start_index: Optional[int] = None, + end_index: Optional[int] = None, + ) -> "Self": + copy = self.model_copy(deep=True) + + if start_index is not None: + copy.query.start_block = start_index + + if end_index is not None: + copy.query.stop_block = end_index + + return copy + + @property + def time_per_row(self) -> float: + # NOTE: Very loose estimate of 100ms per item + return 0.1 # seconds + + def as_dataframe(self, backend: nw.Implementation) -> "Frame": + data: dict[str, list] = {column: [] for column in self.query.columns} + + for item in self.as_model_iter(): + for column in data: + data[column] = getattr(item, column) + + return nw.from_dict(data, backend=backend) + + +class _RpcBlockCursor(_RpcCursor): + query: BlockQuery + + def as_model_iter(self) -> Iterator["BlockAPI"]: + return map( + self.provider.get_block, + # NOTE: the range stop block is a non-inclusive stop. + # Where the query method is an inclusive stop. + range(self.query.start_block, self.query.stop_block + 1, self.query.step), + ) + + +class _RpcBlockTransactionCursor(_RpcCursor): + query: BlockTransactionQuery -class DefaultQueryProvider(QueryAPI): + # TODO: Move to default implementation in `CursorAPI`? (remove `@abstractmethod`) + def shrink( + self, + start_index: Optional[int] = None, + end_index: Optional[int] = None, + ) -> "Self": + if (start_index and start_index != 0) or ( + end_index and end_index != self.query.num_transactions + ): + # NOTE: Not possible to shrink this query (also, should never need to be shrunk unless + # different Engines mismatch block on number of transactions in block) + raise NotImplementedError + + return self + + def as_model_iter(self) -> Iterator[TransactionAPI]: + if self.query.num_transactions > 0: + yield from self.provider.get_transactions_by_block(self.query.block_id) + + +class _RpcContractEventCursor(_RpcCursor): + query: ContractEventQuery + + def as_model_iter(self) -> Iterator[ContractLog]: + addresses = self.query.contract + if not isinstance(addresses, list): + addresses = [self.query.contract] # type: ignore + + log_filter = LogFilter.from_event( + event=self.query.event, + search_topics=self.query.search_topics, + addresses=addresses, + start_block=self.query.start_block, + stop_block=self.query.stop_block, + ) + return self.provider.get_contract_logs(log_filter) + + +class _RpcAccountTransactionCursor(_RpcCursor): + query: AccountTransactionQuery + + def shrink( + self, + start_index: Optional[int] = None, + end_index: Optional[int] = None, + ) -> "Self": + copy = self.model_copy(deep=True) + + if start_index is not None: + copy.query.start_nonce = start_index + + if end_index is not None: + copy.query.stop_nonce = end_index + + return copy + + @property + def time_per_row(self) -> float: + # NOTE: Extremely expensive query, involves binary search of all blocks in a chain + # Very loose estimate of 5s per transaction for this query. + return 5.0 + + def as_model_iter(self) -> Iterator[TransactionAPI]: + yield from self.provider.get_transactions_by_account_nonce( + self.query.account, self.query.start_nonce, self.query.stop_nonce + ) + + +class DefaultQueryProvider(QueryEngineAPI): """ - Default implementation of the :class:`~ape.api.query.QueryAPI`. + Default implementation of the :class:`~ape.api.query.QueryEngineAPI`. Allows for the query of blockchain data using connected provider. """ - def __init__(self): - self.supports_contract_creation = None + @singledispatchmethod + def exec(self, query: QueryType) -> Iterator[CursorAPI]: # type: ignore[override] + return super().exec(query) + + @exec.register + def exec_block_query(self, query: BlockQuery) -> Iterator[_RpcBlockCursor]: + yield _RpcBlockCursor(query=query) + + @exec.register + def exec_block_transaction_query( + self, query: BlockTransactionQuery + ) -> Iterator[_RpcBlockTransactionCursor]: + yield _RpcBlockTransactionCursor(query=query) + + @exec.register + def exec_contract_event_query( + self, query: ContractEventQuery + ) -> Iterator[_RpcContractEventCursor]: + yield _RpcContractEventCursor(query=query) + @exec.register + def exec_account_transaction_query( + self, query: AccountTransactionQuery + ) -> Iterator[_RpcAccountTransactionCursor]: + yield _RpcAccountTransactionCursor(query=query) + + # TODO: Remove below in v0.9 @singledispatchmethod def estimate_query(self, query: QueryType) -> Optional[int]: # type: ignore return None # can't handle this query @@ -99,6 +264,63 @@ def perform_account_transactions_query( ) +class QueryResult(CursorAPI[ModelType]): + cursors: list[CursorAPI[ModelType]] + """The optimal set of cursors (in sorted order) that fulfill this query.""" + + @model_validator(mode="after") + def validate_coverage(self): + # NOTE: This is done to assert that we have full coverage of queries during testing + # (both testing Core and in 2nd/3rd party plugins) + current_pos = self.query.start_index + for i, cursor in enumerate(self.cursors): + logger.debug( + "Start:", + cursor.query.start_index, + "End:", + cursor.query.end_index, + "Total:", + cursor.total_time, + "seconds", + ) + assert cursor.query.start_index == current_pos, ( + f"Cursor {i} starts at {cursor.query.start_index}, expected {current_pos}" + ) + current_pos = cursor.query.end_index + 1 + + assert current_pos == self.query.end_index + 1, ( + f"Coverage ended at {current_pos - 1}, expected {self.query.end_index}" + ) + + return self + + @property + def total_time(self) -> float: + return sum(c.total_time for c in self.cursors) + + @property + def time_per_row(self) -> float: + return self.total_time / sum(len(c.query) for c in self.cursors) + + # Conversion out to fulfill user query requirements + def as_dataframe( + self, + backend: Union[str, nw.Implementation, None] = None, + ) -> "Frame": + if backend is None: + backend = cast(nw.Implementation, self.config_manager.query.backend) + + elif isinstance(backend, str): + backend = nw.Implementation.from_backend(backend) + + # TODO: Source `backend` from core `query:` config if defaulted to `None` + return nw.concat([c.as_dataframe(backend=backend) for c in self.cursors], how="vertical") + + def as_model_iter(self) -> Iterator[ModelType]: + for result in self.cursors: + yield from result.as_model_iter() + + class QueryManager(ManagerAccessMixin): """ A singleton that manages query engines and performs queries. @@ -132,6 +354,114 @@ def engines(self) -> dict[str, QueryAPI]: def _suggest_engines(self, engine_selection): return difflib.get_close_matches(engine_selection, list(self.engines), cutoff=0.6) + def _solve_optimal_coverage( + self, + query: QueryType, + all_cursors: list[CursorAPI], + ) -> Iterator[CursorAPI]: + # NOTE: Use this to reduce the amount of brute force iteration over query window + query_segments = sorted( + set( + [c.query.start_index for c in all_cursors] + + [c.query.end_index for c in all_cursors] + ) + ) + + # Find the best cursor that fits each path segment in `cursor_to_use` + # NOTE: Prime these variables for every time "best cursor" gets yielded + last_start_index = query.start_index + # NOTE: Start with smallest cursor by coverage and total time + # (resolves corner case when `query.start_index` == `query.end_index`) + last_best_cursor = min(all_cursors, key=lambda c: (c.query, c.total_time)) + for start_index, end_index in pairwise(query_segments): + lowest_unit_time = float("inf") + best_cursor = None + for cursor in all_cursors: + # NOTE: Cursor window must at least contain path segment + if cursor.query.start_index <= start_index and cursor.query.end_index >= end_index: + # NOTE: Allow cursor to use previous segment(s) if it was the last best + # since time should typically be better with larger coverage + shrunk_cursor = cursor.shrink( + start_index=( + last_start_index + if last_best_cursor and last_best_cursor is cursor + else start_index + ) + ) + if shrunk_cursor.time_per_row < lowest_unit_time: + lowest_unit_time = shrunk_cursor.time_per_row + # NOTE: Save original cursor to shrink later (not shrunk one) + best_cursor = cursor + + if best_cursor is None: + # NOTE: `AssertionError` because this should not be possible due to RPC engine + raise AssertionError( + f"Could not solve, missing coverage in window [{start_index}:{end_index}]." + ) + logger.debug(f"Best cursor for segment [{start_index}:{end_index}]: {best_cursor}") + + if last_best_cursor is None: + # NOTE: Should only execute first time + last_best_cursor = best_cursor + + elif last_best_cursor != best_cursor: + # NOTE: Yield whatever the last "best cursor" was, + # shrunk up to just before current segment + yield last_best_cursor.shrink( + start_index=last_start_index, + end_index=start_index - 1, + ) + # NOTE: Update our yield variables for next time + last_start_index = start_index + last_best_cursor = best_cursor + + # else: last best is also current best, keep iterating until better one is found + + # NOTE: Always yield last best after loop ends, which contain the final part of query + assert last_best_cursor, "This shouldn't happen best >2 endpoints exist" # mypy happy + yield last_best_cursor.shrink(start_index=last_start_index) + + def _experimental_query( + self, + query: QueryType, + engine_to_use: Optional[str] = None, + ) -> QueryResult: + if not engine_to_use: + # Sort by earliest point in cursor window (then by longest coverage if same start) + # NOTE: We will iterate over this >1 times, so collect our iterator here + all_cursors = sorted( + (c for engine in self.engines.values() for c in engine.exec(query)), + key=lambda c: c.query, + ) + + elif selected_engine := self.engines.get(engine_to_use): + all_cursors = list(selected_engine.exec(query)) + + else: + raise QueryEngineError( + f"Query engine `{engine_to_use}` not found. " + f"Did you mean {' or '.join(self._suggest_engines(engine_to_use))}?" + ) + + if len(all_cursors) == 0: + # NOTE: Likely indicates a problem with the default or selected query engine + raise QueryEngineError(f"No data available for {query.__class__.__name__}") + + logger.debug("Sorted cursors:\n " + "\n ".join(map(str, all_cursors))) + result: QueryResult = QueryResult( + query=query, + cursors=list(self._solve_optimal_coverage(query, all_cursors)), + ) + + # TODO: Execute in background thread when async support introduced + for engine_name, engine in self.engines.items(): + logger.debug(f"Caching w/ '{engine_name}' ...") + engine.cache(result) + logger.debug(f"Caching done for '{engine_name}'") + + return result + + # TODO: Replace `.query` with `._experimental_query` and remove this in v0.9 def query( self, query: QueryType, @@ -150,6 +480,8 @@ def query( Returns: Iterator[``BaseInterfaceModel``] """ + if os.environ.get("APE_ENABLE_EXPERIMENTAL_QUERY_BACKEND", False): + return self._experimental_query(query, engine_to_use=engine_to_use).as_model_iter() if engine_to_use: if engine_to_use not in self.engines: diff --git a/src/ape_cache/query.py b/src/ape_cache/query.py index 6f7e0ac036..1704715f2d 100644 --- a/src/ape_cache/query.py +++ b/src/ape_cache/query.py @@ -1,478 +1,129 @@ from collections.abc import Iterator from functools import singledispatchmethod from pathlib import Path -from typing import Any, Optional, cast +from typing import TYPE_CHECKING, Optional -from sqlalchemy import create_engine, func -from sqlalchemy.engine import CursorResult -from sqlalchemy.sql import column, insert, select -from sqlalchemy.sql.expression import Insert, Select +import narwhals as nw from ape.api.providers import BlockAPI -from ape.api.query import ( - BaseInterfaceModel, - BlockQuery, - BlockTransactionQuery, - ContractEventQuery, - QueryAPI, - QueryType, -) -from ape.api.transactions import TransactionAPI +from ape.api.query import BaseInterfaceModel, BlockQuery, CursorAPI, QueryEngineAPI, QueryType from ape.exceptions import QueryEngineError -from ape.logging import logger -from ape.types.events import ContractLog -from ape.utils.misc import LOCAL_NETWORK_NAME -from . import models -from .models import Blocks, ContractEvents, Transactions +if TYPE_CHECKING: + from narwhals.typing import Frame + try: + # Only on Python 3.11 + from typing import Self # type: ignore + except ImportError: + from typing_extensions import Self # type: ignore -class CacheQueryProvider(QueryAPI): - """ - Default implementation of the :class:`~ape.api.query.QueryAPI`. - Allows for the query of blockchain data using a connected provider. - """ - - # Class var for tracking if we detect a scenario where the cache db isn't working - database_bypass = False - - def _get_database_file(self, ecosystem_name: str, network_name: str) -> Path: - """ - Allows us to figure out what the file *will be*, mostly used for database management. - - Args: - ecosystem_name (str): Name of the ecosystem to store data for (ex: ethereum) - network_name (str): name of the network to store data for (ex: mainnet) - - Raises: - :class:`~ape.exceptions.QueryEngineError`: If a local network is provided. - """ - - if network_name == LOCAL_NETWORK_NAME: - # NOTE: no need to cache local network, no use for data - raise QueryEngineError("Cannot cache local data") - - if "-fork" in network_name: - # NOTE: send query to pull from upstream - network_name = network_name.replace("-fork", "") - return self.config_manager.DATA_FOLDER / ecosystem_name / network_name / "cache.db" - - def _get_sqlite_uri(self, database_file: Path) -> str: - """ - Gets a string for the sqlite db URI. - - Args: - database_file (`pathlib.Path`): A path to the database file. - - Returns: - str - """ - - return f"sqlite:///{database_file}" - - def init_database(self, ecosystem_name: str, network_name: str): - """ - Initialize the SQLite database for caching of provider data. - - Args: - ecosystem_name (str): Name of the ecosystem to store data for (ex: ethereum) - network_name (str): name of the network to store data for (ex: mainnet) - - Raises: - :class:`~ape.exceptions.QueryEngineError`: When the database has not been initialized - """ - - database_file = self._get_database_file(ecosystem_name, network_name) - if database_file.is_file(): - raise QueryEngineError("Database has already been initialized") - - # NOTE: Make sure database folder location has been created - database_file.parent.mkdir(exist_ok=True, parents=True) - - models.Base.metadata.create_all( # type: ignore - bind=create_engine(self._get_sqlite_uri(database_file), pool_pre_ping=True) - ) - - def purge_database(self, ecosystem_name: str, network_name: str): - """ - Removes the SQLite database file from disk. - - Args: - ecosystem_name (str): Name of the ecosystem to store data for (ex: ethereum) - network_name (str): name of the network to store data for (ex: mainnet) - - Raises: - :class:`~ape.exceptions.QueryEngineError`: When the database has not been initialized - """ - - database_file = self._get_database_file(ecosystem_name, network_name) - if not database_file.is_file(): - raise QueryEngineError("Database must be initialized") - - database_file.unlink() +class _BaseCursor(CursorAPI): + cache_folder: Path @property - def database_connection(self): - """ - Returns a connection for the currently active network. - - **NOTE**: Creates a database if it doesn't exist. - - Raises: - :class:`~ape.exceptions.QueryEngineError`: If you are not connected to a provider, - or if the database has not been initialized. + def total_time(self) -> float: + return (self.query.end_index - self.query.start_index) * (self.time_per_row) - Returns: - Optional[`sqlalchemy.engine.Connection`] - """ - if self.provider.network.is_local: - return None + @property + def time_per_row(self) -> float: + return 0.01 # 10ms per row to parse file w/ Pydantic - if not self.network_manager.connected: - raise QueryEngineError("Not connected to a provider") - database_file = self._get_database_file( - self.provider.network.ecosystem.name, self.provider.network.name - ) +class BlockCursor(_BaseCursor): + query: BlockQuery - if not database_file.is_file(): - # NOTE: Raising `info` here hints user that they can initialize the cache db - logger.info("`ape-cache` database has not been initialized") - self.database_bypass = True - return None + def shrink(self, start_index: Optional[int] = None, end_index: Optional[int] = None) -> "Self": + copy = self.model_copy(deep=True) - try: - sqlite_uri = self._get_sqlite_uri(database_file) - return create_engine(sqlite_uri, pool_pre_ping=True).connect() + if start_index is not None: + copy.query.start_block = start_index - except QueryEngineError as e: - logger.debug(f"Exception when querying:\n{e}") - return None + if end_index is not None: + copy.query.stop_block = end_index - except Exception as e: - logger.warning(f"Unhandled exception when querying:\n{e}") - self.database_bypass = True - return None + return copy - @singledispatchmethod - def _estimate_query_clause(self, query: QueryType) -> Select: - """ - A singledispatchmethod that returns a select statement. + def as_dataframe(self, backend: nw.Implementation) -> "Frame": + return super().as_dataframe(backend) - Args: - query (QueryType): Choice of query type to perform a - check of the number of rows that match the clause. - - Raises: - :class:`~ape.exceptions.QueryEngineError`: When given an - incompatible QueryType. - - Returns: - `sqlalchemy.sql.expression.Select` - """ + def as_model_iter(self) -> Iterator[BlockAPI]: + block_index_folder = self.cache_folder / ".number" + for block_number in range(self.query.start_block, self.query.stop_block + 1): + yield from map( + self.provider.network.ecosystem.block_class.model_validate_json, + (block_index_folder / str(block_number)).read_text(), + ) - raise QueryEngineError( - """ - Not a compatible QueryType. For more details see our docs - https://docs.apeworx.io/ape/stable/methoddocs/exceptions.html#ape.exceptions.QueryEngineError - """ - ) - @_estimate_query_clause.register - def _block_estimate_query_clause(self, query: BlockQuery) -> Select: - return ( - select(func.count()) - .select_from(Blocks) - .where(Blocks.number >= query.start_block) - .where(Blocks.number <= query.stop_block) - .where(Blocks.number % query.step == 0) - ) +class CacheQueryProvider(QueryEngineAPI): + """ + Default implementation of the :class:`~ape.api.query.QueryAPI`. + Allows for the query of blockchain data using a connected provider. + """ - @_estimate_query_clause.register - def _transaction_estimate_query_clause(self, query: BlockTransactionQuery) -> Select: - return ( - select(func.count()) - .select_from(Transactions) - .where(Transactions.block_hash == query.block_id) - ) + exec = singledispatchmethod(QueryEngineAPI.exec) - @_estimate_query_clause.register - def _contract_events_estimate_query_clause(self, query: ContractEventQuery) -> Select: + def cache_folder(self) -> Path: return ( - select(func.count()) - .select_from(ContractEvents) - .where(ContractEvents.block_number >= query.start_block) - .where(ContractEvents.block_number <= query.stop_block) - .where(ContractEvents.block_number % query.step == 0) - ) - - @singledispatchmethod - def _compute_estimate(self, query: QueryType, result: CursorResult) -> Optional[int]: - """ - A singledispatchemethod that computes the time a query - will take to perform from the caching database - """ - - return None # can't handle this query - - @_compute_estimate.register - def _compute_estimate_block_query( - self, - query: BlockQuery, - result: CursorResult, - ) -> Optional[int]: - if result.scalar() == (1 + query.stop_block - query.start_block) // query.step: - # NOTE: Assume 200 msec to get data from database - return 200 - - # Can't handle this query - # TODO: Allow partial queries - return None - - @_compute_estimate.register - def _compute_estimate_block_transaction_query( - self, - query: BlockTransactionQuery, - result: CursorResult, - ) -> Optional[int]: - # TODO: Update `transactions` table schema so this query functions properly - # Uncomment below after https://github.com/ApeWorX/ape/issues/994 - # if result.scalar() > 0: # type: ignore - # # NOTE: Assume 200 msec to get data from database - # return 200 - - # Can't handle this query - return None - - @_compute_estimate.register - def _compute_estimate_contract_events_query( - self, - query: ContractEventQuery, - result: CursorResult, - ) -> Optional[int]: - if result.scalar() == (query.stop_block - query.start_block) // query.step: - # NOTE: Assume 200 msec to get data from database - return 200 - - # Can't handle this query - # TODO: Allow partial queries - return None - - def estimate_query(self, query: QueryType) -> Optional[int]: - """ - Method called by the client to return a query time estimate. - - Args: - query (QueryType): Choice of query type to perform a - check of the number of rows that match the clause. - - Returns: - Optional[int] - """ - - # NOTE: Because of Python shortcircuiting, the first time `database_connection` is missing - # this will lock the class var `database_bypass` in place for the rest of the session - if self.database_bypass or self.database_connection is None: - # No database, or some other issue - return None - - try: - with self.database_connection as conn: - result = conn.execute(self._estimate_query_clause(query)) - if not result: - return None - - return self._compute_estimate(query, result) - - except QueryEngineError as err: - logger.debug(f"Bypassing cache database: {err}") - # Note: The reason we return None instead of failing is that we want - # a failure of the query to bypass the query logic so that the - # estimation phase does not fail in `QueryManager`. - return None - - @singledispatchmethod - def perform_query(self, query: QueryType) -> Iterator: # type: ignore - """ - Performs the requested query from cache. - - Args: - query (QueryType): Choice of query type to perform a - check of the number of rows that match the clause. - - Raises: - :class:`~ape.exceptions.QueryEngineError`: When given an - incompatible QueryType, or encounters some sort of error - in the database or estimation logic. - - Returns: - Iterator - """ - - raise QueryEngineError( - "Not a compatible QueryType. For more details see our docs " - "https://docs.apeworx.io/ape/stable/methoddocs/" - "exceptions.html#ape.exceptions.QueryEngineError" + self.config_manager.DATA_FOLDER + / self.provider.network.ecosystem.name + / self.provider.network.name ) - @perform_query.register - def _perform_block_query(self, query: BlockQuery) -> Iterator[BlockAPI]: - with self.database_connection as conn: - result = conn.execute( - select([column(c) for c in query.columns]) - .where(Blocks.number >= query.start_block) - .where(Blocks.number <= query.stop_block) - .where(Blocks.number % query.step == 0) - ) - - if not result: - # NOTE: Should be unreachable if estimated correctly - raise QueryEngineError(f"Could not perform query:\n{query}") - - yield from map( - lambda row: self.provider.network.ecosystem.decode_block(dict(row.items())), result - ) + def find_ranges( + self, index_folder: Path, start: int = 0, end: int = -1 + ) -> Iterator[tuple[int, int]]: + all_indices = sorted(int(p.name) for p in index_folder.glob("*")) + last_index = max(start, min(all_indices)) - @perform_query.register - def _perform_transaction_query(self, query: BlockTransactionQuery) -> Iterator[dict]: - with self.database_connection as conn: - result = conn.execute( - select([Transactions]).where(Transactions.block_hash == query.block_id) - ) + for index in all_indices: + if index <= last_index: + continue # NOTE: Skip past `last_index` - if not result: - # NOTE: Should be unreachable if estimated correctly - raise QueryEngineError(f"Could not perform query:\n{query}") + elif end != -1 and index >= end: + # NOTE: Yield last range in `[start, end]` + yield start, end + break - yield from map(lambda row: dict(row.items()), result) + elif index - last_index > 1: + # NOTE: Gap identified + yield start, last_index + start = index - @perform_query.register - def _perform_contract_events_query(self, query: ContractEventQuery) -> Iterator[ContractLog]: - with self.database_connection as conn: - result = conn.execute( - select([column(c) for c in query.columns]) - .where(ContractEvents.block_number >= query.start_block) - .where(ContractEvents.block_number <= query.stop_block) - .where(ContractEvents.block_number % query.step == 0) - ) + last_index = index - if not result: - # NOTE: Should be unreachable if estimated correctly - raise QueryEngineError(f"Could not perform query:\n{query}") + @exec.register + def exec_block_query(self, query: BlockQuery) -> Iterator[BlockCursor]: + if not (block_folder := self.cache_folder() / "blocks").exists(): + return - yield from map(lambda row: ContractLog.model_validate(dict(row.items())), result) + for block_range in self.find_ranges( + block_folder / ".number", + start=query.start_block, + end=query.stop_block, + ): + yield BlockCursor(query=query, cache_folder=block_folder).shrink(*block_range) - @singledispatchmethod - def _cache_update_clause(self, query: QueryType) -> Insert: + def prune_database(self, ecosystem_name: str, network_name: str): """ - Update cache database Insert statement. + Removes the SQLite database file from disk. Args: - query (QueryType): Choice of query type to perform a - check of the number of rows that match the clause. + ecosystem_name (str): Name of the ecosystem to store data for (ex: ethereum) + network_name (str): name of the network to store data for (ex: mainnet) Raises: - :class:`~ape.exceptions.QueryEngineError`: When given an - incompatible QueryType, or encounters some sort of error - in the database or estimation logic. - - Returns: - `sqlalchemy.sql.Expression.Insert` + :class:`~ape.exceptions.QueryEngineError`: When the database has not been initialized """ - # Can't cache this query - raise QueryEngineError( - "Not a compatible QueryType. For more details see our docs " - "https://docs.apeworx.io/ape/stable/methoddocs/" - "exceptions.html#ape.exceptions.QueryEngineError" - ) - @_cache_update_clause.register - def _cache_update_block_clause(self, query: BlockQuery) -> Insert: - return insert(Blocks) - - # TODO: Update `transactions` table schema so we can use `EcosystemAPI.decode_receipt` - # Uncomment below after https://github.com/ApeWorX/ape/issues/994 - # @_cache_update_clause.register - # def _cache_update_block_txns_clause(self, query: BlockTransactionQuery) -> Insert: - # return insert(Transactions) # type: ignore - - @_cache_update_clause.register - def _cache_update_events_clause(self, query: ContractEventQuery) -> Insert: - return insert(ContractEvents) - - @singledispatchmethod - def _get_cache_data( - self, query: QueryType, result: Iterator[BaseInterfaceModel] - ) -> Optional[list[dict[str, Any]]]: - raise QueryEngineError( - """ - Not a compatible QueryType. For more details see our docs - https://docs.apeworx.io/ape/stable/methoddocs/exceptions.html#ape.exceptions.QueryEngineError - """ - ) + # NOTE: Delete below after v0.9 + def estimate_query(self, query: QueryType) -> Optional[int]: + return None - @_get_cache_data.register - def _get_block_cache_data( - self, query: BlockQuery, result: Iterator[BaseInterfaceModel] - ) -> Optional[list[dict[str, Any]]]: - return [m.model_dump(mode="json", by_alias=False) for m in result] - - @_get_cache_data.register - def _get_block_txns_data( - self, query: BlockTransactionQuery, result: Iterator[BaseInterfaceModel] - ) -> Optional[list[dict[str, Any]]]: - new_result = [] - table_columns = [c.key for c in Transactions.__table__.columns] # type: ignore - txns: list[TransactionAPI] = cast(list[TransactionAPI], result) - for val in [m for m in txns]: - new_dict = { - k: v - for k, v in val.model_dump(mode="json", by_alias=False).items() - if k in table_columns - } - for col in table_columns: - if col == "txn_hash": - new_dict["txn_hash"] = val.txn_hash - elif col == "sender": - new_dict["sender"] = new_dict["sender"].encode() - elif col == "receiver" and "receiver" in new_dict: - new_dict["receiver"] = new_dict["receiver"].encode() - elif col == "receiver" and "receiver" not in new_dict: - new_dict["receiver"] = b"" - elif col == "block_hash": - new_dict["block_hash"] = query.block_id - elif col == "signature" and val.signature is not None: - new_dict["signature"] = val.signature.encode_rsv() - elif col not in new_dict: - new_dict[col] = None - new_result.append(new_dict) - return new_result - - @_get_cache_data.register - def _get_cache_events_data( - self, query: ContractEventQuery, result: Iterator[BaseInterfaceModel] - ) -> Optional[list[dict[str, Any]]]: - return [m.model_dump(mode="json", by_alias=False) for m in result] + def perform_query(self, query: QueryType) -> Iterator: + raise QueryEngineError("Cannot use this engine in legacy mode") def update_cache(self, query: QueryType, result: Iterator[BaseInterfaceModel]): - try: - clause = self._cache_update_clause(query) - except QueryEngineError: - # Cannot handle query type - return - - # NOTE: Because of Python shortcircuiting, the first time `database_connection` is missing - # this will lock the class var `database_bypass` in place for the rest of the session - if not self.database_bypass and self.database_connection is not None: - logger.debug(f"Caching query: {query}") - with self.database_connection as conn: - try: - conn.execute( - clause.values( # type: ignore - self._get_cache_data(query, result) - ).prefix_with("OR IGNORE") - ) - - except QueryEngineError as err: - logger.warning(f"Database corruption: {err}") + pass # TODO: Add legacy cache support diff --git a/src/ape_ethereum/query.py b/src/ape_ethereum/query.py index 7dd8a4f9ca..2bbf451c4d 100644 --- a/src/ape_ethereum/query.py +++ b/src/ape_ethereum/query.py @@ -1,17 +1,191 @@ from collections.abc import Iterator from functools import singledispatchmethod -from typing import Optional +from typing import TYPE_CHECKING, Optional -from ape.api.query import ContractCreation, ContractCreationQuery, QueryAPI, QueryType +import narwhals as nw + +from ape.api.query import ( + ContractCreation, + ContractCreationQuery, + CursorAPI, + QueryEngineAPI, + QueryType, +) from ape.exceptions import APINotImplementedError, ProviderError, QueryEngineError -from ape.types.address import AddressType +from ape.types import AddressType + +if TYPE_CHECKING: + from narwhals.typing import Frame + + try: + # Only on Python 3.11 + from typing import Self # type: ignore + except ImportError: + from typing_extensions import Self # type: ignore + + +class ContractCreationCursor(CursorAPI[ContractCreation]): + query: ContractCreationQuery + + use_debug_trace: bool + + def shrink( + self, + start_index: Optional[int] = None, + end_index: Optional[int] = None, + ) -> "Self": + if (start_index is not None and start_index != self.query.start_index) or ( + end_index is not None and end_index != self.query.end_index + ): + raise NotImplementedError + + return self + + @property + def total_time(self) -> float: + # NOTE: 1 row + return self.time_per_row + + @property + def time_per_row(self) -> float: + # NOTE: Extremely expensive query, involves binary search of all blocks in a chain + # Very loose estimate of 5s per call for this query. + return 5.0 + + def _find_creation_in_block_via_parity(self, block, contract_address): + # NOTE requires `trace_` namespace + traces = self.provider.make_request("trace_replayBlockTransactions", [block, ["trace"]]) + + for tx in traces: + for trace in tx["trace"]: + if ( + "error" not in trace + and trace["type"] == "create" + and trace["result"]["address"] == contract_address.lower() + ): + receipt = self.chain_manager.get_receipt(tx["transactionHash"]) + creator = self.conversion_manager.convert(trace["action"]["from"], AddressType) + yield ContractCreation( + txn_hash=tx["transactionHash"], + block=block, + deployer=receipt.sender, + factory=creator if creator != receipt.sender else None, + ) + + def _find_creation_in_block_via_geth(self, block, contract_address): + # NOTE requires `debug_` namespace + traces = self.provider.make_request( + "debug_traceBlockByNumber", [hex(block), {"tracer": "callTracer"}] + ) + + def flatten(call): + if call["type"] in ["CREATE", "CREATE2"]: + yield call["from"], call["to"] + + if "error" in call or "calls" not in call: + return + + for sub in call["calls"]: + if sub["type"] in ["CREATE", "CREATE2"]: + yield sub["from"], sub["to"] + else: + yield from flatten(sub) + for tx in traces: + call = tx["result"] + sender = call["from"] + for factory, contract in flatten(call): + if contract == contract_address.lower(): + yield ContractCreation( + txn_hash=tx["txHash"], + block=block, + deployer=self.conversion_manager.convert(sender, AddressType), + factory=( + self.conversion_manager.convert(factory, AddressType) + if factory != sender + else None + ), + ) + + def as_model_iter(self) -> Iterator[ContractCreation]: + # skip the search if there is still no code at address at head + if not self.chain_manager.get_code(self.query.contract): + return None + + def find_creation_block(lo, hi): + # perform a binary search to find the block when the contract was deployed. + # takes log2(height), doesn't work with contracts that have been reinit. + while hi - lo > 1: + mid = (lo + hi) // 2 + code = self.chain_manager.get_code(self.query.contract, block_id=mid) + if not code: + lo = mid + else: + hi = mid + + if self.chain_manager.get_code(self.query.contract, block_id=hi): + return hi + + return None + + if not (block := find_creation_block(0, self.chain_manager.blocks.height)): + return + + if self.use_debug_trace: + yield from self._find_creation_in_block_via_geth(block, self.query.contract) + + else: + yield from self._find_creation_in_block_via_parity(block, self.query.contract) + + def as_dataframe(self, backend: nw.Implementation) -> "Frame": + data: dict[str, list] = {column: [] for column in self.query.columns} + + # NOTE: Only 1 item + item = next(self.as_model_iter()) + for column in data: + data[column] = getattr(item, column) + + return nw.from_dict(data, backend=backend) -class EthereumQueryProvider(QueryAPI): + +class EthereumQueryProvider(QueryEngineAPI): """ Implements more advanced queries specific to Ethereum clients. """ + def _has_method(self, rpc_method: str) -> bool: + try: + self.provider.make_request(rpc_method, []) + return True + + except APINotImplementedError: + return False + + except ProviderError as e: + return "Method not found" not in str(e) + + @property + def use_debug_trace(self) -> bool: + return "geth" in self.provider.client_version.lower() and self._has_method( + "debug_traceBlockByNumber" + ) + + @property + def use_trace_replay(self) -> bool: + return self._has_method("trace_replayBlockTransactions") + + @singledispatchmethod + def exec(self, query: QueryType) -> Iterator[CursorAPI]: # type: ignore[override] + return super().exec(query) + + @exec.register + def exec_contract_creation( + self, query: ContractCreationQuery + ) -> Iterator[ContractCreationCursor]: + if (use_debug_trace := self.use_debug_trace) or self.use_trace_replay: + yield ContractCreationCursor(query=query, use_debug_trace=use_debug_trace) + + # TODO: Delete all of below in v0.9 def __init__(self): self.supports_contract_creation = None # will be set after we try for the first time diff --git a/src/ape_node/query.py b/src/ape_node/query.py index e231d74a39..2b4e056889 100644 --- a/src/ape_node/query.py +++ b/src/ape_node/query.py @@ -1,14 +1,78 @@ from collections.abc import Iterator from functools import singledispatchmethod -from typing import Optional +from typing import TYPE_CHECKING, Optional -from ape.api.query import ContractCreation, ContractCreationQuery, QueryAPI, QueryType +import narwhals as nw + +from ape.api.query import ( + ContractCreation, + ContractCreationQuery, + CursorAPI, + QueryEngineAPI, + QueryType, +) from ape.exceptions import QueryEngineError -from ape.types.address import AddressType +from ape.types import AddressType from ape_ethereum.provider import EthereumNodeProvider +if TYPE_CHECKING: + from narwhals.typing import Frame + + +class ContractCreationCursor(CursorAPI): + query: ContractCreationQuery + + def shrink( + self, + start_index: Optional[int] = None, + end_index: Optional[int] = None, + ) -> "ContractCreationCursor": + if start_index or end_index: + raise NotImplementedError + + return self + + @property + def total_time(self) -> float: + return 0.25 + + @property + def time_per_row(self) -> float: + return 0.25 + + def _get_ots_contract_creation(self) -> ContractCreation: + result = self.provider.make_request("ots_getContractCreator", [self.query.contract]) + creator = self.conversion_manager.convert(result["creator"], AddressType) + receipt = self.provider.get_receipt(result["hash"]) + return ContractCreation( + txn_hash=result["hash"], + block=receipt.block_number, + deployer=receipt.sender, + factory=creator if creator != receipt.sender else None, + ) + + def as_dataframe(self, backend: nw.Implementation) -> "Frame": + return nw.from_dict(self._get_ots_contract_creation().model_dump(), backend=backend) + + def as_model_iter(self) -> Iterator[ContractCreation]: + yield self._get_ots_contract_creation() + + +class OtterscanQueryEngine(QueryEngineAPI): + @singledispatchmethod + def exec(self, query: QueryType) -> Iterator[CursorAPI]: # type: ignore[override] + return super().exec(query) + + @property + def supports_ots_namespace(self) -> bool: + return getattr(self.provider, "_ots_api_level", None) is not None + + @exec.register + def exec_creation_query(self, query: ContractCreationQuery) -> Iterator[ContractCreationCursor]: + if self.supports_ots_namespace: + yield ContractCreationCursor(query=query) -class OtterscanQueryEngine(QueryAPI): + # TODO: Delete below in v0.9 @singledispatchmethod def estimate_query(self, query: QueryType) -> Optional[int]: # type: ignore[override] return None diff --git a/tests/functional/test_block.py b/tests/functional/test_block.py index d5912ab0a9..7b0130e225 100644 --- a/tests/functional/test_block.py +++ b/tests/functional/test_block.py @@ -128,14 +128,3 @@ def test_model_validate_web3_block(): data = BlockData(number=123, timestamp=123, gasLimit=123, gasUsed=100) # type: ignore actual = Block.model_validate(data) assert actual.number == 123 - - -def test_transactions(block): - actual = block.transactions - expected: list = [] - assert actual == expected - - # Ensure still works when hash is None (was a bug where this crashed). - block.hash = None - block.__dict__.pop("transactions", None) # Ensure not cached. - assert block.transactions == [] diff --git a/tests/functional/test_query.py b/tests/functional/test_query.py index fad3e49806..c79b8a646e 100644 --- a/tests/functional/test_query.py +++ b/tests/functional/test_query.py @@ -1,6 +1,6 @@ import time -import pandas as pd +import narwhals as nw import pytest from ape.api.query import validate_and_expand_columns @@ -12,15 +12,15 @@ def test_basic_query(chain, eth_tester_provider): blocks_df0 = chain.blocks.query("*") blocks_df1 = chain.blocks.query("number", "timestamp") - assert list(blocks_df0["number"].values)[:4] == [0, 1, 2, 3] + assert blocks_df0["number"].to_list()[:4] == [0, 1, 2, 3] assert len(blocks_df1) == len(chain.blocks) assert ( - blocks_df1.iloc[3]["timestamp"] - >= blocks_df1.iloc[2]["timestamp"] - >= blocks_df1.iloc[1]["timestamp"] - >= blocks_df1.iloc[0]["timestamp"] + blocks_df1["timestamp"][3] + >= blocks_df1["timestamp"][2] + >= blocks_df1["timestamp"][1] + >= blocks_df1["timestamp"][0] ) - assert list(blocks_df0.columns) == [ + assert blocks_df0.columns == [ "base_fee", "difficulty", "gas_limit", @@ -40,8 +40,8 @@ def test_relative_block_query(chain, eth_tester_provider): chain.mine(10) df = chain.blocks.query("*", start_block=-8, stop_block=-2) assert len(df) == 7 - assert df.number.min() == chain.blocks[-8].number == start_block + 3 - assert df.number.max() == chain.blocks[-2].number == start_block + 9 + assert df["number"].min() == chain.blocks[-8].number == start_block + 3 + assert df["number"].max() == chain.blocks[-2].number == start_block + 9 def test_block_transaction_query(chain, eth_tester_provider, sender, receiver): @@ -56,8 +56,8 @@ def test_transaction_contract_event_query(contract_instance, owner, eth_tester_p contract_instance.fooAndBar(sender=owner) time.sleep(0.1) df_events = contract_instance.FooHappened.query("*", start_block=-1) - assert isinstance(df_events, pd.DataFrame) - assert df_events.event_name[0] == "FooHappened" + assert isinstance(df_events, nw.DataFrame) + assert df_events["event_name"][0] == "FooHappened" def test_transaction_contract_event_query_starts_query_at_deploy_tx( @@ -66,8 +66,8 @@ def test_transaction_contract_event_query_starts_query_at_deploy_tx( contract_instance.fooAndBar(sender=owner) time.sleep(0.1) df_events = contract_instance.FooHappened.query("*") - assert isinstance(df_events, pd.DataFrame) - assert df_events.event_name[0] == "FooHappened" + assert isinstance(df_events, nw.DataFrame) + assert df_events["event_name"][0] == "FooHappened" class Model(BaseInterfaceModel):