Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 78 additions & 10 deletions src/tq_oracle/adapters/asset_adapters/idle_balances.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@

logger = get_logger(__name__)

# Minimal ABI for validating extra_addresses
EXTRA_ADDRESS_VALIDATION_ABI = [
{
"inputs": [],
"name": "subvault",
"outputs": [{"internalType": "address", "name": "", "type": "address"}],
"stateMutability": "view",
"type": "function",
},
]


class IdleBalancesAdapter(BaseAssetAdapter):
"""Adapter for querying idle balances on the vault chain."""
Expand Down Expand Up @@ -88,26 +99,24 @@ def __init__(self, config: OracleSettings):
self._extra_additional_assets_by_symbol,
)

extra_address_candidates = (
[
self.w3.to_checksum_address(address)
for address in idle_cfg.extra_addresses
if address
]
if config.additional_asset_support
else []
)
extra_address_candidates = [
self.w3.to_checksum_address(address)
for address in idle_cfg.extra_addresses
if address
]

deduped_addresses: dict[str, str] = {}
for checksum in extra_address_candidates:
deduped_addresses.setdefault(checksum.lower(), checksum)

self._extra_addresses = list(deduped_addresses.values())
self._extra_addresses_lookup = set(deduped_addresses.keys())
self._skip_extra_address_validation = idle_cfg.skip_extra_address_validation
if self._extra_addresses:
logger.debug(
"Idle balances extra addresses configured: %s",
"Idle balances extra addresses configured: %s (validation=%s)",
self._extra_addresses,
"skipped" if self._skip_extra_address_validation else "enabled",
)

self._rpc_sem = asyncio.Semaphore(getattr(self.config, "max_calls", 5))
Expand Down Expand Up @@ -186,6 +195,10 @@ async def fetch_all_assets(self) -> list[AssetData]:
List of AssetData objects from main vault and all subvaults
"""
subvault_addresses = await self._fetch_subvault_addresses()

if self._extra_addresses:
await self._validate_extra_addresses(subvault_addresses)

vault_addresses = [self.config.vault_address_required] + subvault_addresses
seen_addresses = {addr.lower() for addr in vault_addresses}
for extra_address in self._extra_addresses:
Expand Down Expand Up @@ -293,6 +306,61 @@ async def _fetch_subvault_addresses(self) -> list[str]:
"""Get the subvault addresses for the given vault."""
return await fetch_subvault_addresses(self.config)

async def _validate_extra_addresses(
self,
subvault_addresses: list[str],
) -> None:
"""Validate that extra_addresses return correct subvault values."""
if not self._extra_addresses or self._skip_extra_address_validation:
return

logger.debug(
"Validating %d extra_addresses against %d subvaults",
len(self._extra_addresses),
len(subvault_addresses),
)

normalized_subvaults = {addr.lower() for addr in subvault_addresses}

async def validate_one(extra_addr: str) -> str | None:
"""Returns error message or None if valid."""
checksum_addr = self.w3.to_checksum_address(extra_addr)
contract = self.w3.eth.contract(
address=checksum_addr, abi=EXTRA_ADDRESS_VALIDATION_ABI
)

try:
returned_subvault: str = await self._rpc(
contract.functions.subvault().call,
block_identifier=self.block_number,
)
except (ProviderConnectionError, ValueError) as e:
return f"{extra_addr}: failed to call .subvault() - {e}"

if returned_subvault.lower() not in normalized_subvaults:
return (
f"{extra_addr}: .subvault() returned {returned_subvault} "
f"which is not in auto-discovered subvaults"
)
return None

results = await asyncio.gather(
*[validate_one(addr) for addr in self._extra_addresses]
)
validation_errors = [e for e in results if e is not None]

if validation_errors:
error_list = "\n - ".join(validation_errors)
raise ValueError(
f"extra_address validation failed:\n - {error_list}\n"
f"If these addresses are intentional, set 'skip_extra_address_validation = true' "
f"in [adapters.idle_balances] and pass the --allow-dangerous CLI flag."
)

logger.debug(
"All %d extra_addresses validated successfully", len(self._extra_addresses)
)

async def _fetch_supported_assets(self) -> list[str]:
"""Get the supported assets for the given vault."""
oracle_abi = load_oracle_abi()
Expand Down
8 changes: 8 additions & 0 deletions src/tq_oracle/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,14 @@ def report(
param_hint=["--private-key", "TQ_ORACLE_PRIVATE_KEY"],
)

dangerous_options = state.settings.get_dangerous_options()
if dangerous_options and not state.settings.allow_dangerous:
options_list = "\n - ".join(dangerous_options)
raise typer.BadParameter(
f"Configuration uses dangerous options:\n - {options_list}\n"
f"Pass --allow-dangerous to enable these options."
)

from .pipeline.run import run_report

asyncio.run(run_report(state, state.settings.vault_address_required))
Expand Down
18 changes: 2 additions & 16 deletions src/tq_oracle/pipeline/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def _process_adapter_results(
asset_data.append(assets)

if failures:
failure_list = ", ".join(name for name, _ in failures)
error_details = "\n - ".join(f"{name}: {e}" for name, e in failures)
raise ValueError(
f"Failed to collect assets from {len(failures)} adapter(s): {failure_list}"
f"Failed to collect assets from {len(failures)} adapter(s):\n - {error_details}"
)


Expand All @@ -100,20 +100,6 @@ async def collect_assets(ctx: PipelineContext) -> None:
cfg["subvault_address"].lower(): cfg for cfg in s.subvault_adapters
}

dangerous_configs = [
cfg
for cfg in s.subvault_adapters
if cfg.get("skip_subvault_existence_check", False)
]
if dangerous_configs and not s.allow_dangerous:
addresses = [cfg["subvault_address"] for cfg in dangerous_configs]
raise ValueError(
f"Configuration uses 'skip_subvault_existence_check' for subvault(s): "
f"{', '.join(addresses)}. This is a dangerous operation that bypasses "
f"subvault existence validation. You must explicitly allow this by "
f"passing the --allow-dangerous CLI flag."
)

# Validate subvault_adapters config references existing subvaults
if s.subvault_adapters:
normalized_subvault_addrs = {addr.lower() for addr in subvault_addresses}
Expand Down
28 changes: 28 additions & 0 deletions src/tq_oracle/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ class IdleBalancesAdapterSettings(BaseModel):

extra_tokens: dict[str, str] = Field(default_factory=dict)
extra_addresses: list[str] = Field(default_factory=list)
skip_extra_address_validation: bool = Field(
default=False,
description=(
"[DANGEROUS] Skip validation that extra_addresses.subvault() returns an address "
"in the auto-discovered subvault list. Requires --allow-dangerous CLI flag."
),
)

model_config = ConfigDict(extra="ignore")

Expand Down Expand Up @@ -389,3 +396,24 @@ def streth_redemption_asset(self) -> str:
@property
def multicall(self) -> str:
return self._resolve_streth_addresses()["multicall"]

def get_dangerous_options(self) -> list[str]:
"""Return a list of enabled dangerous configuration options.

