Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tiled/_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def tiled_websocket_context(tmpdir, redis_uri):
f"file://localhost{str(tmpdir / 'data')}",
f"duckdb:///{tmpdir / 'data.duckdb'}",
],
readable_storage=[tempfile.gettempdir()],
readable_storage=[str(Path(tempfile.gettempdir()).resolve())],
init_if_not_exists=True,
# This uses shorter defaults than the production defaults. Nothing in
# the test suite should be going on for more than ten minutes.
Expand Down
68 changes: 66 additions & 2 deletions tiled/_tests/test_subscription.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import copy
import sys
import threading
import time
import uuid

import numpy as np
import pandas as pd
import pyarrow
import pytest
import tifffile
import websockets.exceptions
from pandas.testing import assert_frame_equal
from starlette.testclient import WebSocketDenialResponse

Expand All @@ -24,6 +26,15 @@
)


@pytest.fixture
def stamina_active():
import stamina

stamina.set_active(True)
yield
stamina.set_active(False)


def test_subscribe_immediately_after_creation_websockets(tiled_websocket_context):
context = tiled_websocket_context
client = from_context(context)
Expand Down Expand Up @@ -236,8 +247,6 @@ def child_metadata_updated_cb(update):
for i in range(3):
# This is exposing fragility in SQLite database connection handling.
# Once that is resolved, remove the sleep.
import time

time.sleep(0.1)
unique_key = f"{uuid.uuid4().hex[:8]}"
uploaded_nodes.append(client.create_container(unique_key))
Expand Down Expand Up @@ -570,3 +579,58 @@ def collect(update):
)
expected_combined = pyarrow.concat_tables([table1, table2])
assert streaming_combined == expected_combined


def test_subscription_auto_reconnect_on_network_failure(
tiled_websocket_context, stamina_active, monkeypatch
):
"""Test that subscription automatically reconnects after network failure."""
context = tiled_websocket_context
client = from_context(context)

# Create streaming array node
arr = np.arange(10)
streaming_node = client.write_array(arr, key="test_reconnect")

# Track received updates
received = []

def callback(update):
received.append(update)

subscription = streaming_node.subscribe()
subscription.new_data.add_callback(callback)

with subscription.start_in_thread():
# Send first 3 updates
for i in range(1, 4):
streaming_node.write(np.arange(10) + i)
time.sleep(0.1)

# Simulate network failure once, then restore normal behavior
original_recv = subscription._websocket.recv

class FailOnce:
"""Fails on first call, then delegates to original."""

def __init__(self):
self.call_count = 0

def __call__(self, timeout=None):
self.call_count += 1
if self.call_count == 1:
raise websockets.exceptions.ConnectionClosedError(None, None)
return original_recv(timeout)

monkeypatch.setattr(subscription._websocket, "recv", FailOnce())

# Send more updates after simulated disconnect
for i in range(4, 7):
streaming_node.write(np.arange(10) + i)
time.sleep(0.1)

# Give time for reconnection and receiving all updates
time.sleep(2)

# Verify we received all 6 updates (3 before disconnect + 3 after)
assert len(received) >= 6, f"Expected at least 6 updates, got {len(received)}"
66 changes: 59 additions & 7 deletions tiled/client/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import anyio
import httpx
import msgpack
import stamina
import websockets.exceptions
from pydantic import ConfigDict
from websockets.sync.client import connect
Expand All @@ -36,7 +37,14 @@
)
from ..structures.core import STRUCTURE_TYPES, StructureFamily
from .context import Context
from .utils import client_for_item, handle_error, normalize_specs, retry_context
from .utils import (
TILED_RETRY_ATTEMPTS,
TILED_RETRY_TIMEOUT,
client_for_item,
handle_error,
normalize_specs,
retry_context,
)

T = TypeVar("T")
Callback = Callable[[T], None]
Expand All @@ -48,7 +56,6 @@

logger = logging.getLogger(__name__)


__all__ = ["Subscription"]


Expand Down Expand Up @@ -259,6 +266,7 @@ def __init__(
self._disconnect_lock = threading.Lock()
self._disconnect_event = threading.Event()
self._thread = None
self._last_received_sequence = None # Track last sequence for reconnection
if getattr(self.context.http_client, "app", None):
self._websocket = _TestClientWebsocketWrapper(
context.http_client, self._uri
Expand Down Expand Up @@ -287,10 +295,19 @@ def context(self) -> Context:
def segments(self) -> List[str]:
return self._segments

def _websocket_retry_context(self):
return stamina.retry_context(
on=(
websockets.exceptions.ConnectionClosedError,
OSError,
TimeoutError,
),
attempts=TILED_RETRY_ATTEMPTS,
timeout=TILED_RETRY_TIMEOUT,
)

def _connect(self, start: Optional[int] = None) -> None:
"Connect to websocket"
if self._disconnect_event.is_set():
raise RuntimeError("Cannot be restarted once stopped.")
needs_api_key = self.context.server_info.authentication.providers
if needs_api_key:
# Request a short-lived API key to use for authenticating the WS connection.
Expand All @@ -311,17 +328,50 @@ def _connect(self, start: Optional[int] = None) -> None:
# necessary.
self.context.revoke_api_key(key_info["first_eight"])

def _reconnect(self, attempt_num: int) -> None:
"""Reconnect to websocket after connection failure."""
logger.debug(f"Reconnecting after connection failure (attempt {attempt_num})")

try:
self._websocket.close()
except Exception:
pass

self._schema = None # Server will resend schema

if self._last_received_sequence is not None:
logger.debug(f"Resuming from sequence {self._last_received_sequence + 1}")
self._connect(start=self._last_received_sequence + 1)
else:
logger.debug("Reconnecting from start (no sequence yet)")
self._connect()

def _receive(self) -> None:
"Blocking loop that receives and processes updates"
while not self._disconnect_event.is_set():
try:
data = self._websocket.recv(timeout=RECEIVE_TIMEOUT)
except (TimeoutError, anyio.EndOfStream):
data = None
for attempt in self._websocket_retry_context():
with attempt:
if attempt.num > 1:
self._reconnect(attempt.num)

logger.debug(f"Receive attempt {attempt.num}")
try:
data = self._websocket.recv(timeout=RECEIVE_TIMEOUT)
except (TimeoutError, anyio.EndOfStream):
data = "timeout"
break

break

if data == "timeout":
continue

if data is None:
self.stream_closed.process(self)
self._disconnect()
return

try:
if self._schema is None:
self._schema = parse_schema(data)
Expand All @@ -333,6 +383,8 @@ def _receive(self) -> None:
"A websocket message will be ignored because it could not be parsed."
)
continue

self._last_received_sequence = update.sequence
self.process(update)

@abc.abstractmethod
Expand Down
Loading