Skip to content

Commit

Permalink
Merge pull request #178 from plotly/andrew/refactor
Browse files Browse the repository at this point in the history
Andrew/refactor

A lot here, some kind of spooky behavior with pytest.

In general, reflects small changes to the api:

There is now an `open()` function so you don't have to `await` the constructure, `BrowserSync` has been factored out.
  • Loading branch information
ayjayt authored Jan 27, 2025
2 parents 6b44fde + 2bef2d9 commit f02b05b
Show file tree
Hide file tree
Showing 116 changed files with 30,781 additions and 1,945 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/publish_testpypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
# don't modify sync file! messes up version!
- run: uv sync --all-extras --frozen # does order matter?
- run: uv build
- run: uv run --no-sync choreo_get_browser -v --i ${{ matrix.chrome_v }}
- run: uv run --no-sync choreo_get_chrome -v --i ${{ matrix.chrome_v }}
- name: Reinstall from wheel
run: >
uv pip install dist/choreographer-$(uv
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- name: Install choreographer
run: uv sync --all-extras
- name: Install google-chrome-for-testing
run: uv run choreo_get_browser
run: uv run choreo_get_chrome
- name: Diagnose
run: uv run choreo_diagnose --no-run
timeout-minutes: 1
Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# See https://pre-commit.com/hooks.html for more hooks
%YAML 1.2
---
exclude: "site/.*"
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
Expand Down
31 changes: 0 additions & 31 deletions choreographer/DIR_INDEX.txt

This file was deleted.

35 changes: 18 additions & 17 deletions choreographer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
"""choreographer is a browser controller for python."""
"""
choreographer is a browser controller for python.
import choreographer._devtools_protocol_layer as protocol
choreographer is natively async, so while there are two main entrypoints:
classes `Browser` and `BrowserSync`, the sync version is very limited, functioning
as a building block for more featureful implementations.
from ._browser import Browser, BrowserClosedError, browser_which, get_browser_path
from ._cli_utils import get_browser, get_browser_sync
from ._pipe import BlockWarning, PipeClosedError
from ._system_utils._tempfile import TempDirectory, TempDirWarning
from ._tab import Tab
See the main README for a quickstart.
"""

from .browser_async import (
Browser,
Tab,
)
from .browser_sync import (
BrowserSync,
TabSync,
)

__all__ = [
"BlockWarning",
"Browser",
"BrowserClosedError",
"PipeClosedError",
"BrowserSync",
"Tab",
"TempDirWarning",
"TempDirectory",
"browser_which",
"get_browser",
"get_browser_path",
"get_browser_sync",
"protocol",
"TabSync",
]
10 changes: 10 additions & 0 deletions choreographer/_brokers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from ._async import Broker
from ._sync import BrokerSync

__all__ = [
"Broker",
"BrokerSync",
]

# note: should brokers be responsible for closing browser on bad pipe?
# note: should the broker be the watchdog, in that case?
268 changes: 268 additions & 0 deletions choreographer/_brokers/_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
from __future__ import annotations

import asyncio
import warnings
from typing import TYPE_CHECKING

import logistro

from choreographer import channels, protocol

# afrom choreographer.channels import ChannelClosedError

if TYPE_CHECKING:
from collections.abc import MutableMapping
from typing import Any

from choreographer.browser_async import Browser
from choreographer.channels._interface_type import ChannelInterface
from choreographer.protocol.devtools_async import Session, Target


_logger = logistro.getLogger(__name__)


class UnhandledMessageWarning(UserWarning):
pass


class Broker:
"""Broker is a middleware implementation for asynchronous implementations."""

_browser: Browser
"""Browser is a reference to the Browser object this broker is brokering for."""
_channel: ChannelInterface
"""
Channel will be the ChannelInterface implementation (pipe or websocket)
that the broker communicates on.
"""
futures: MutableMapping[protocol.MessageKey, asyncio.Future[Any]]
"""A mapping of all the futures for all sent commands."""

_subscriptions_futures: MutableMapping[
str,
MutableMapping[
str,
list[asyncio.Future[Any]],
],
]
"""A mapping of session id: subscription: list[futures]"""

def __init__(self, browser: Browser, channel: ChannelInterface) -> None:
"""
Construct a broker for a synchronous arragenment w/ both ends.
Args:
browser: The sync browser implementation.
channel: The channel the browser uses to talk on.
"""
self._browser = browser
self._channel = channel
self._background_tasks: set[asyncio.Task[Any]] = set()
# if its a task you dont want canceled at close (like the close task)
self._background_tasks_cancellable: set[asyncio.Task[Any]] = set()
# if its a user task, can cancel
self._current_read_task: asyncio.Task[Any] | None = None
self.futures = {}
self._subscriptions_futures = {}

def new_subscription_future(
self,
session_id: str,
subscription: str,
) -> asyncio.Future[Any]:
if session_id not in self._subscriptions_futures:
self._subscriptions_futures[session_id] = {}
if subscription not in self._subscriptions_futures[session_id]:
self._subscriptions_futures[session_id][subscription] = []
future = asyncio.get_running_loop().create_future()
self._subscriptions_futures[session_id][subscription].append(future)
return future

def clean(self) -> None:
_logger.debug("Cancelling message futures")
for future in self.futures.values():
if not future.done():
_logger.debug(f"Cancelling {future}")
future.cancel()
_logger.debug("Cancelling read task")
if self._current_read_task and not self._current_read_task.done():
_logger.debug(f"Cancelling read: {self._current_read_task}")
self._current_read_task.cancel()
_logger.debug("Cancelling subscription-futures")
for session in self._subscriptions_futures.values():
for query in session.values():
for future in query:
if not future.done():
_logger.debug(f"Cancelling {future}")
future.cancel()
_logger.debug("Cancelling background tasks")
for task in self._background_tasks_cancellable:
if not task.done():
_logger.debug(f"Cancelling {task}")
task.cancel()

def run_read_loop(self) -> None: # noqa: C901, PLR0915 complexity
def check_error(result: asyncio.Future[Any]) -> None:
try:
e = result.exception()
if e:
self._background_tasks.add(
asyncio.create_task(self._browser.close()),
)
if not isinstance(e, asyncio.CancelledError):
_logger.error(f"Error in run_read_loop: {e!s}")
raise e
except asyncio.CancelledError:
self._background_tasks.add(asyncio.create_task(self._browser.close()))

async def read_loop() -> None: # noqa: PLR0912, C901
try:
responses = await asyncio.to_thread(
self._channel.read_jsons,
blocking=True,
)
for response in responses:
error = protocol.get_error_from_result(response)
key = protocol.calculate_message_key(response)
if not key and error:
raise protocol.DevtoolsProtocolError(response)
self._check_for_closed_session(response)
# surrounding lines overlap in idea
if protocol.is_event(response):
event_session_id = response.get(
"sessionId",
"",
)
x = self._get_target_session_by_session_id(
event_session_id,
)
if not x:
continue
_, event_session = x
if not event_session:
_logger.error("Found an event that returned no session.")
continue

session_futures = self._subscriptions_futures.get(
event_session_id,
)
if session_futures:
for query in session_futures:
match = (
query.endswith("*")
and response["method"].startswith(query[:-1])
) or (response["method"] == query)
if match:
for future in session_futures[query]:
if not future.done():
future.set_result(response)
session_futures[query] = []

for query in list(event_session.subscriptions):
match = (
query.endswith("*")
and response["method"].startswith(query[:-1])
) or (response["method"] == query)
_logger.debug2(
f"Checking subscription key: {query} "
f"against event method {response['method']}",
)
if match:
t: asyncio.Task[Any] = asyncio.create_task(
event_session.subscriptions[query][0](response),
)
self._background_tasks_cancellable.add(t)
if not event_session.subscriptions[query][1]:
event_session.unsubscribe(query)

elif key:
if key in self.futures:
_logger.debug(f"run_read_loop() found future for key {key}")
future = self.futures.pop(key)
elif "error" in response:
raise protocol.DevtoolsProtocolError(response)
else:
raise RuntimeError(f"Couldn't find a future for key: {key}")
future.set_result(response)
else:
warnings.warn( # noqa: B028
f"Unhandled message type:{response!s}",
UnhandledMessageWarning,
)
except channels.ChannelClosedError:
_logger.debug("PipeClosedError caught")
self._background_tasks.add(asyncio.create_task(self._browser.close()))
return
read_task = asyncio.create_task(read_loop())
read_task.add_done_callback(check_error)
self._current_read_task = read_task

read_task = asyncio.create_task(read_loop())
read_task.add_done_callback(check_error)
self._current_read_task = read_task

async def write_json(
self,
obj: protocol.BrowserCommand,
) -> protocol.BrowserResponse:
_logger.debug2(f"In broker.write_json for {obj}")
protocol.verify_params(obj)
key = protocol.calculate_message_key(obj)
if not key:
raise RuntimeError(
"Message strangely formatted and "
"choreographer couldn't figure it out why.",
)
loop = asyncio.get_running_loop()
future: asyncio.Future[protocol.BrowserResponse] = loop.create_future()
self.futures[key] = future
_logger.debug(f"Created future: {key} {future}")
try:
await asyncio.to_thread(self._channel.write_json, obj)
except BaseException as e: # noqa: BLE001
future.set_exception(e)
del self.futures[key]
_logger.debug(f"Future for {key} deleted.")
return await future

def _get_target_session_by_session_id(
self,
session_id: str,
) -> tuple[Target, Session] | None:
if session_id == "":
return (self._browser, self._browser.sessions[session_id])
for tab in self._browser.tabs.values():
if session_id in tab.sessions:
return (tab, tab.sessions[session_id])
if session_id in self._browser.sessions:
return (self._browser, self._browser.sessions[session_id])
return None

def _check_for_closed_session(self, response: protocol.BrowserResponse) -> bool:
if "method" in response and response["method"] == "Target.detachedFromTarget":
session_closed = response["params"].get(
"sessionId",
"",
)
if session_closed == "":
return True

x = self._get_target_session_by_session_id(session_closed)
if x:
target_closed, _ = x
else:
return False

if target_closed:
target_closed._remove_session(session_closed) # noqa: SLF001
_logger.debug(
"Using intern subscription key: "
"'Target.detachedFromTarget'. "
f"Session {session_closed} was closed.",
)
return True
return False
else:
return False
Loading

0 comments on commit f02b05b

Please sign in to comment.