These options bypass safety checks and require --allow-dangerous to use.
"""
dangerous: list[str] = []

# Subvault adapter dangerous options
for cfg in self.subvault_adapters:
addr = cfg.get("subvault_address", "unknown")
if cfg.get("skip_subvault_existence_check", False):
dangerous.append(
f"subvault_adapters[{addr}].skip_subvault_existence_check"
)

# Idle balances adapter dangerous options
if self.adapters.idle_balances.skip_extra_address_validation:
dangerous.append("adapters.idle_balances.skip_extra_address_validation")

return dangerous
67 changes: 66 additions & 1 deletion tests/adapters/asset_adapters/test_idle_balances.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ async def fake_fetch_asset_balance(_w3, _subvault, asset_address, tvl_only=False
async def test_fetch_all_assets_includes_extra_addresses(config, monkeypatch):
extra_address = "0x0000000000000000000000000000000000000009"
config.adapters.idle_balances.extra_addresses = [extra_address]
# Skip validation since test address doesn't have .subvault()/.oracle() methods
config.adapters.idle_balances.skip_extra_address_validation = True
adapter = IdleBalancesAdapter(config)

async def fake_fetch_subvault_addresses():
Expand Down Expand Up @@ -209,6 +211,66 @@ async def fake_fetch_assets(address):
assert set(recorded) == expected


@pytest.mark.asyncio
async def test_extra_address_validation_fails_on_mismatched_subvault(
config, monkeypatch
):
"""Test that validation fails when extra_address returns subvault not in list."""
extra_address = "0x0000000000000000000000000000000000000009"
config.adapters.idle_balances.extra_addresses = [extra_address]
adapter = IdleBalancesAdapter(config)

subvault_addresses = ["0x00000000000000000000000000000000000000AA"]

# Mock _rpc to return subvault not in list
async def fake_rpc(fn, *args, **kwargs):
fn_name = str(fn)
if "subvault" in fn_name:
return "0x00000000000000000000000000000000000000DD" # Not in subvault list
return await fn(*args, **kwargs)

monkeypatch.setattr(adapter, "_rpc", fake_rpc)

with pytest.raises(ValueError, match=r"which is not in auto-discovered subvaults"):
await adapter._validate_extra_addresses(subvault_addresses)


@pytest.mark.asyncio
async def test_extra_address_validation_passes_when_valid(config, monkeypatch):
"""Test that validation passes when extra_address returns correct subvault."""
extra_address = "0x0000000000000000000000000000000000000009"
config.adapters.idle_balances.extra_addresses = [extra_address]
adapter = IdleBalancesAdapter(config)

subvault_addresses = ["0x00000000000000000000000000000000000000AA"]

# Mock _rpc to return correct subvault
async def fake_rpc(fn, *args, **kwargs):
fn_name = str(fn)
if "subvault" in fn_name:
return subvault_addresses[0] # Valid subvault
return await fn(*args, **kwargs)

monkeypatch.setattr(adapter, "_rpc", fake_rpc)

# Should not raise
await adapter._validate_extra_addresses(subvault_addresses)


@pytest.mark.asyncio
async def test_extra_address_validation_skipped_when_flag_set(config):
"""Test that validation is skipped when skip_extra_address_validation is True."""
extra_address = "0x0000000000000000000000000000000000000009"
config.adapters.idle_balances.extra_addresses = [extra_address]
config.adapters.idle_balances.skip_extra_address_validation = True
adapter = IdleBalancesAdapter(config)

# Should not make any RPC calls, just return
await adapter._validate_extra_addresses(
["0x00000000000000000000000000000000000000AA"],
)


@pytest.mark.asyncio
async def test_fetch_all_assets_fails_if_any_vault_fails(config, monkeypatch):
adapter = IdleBalancesAdapter(config)
Expand Down Expand Up @@ -253,5 +315,8 @@ def test_additional_assets_can_be_disabled(config):

assert adapter._default_additional_assets == []
assert adapter._extra_additional_assets_by_symbol == {}
assert adapter._extra_addresses == []
# extra_addresses is decoupled from additional_asset_support
assert adapter._extra_addresses == [
adapter.w3.to_checksum_address("0x00000000000000000000000000000000000000BB")
]
assert merged == [base_asset]
5 changes: 3 additions & 2 deletions tests/pipeline/test_assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def test_process_adapter_results_raises_on_adapter_failure():
asset_data = []

with pytest.raises(
ValueError, match=r"Failed to collect assets from 1 adapter\(s\): stakewise"
ValueError,
match=r"(?s)Failed to collect assets from 1 adapter\(s\):.*stakewise: RPC connection failed",
):
_process_adapter_results(tasks_info, results, asset_data, log)

Expand All @@ -59,7 +60,7 @@ def test_process_adapter_results_raises_on_multiple_failures():

with pytest.raises(
ValueError,
match=r"Failed to collect assets from 2 adapter\(s\): idle_balances, custom_adapter",
match=r"(?s)Failed to collect assets from 2 adapter\(s\):.*idle_balances:.*custom_adapter:",
):
_process_adapter_results(tasks_info, results, asset_data, log)

Expand Down
45 changes: 45 additions & 0 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,48 @@ def test_additional_asset_support_toggle(tmp_path, monkeypatch):
settings = OracleSettings()

assert settings.additional_asset_support is False


def test_get_dangerous_options_empty_by_default():
"""No dangerous options enabled by default."""
settings = OracleSettings()
assert settings.get_dangerous_options() == []


def test_get_dangerous_options_skip_extra_address_validation(tmp_path, monkeypatch):
"""Detects skip_extra_address_validation as dangerous."""
config_path = tmp_path / "config.toml"
config_path.write_text(
dedent(
"""
[adapters.idle_balances]
skip_extra_address_validation = true
"""
).strip()
)
monkeypatch.setenv("TQ_ORACLE_CONFIG", str(config_path))

settings = OracleSettings()

dangerous = settings.get_dangerous_options()
assert dangerous == ["adapters.idle_balances.skip_extra_address_validation"]


def test_get_dangerous_options_skip_subvault_existence_check(tmp_path, monkeypatch):
"""Detects skip_subvault_existence_check as dangerous."""
config_path = tmp_path / "config.toml"
config_path.write_text(
dedent(
"""
[[subvault_adapters]]
subvault_address = "0xABC"
skip_subvault_existence_check = true
"""
).strip()
)
monkeypatch.setenv("TQ_ORACLE_CONFIG", str(config_path))

settings = OracleSettings()

dangerous = settings.get_dangerous_options()
assert dangerous == ["subvault_adapters[0xABC].skip_subvault_existence_check"]