From 57da056e7a0f9f9ebc0025f66982e1fa5ba78ab7 Mon Sep 17 00:00:00 2001 From: Ali Abid Date: Fri, 2 May 2025 15:19:14 -0700 Subject: [PATCH 1/2] changes --- demo/master_worker/run.ipynb | 1 + demo/master_worker/run.py | 21 +++ gradio/blocks.py | 43 +++++- gradio/data_classes.py | 6 + gradio/queueing.py | 260 ++++++++++++++++++++++++++++++----- gradio/routes.py | 77 +++++++++++ 6 files changed, 367 insertions(+), 41 deletions(-) create mode 100644 demo/master_worker/run.ipynb create mode 100644 demo/master_worker/run.py diff --git a/demo/master_worker/run.ipynb b/demo/master_worker/run.ipynb new file mode 100644 index 0000000000..7f311b7129 --- /dev/null +++ b/demo/master_worker/run.ipynb @@ -0,0 +1 @@ +{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: master_worker"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import sys\n", "import time\n", "\n", "def greet(name, prog = gr.Progress()):\n", " print(\"processing\", name)\n", " prog(0, desc=\"Starting...\")\n", " time.sleep(2)\n", " prog(0.5, desc=\"Halfway there...\")\n", " time.sleep(2)\n", " return \"Hello \" + name + \"!\"\n", "\n", "with gr.Blocks() as demo:\n", " name = gr.Textbox(label=\"Name\")\n", " output = gr.Textbox(label=\"Output Box\")\n", " name.submit(fn=greet, inputs=name, outputs=output, api_name=\"greet\") \n", "\n", "role = \"worker\" if \"-w\" in sys.argv else \"hybrid\"\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch(role=role, master_url=\"http://localhost:7860/\", app_key=\"test123\")\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/master_worker/run.py b/demo/master_worker/run.py new file mode 100644 index 0000000000..c9014e6345 --- /dev/null +++ b/demo/master_worker/run.py @@ -0,0 +1,21 @@ +import gradio as gr +import sys +import time + +def greet(name, prog = gr.Progress()): + print("processing", name) + prog(0, desc="Starting...") + time.sleep(2) + prog(0.5, desc="Halfway there...") + time.sleep(2) + return "Hello " + name + "!" + +with gr.Blocks() as demo: + name = gr.Textbox(label="Name") + output = gr.Textbox(label="Output Box") + name.submit(fn=greet, inputs=name, outputs=output, api_name="greet") + +role = "worker" if "-w" in sys.argv else "hybrid" + +if __name__ == "__main__": + demo.launch(role=role, master_url="http://localhost:7860/", app_key="test123") diff --git a/gradio/blocks.py b/gradio/blocks.py index da4d111934..2a2fde289a 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -502,10 +502,10 @@ def __init__( inputs_as_dict: bool, targets: list[tuple[int | None, str]], _id: int, + concurrency_id: str, + concurrency_limit: int | None | Literal["default"] = "default", batch: bool = False, max_batch_size: int = 4, - concurrency_limit: int | None | Literal["default"] = "default", - concurrency_id: str | None = None, tracks_progress: bool = False, api_name: str | Literal[False] = False, js: str | Literal[True] | None = None, @@ -538,7 +538,7 @@ def __init__( self.postprocess = postprocess self.tracks_progress = tracks_progress self.concurrency_limit: int | None | Literal["default"] = concurrency_limit - self.concurrency_id = concurrency_id or str(id(fn)) + self.concurrency_id = concurrency_id self.batch = batch self.max_batch_size = max_batch_size self.total_runtime = 0 @@ -865,6 +865,19 @@ def set_event_trigger( "Cannot create event: events with js=True cannot have inputs." ) + if concurrency_id is None: + concurrency_hash = str(id(fn)) if fn is not None else None + if concurrency_hash in self.root_block.concurrency_hash_to_id: + concurrency_id = self.root_block.concurrency_hash_to_id[ + concurrency_hash + ] + else: + concurrency_id = str(self.root_block.running_concurrency_id) + self.root_block.concurrency_hash_to_id[concurrency_hash] = ( + concurrency_id + ) + self.root_block.running_concurrency_id += 1 + block_fn = BlockFunction( fn, inputs, @@ -1200,6 +1213,8 @@ def __init__( self.blocked_paths = [] self.root_path = os.environ.get("GRADIO_ROOT_PATH", "") self.proxy_urls = set() + self.concurrency_hash_to_id = {} + self.running_concurrency_id = 0 self.pages: list[tuple[str, str]] = [("", "Home")] self.current_page = "" @@ -2432,6 +2447,9 @@ def launch( node_port: int | None = None, ssr_mode: bool | None = None, pwa: bool | None = None, + app_key: str | None = None, + role: Literal["hybrid", "master", "worker"] | None = None, + master_url: str | None = None, _frontend: bool = True, ) -> tuple[App, str, str]: """ @@ -2472,6 +2490,9 @@ def launch( strict_cors: If True, prevents external domains from making requests to a Gradio server running on localhost. If False, allows requests to localhost that originate from localhost but also, crucially, from "null". This parameter should normally be True to prevent CSRF attacks but may need to be False when embedding a *locally-running Gradio app* using web components. ssr_mode: If True, the Gradio app will be rendered using server-side rendering mode, which is typically more performant and provides better SEO, but this requires Node 20+ to be installed on the system. If False, the app will be rendered using client-side rendering mode. If None, will use GRADIO_SSR_MODE environment variable or default to False. pwa: If True, the Gradio app will be set up as an installable PWA (Progressive Web App). If set to None (default behavior), then the PWA feature will be enabled if this Gradio app is launched on Spaces, but not otherwise. + app_key: Used for communication in master/worker setups - must be the same across master and all workers. If not provided, will use the GRADIO_APP_KEY environment variable. + role: Role in master/worker setup. "hybrid" (default) means this app will both receive and process tasks, and other workers can attach to this app to process tasks. "master" means this app will only receive tasks. "worker" means this app will only process tasks and not receive them, and this requires master_url to be set. Will load from GRADIO_ROLE environment variable if not provided. + master_url: The URL of the master app in a master/worker setup. This is required if role is set to "worker", otherwise this is ignored. Will load from GRADIO_MASTER_URL environment variable if not provided. Returns: app: FastAPI app object that is running the demo local_url: Locally accessible link to the demo @@ -2531,6 +2552,22 @@ def reverse(text): self.favicon_path = favicon_path self.ssl_verify = ssl_verify self.state_session_capacity = state_session_capacity + self.app_key = app_key or os.environ.get("GRADIO_APP_KEY", None) + role = role or os.environ.get("GRADIO_ROLE", "hybrid") + if role in ("master", "worker") and self.app_key is None: + raise ValueError( + "You must provide a secret app_key if you are launching a master/worker setup. This can be set via `.launch(app_key='...')` or through the GRADIO_APP_KEY environment variable. Masters and workers must all share the same app key." + ) + if role == "worker" and master_url is None: + raise ValueError( + "You must provide a master_url if you are launching a worker." + ) + if role == "master" and any(fn.renderable for fn in self.fns.values()): + raise ValueError( + "You cannot use `role='master'` with renderable components, use role='hybrid' instead. This machine will handle all renders." + ) + self._queue.role = role + self._queue.master_url = master_url.strip("/") if root_path is None: self.root_path = os.environ.get("GRADIO_ROOT_PATH", "") else: diff --git a/gradio/data_classes.py b/gradio/data_classes.py index 7fac5a7ab8..ddadecaa77 100644 --- a/gradio/data_classes.py +++ b/gradio/data_classes.py @@ -428,3 +428,9 @@ class ImageData(GradioModel): class Base64ImageData(GradioModel): url: str = Field(description="base64 encoded image") + + +class EventAnalytics(BaseModel): + event_id: str + key: str + value: str | float | int | None diff --git a/gradio/queueing.py b/gradio/queueing.py index f2585c2bf4..786bfeff03 100644 --- a/gradio/queueing.py +++ b/gradio/queueing.py @@ -2,6 +2,7 @@ import asyncio import copy +import json import os import random import time @@ -12,6 +13,7 @@ from typing import TYPE_CHECKING, Literal, cast import fastapi +import requests from gradio import route_utils, routes, wasm_utils from gradio.data_classes import ( @@ -19,6 +21,7 @@ ) from gradio.exceptions import Error from gradio.helpers import TrackedIterable +from gradio.route_utils import API_PREFIX from gradio.server_messages import ( EstimationMessage, EventMessage, @@ -64,6 +67,52 @@ def __init__( self.run_time: float = 0 self.signal = asyncio.Event() + def json(self): + return { + "id": self._id, + "session_hash": self.session_hash, + "fn_index": self.fn._id, + "data": self.data.model_dump_json(exclude=["request"]) + if self.data + else None, + "username": self.username, + "concurrency_id": self.concurrency_id, + "request_scope": { + "type": "http", + "headers": self.request.headers.items() + if self.request.headers + else None, + "method": self.request.method, + "path": self.request.url.path, + "query_string": dict(self.request.query_params) + if self.request.query_params + else None, + "server": self.request.scope.get("server"), + "path_params": self.request.path_params, + "cookies": self.request.cookies, + }, + } + + @classmethod + def from_json(cls, blocks: Blocks, json_data: dict) -> Event: + json_data["request_scope"]["headers"] = [ + (key.encode("utf-8"), value.encode("utf-8")) + for key, value in json_data["request_scope"]["headers"] + ] + request = fastapi.Request( + scope=json_data["request_scope"], + ) + event = cls( + session_hash=json_data["session_hash"], + fn=blocks.fns[json_data["fn_index"]], + request=request, + username=json_data["username"], + ) + event._id = json_data["id"] + event.concurrency_id = json_data["concurrency_id"] + event.data = PredictBodyInternal.model_validate_json(json_data["data"]) + return event + @property def streaming(self): return self.fn.connection == "stream" @@ -83,6 +132,8 @@ def __init__(self, concurrency_id: str, concurrency_limit: int | None): self.concurrency_id = concurrency_id self.concurrency_limit = concurrency_limit self.current_concurrency = 0 + self.handled_fn_ids: set[str] = set() + self.includes_render: bool = False self.start_times_per_fn: defaultdict[BlockFunction, set[float]] = defaultdict( set ) @@ -120,6 +171,7 @@ def __init__( self.stopped = False self.max_thread_count = concurrency_count self.update_intervals = update_intervals + self.active_jobs: list[None | list[Event]] = [] self.delete_lock = safe_get_lock() self.server_app = None @@ -129,6 +181,7 @@ def __init__( self.live_updates = live_updates self.sleep_when_free = 0.05 self.progress_update_sleep_when_free = 0.1 + self.check_workers_live_sleep = 1 self.max_size = max_size self.blocks = blocks self._asyncio_tasks: list[asyncio.Task] = [] @@ -136,16 +189,33 @@ def __init__( default_concurrency_limit ) self.event_analytics: dict[str, dict[str, float | str | None]] = {} + self.role: Literal["master", "worker", "hybrid"] + self.master_url: str + self.connected_to_master: bool = False + self.events_by_id_per_worker: dict[int, dict[str, Event]] = {} def start(self): - self.active_jobs = [None] * self.max_thread_count - - run_coro_in_background(self.start_processing) + if self.role in ["worker", "hybrid"]: + self.active_jobs = [None] * self.max_thread_count + run_coro_in_background(self.start_processing) + if self.role in ["master", "hybrid"] and not self.live_updates: + run_coro_in_background(self.notify_clients) + if self.role == "worker": + run_coro_in_background(self.connect_to_master) run_coro_in_background(self.start_progress_updates) - if not self.live_updates: - run_coro_in_background(self.notify_clients) + + for fn in self.blocks.fns.values(): + self.create_event_queue_for_fn(fn) def create_event_queue_for_fn(self, block_fn: BlockFunction): + if ( + block_fn.concurrency_id in self.event_queue_per_concurrency_id + and block_fn._id + in self.event_queue_per_concurrency_id[ + block_fn.concurrency_id + ].handled_fn_ids + ): + return concurrency_id = block_fn.concurrency_id concurrency_limit: int | None if block_fn.concurrency_limit == "default": @@ -165,6 +235,11 @@ def create_event_queue_for_fn(self, block_fn: BlockFunction): or concurrency_limit < existing_event_queue.concurrency_limit ): existing_event_queue.concurrency_limit = concurrency_limit + self.event_queue_per_concurrency_id[concurrency_id].handled_fn_ids.add( + block_fn._id + ) + if block_fn.renderable: + self.event_queue_per_concurrency_id[concurrency_id].includes_render = True def close(self): self.stopped = True @@ -177,8 +252,25 @@ def send_message( if not event.alive: return event_message.event_id = event._id - messages = self.pending_messages_per_session[event.session_hash] - messages.put_nowait(event_message) + if self.role == "worker": + response = requests.post( + self.master_url + API_PREFIX + "/queue/message", + params={ + "app_key": self.blocks.app_key, + "worker_id": self.blocks.app_id, + }, + json={ + "event": event.json(), + "event_message": json.loads(event_message.model_dump_json()), + }, + ) + if not response.ok: + raise Exception( + f"Failed to send message to master: {response.status_code} {response.text}" + ) + else: + messages = self.pending_messages_per_session[event.session_hash] + messages.put_nowait(event_message) def _resolve_concurrency_limit( self, default_concurrency_limit: int | None | Literal["not_set"] @@ -258,9 +350,31 @@ async def push( "session_hash": body.session_hash, } - self.broadcast_estimations(event.concurrency_id, len(event_queue.queue) - 1) + if self.role in ["master", "hybrid"]: + self.broadcast_estimations(event.concurrency_id, len(event_queue.queue) - 1) return True, event._id + def update_event_analytics(self, event_id, key: str, value: float | str | None): + if self.role == "worker": + response = requests.post( + self.master_url + API_PREFIX + "/queue/event_analytics", + params={ + "app_key": self.blocks.app_key, + "worker_id": self.blocks.app_id, + }, + json={ + "event_id": event_id, + "key": key, + "value": value, + }, + ) + if not response.ok: + raise Exception( + f"Failed to update event analytics: {response.status_code} {response.text}" + ) + elif event_id in self.event_analytics: + self.event_analytics[event_id][key] = value + def _cancel_asyncio_tasks(self): for task in self._asyncio_tasks: task.cancel() @@ -276,35 +390,96 @@ def get_active_worker_count(self) -> int: count += 1 return count - def get_events(self) -> tuple[list[Event], bool, str] | None: - concurrency_ids = list(self.event_queue_per_concurrency_id.keys()) - random.shuffle(concurrency_ids) - for concurrency_id in concurrency_ids: - event_queue = self.event_queue_per_concurrency_id[concurrency_id] - if len(event_queue.queue) and ( - event_queue.concurrency_limit is None - or event_queue.current_concurrency < event_queue.concurrency_limit - ): - first_event = event_queue.queue[0] - block_fn = first_event.fn - events = [first_event] - batch = block_fn.batch - if batch: - events += [ - event - for event in event_queue.queue[1:] - if event.fn == first_event.fn - ][: block_fn.max_batch_size - 1] + async def get_events_from_master( + self, concurrency_ids: list[str] + ) -> tuple[list[Event], bool, str] | None: + response = requests.get( + self.master_url + API_PREFIX + "/queue/events", + headers={"Content-Type": "application/json"}, + params={"app_key": self.blocks.app_key, "worker_id": self.blocks.app_id}, + data=json.dumps(concurrency_ids), + ) + if not response.ok: + raise Exception( + f"Failed to get events from master: {response.status_code} {response.text}" + ) + data = response.json() + if data: + return ( + [Event.from_json(self.blocks, event) for event in data["events"]], + data["batch"], + data["concurrency_id"], + ) + else: + return - for event in events: - event_queue.queue.remove(event) + async def get_events( + self, concurrency_ids: list[str] | None = None + ) -> tuple[list[Event], bool, str] | None: + if concurrency_ids is None: + concurrency_ids = [] + for ( + concurrency_id, + event_queue, + ) in self.event_queue_per_concurrency_id.items(): + if self.role == "worker" and event_queue.includes_render: + continue + if ( + event_queue.concurrency_limit is None + or event_queue.current_concurrency < event_queue.concurrency_limit + ): + concurrency_ids.append(concurrency_id) + if self.role == "worker": + if len(concurrency_ids) == 0: + return + return await self.get_events_from_master(concurrency_ids) + async with self.delete_lock: + random.shuffle(concurrency_ids) + for concurrency_id in concurrency_ids: + event_queue = self.event_queue_per_concurrency_id[concurrency_id] + if len(event_queue.queue): + first_event = event_queue.queue[0] + block_fn = first_event.fn + events = [first_event] + batch = block_fn.batch + if batch: + events += [ + event + for event in event_queue.queue[1:] + if event.fn == first_event.fn + ][: block_fn.max_batch_size - 1] + + for event in events: + event_queue.queue.remove(event) - return events, batch, concurrency_id + return events, batch, concurrency_id + + async def connect_to_master(self) -> None: + try: + while True: + response = requests.get( + self.master_url + API_PREFIX + "/queue/attach_worker", + params={ + "app_key": self.blocks.app_key, + "worker_id": self.blocks.app_id, + }, + stream=True, + ) + if response.ok: + break + + # for chunk in response.iter_lines(): + # print(chunk) + # pass + + except BaseException as e: + print(f"Error connecting to master: {e}") + return async def start_processing(self) -> None: try: while not self.stopped: - if len(self) == 0: + if self.role != "worker" and len(self) == 0: await asyncio.sleep(self.sleep_when_free) continue @@ -313,8 +488,7 @@ async def start_processing(self) -> None: continue # Using mutex to avoid editing a list in use - async with self.delete_lock: - event_batch = self.get_events() + event_batch = await self.get_events() if event_batch: events, batch, concurrency_id = event_batch @@ -324,7 +498,7 @@ async def start_processing(self) -> None: start_time = time.time() event_queue.start_times_per_fn[events[0].fn].add(start_time) for event in events: - self.event_analytics[event._id]["status"] = "processing" + self.update_event_analytics(event._id, "status", "processing") process_event_task = run_coro_in_background( self.process_events, events, batch, start_time ) @@ -341,6 +515,8 @@ async def start_processing(self) -> None: self.broadcast_estimations(concurrency_id) else: await asyncio.sleep(self.sleep_when_free) + except Exception: + traceback.print_exc() finally: self.stopped = True self._cancel_asyncio_tasks() @@ -442,6 +618,14 @@ async def notify_clients(self) -> None: for concurrency_id in self.event_queue_per_concurrency_id: self.broadcast_estimations(concurrency_id) + def remove_worker(self, worker_id: int) -> None: + events = self.events_by_id_per_worker[worker_id] + for event in events.values(): + self.event_queue_per_concurrency_id[event.concurrency_id].queue.append( + event + ) + del self.events_by_id_per_worker[worker_id] + def broadcast_estimations( self, concurrency_id: str, after: int | None = None ) -> None: @@ -767,7 +951,7 @@ async def process_events( ) self.process_time_per_fn[events[0].fn].add(duration) for event in events: - self.event_analytics[event._id]["process_time"] = duration + self.update_event_analytics(event._id, "process_time", duration) except Exception as e: if not isinstance(e, Error) or e.print_exception: traceback.print_exc() @@ -793,11 +977,11 @@ async def process_events( await self.reset_iterators(event._id) if event in awake_events: - self.event_analytics[event._id]["status"] = ( - "success" if success else "failed" + self.update_event_analytics( + event._id, "status", "success" if success else "failed" ) else: - self.event_analytics[event._id]["status"] = "cancelled" + self.update_event_analytics(event._id, "status", "cancelled") async def reset_iterators(self, event_id: str): # Do the same thing as the /reset route diff --git a/gradio/routes.py b/gradio/routes.py index 613e07428e..7cc3a05421 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -71,6 +71,7 @@ ComponentServerJSONBody, DataWithFiles, DeveloperPath, + EventAnalytics, PredictBody, PredictBodyInternal, ResetBody, @@ -82,6 +83,7 @@ start_node_server, ) from gradio.oauth import attach_oauth +from gradio.queueing import Event from gradio.route_utils import ( # noqa: F401 API_PREFIX, CustomCORSMiddleware, @@ -1405,6 +1407,81 @@ async def component_server( async def get_queue_status(): return app.get_blocks()._queue.get_status() + @router.get("/queue/attach_worker") + async def attach_worker(app_key: str, worker_id: int): + blocks = app.get_blocks() + if app_key != blocks.app_key: + raise HTTPException(status_code=403, detail="Invalid app key.") + + blocks._queue.events_by_id_per_worker[worker_id] = {} + + async def worker_heartbeat(): + try: + while not blocks._queue.stopped: + yield "data: heartbeat\n\n" + await asyncio.sleep(5) + except BaseException: + pass + # blocks._queue.remove_worker(worker_id) + + return StreamingResponse( + worker_heartbeat(), + media_type="text/event-stream", + ) + + @router.post( + "/queue/message", + ) + async def queue_message( + app_key: str, + worker_id: int, + event: dict[str, Any], + event_message: EventMessage, + ): + blocks = app.get_blocks() + if app_key != blocks.app_key: + raise HTTPException(status_code=403, detail="Invalid app key.") + + event_object = Event.from_json(blocks, event) + blocks._queue.send_message(event_object, event_message) + if event_message.msg == ServerMessage.process_completed: + del blocks._queue.events_by_id_per_worker[worker_id][event_object._id] + + @router.get("/queue/events") + async def get_events(app_key: str, worker_id: int, concurrency_ids: list[str]): + blocks = app.get_blocks() + if app_key != blocks.app_key: + raise HTTPException(status_code=403, detail="Invalid app key.") + + events_response = await blocks._queue.get_events(concurrency_ids) + if events_response: + events, batch, concurrency_id = events_response + blocks._queue.events_by_id_per_worker[worker_id].update( + {event._id: event for event in events} + ) + events = [event.json() for event in events] + return { + "events": events, + "batch": batch, + "concurrency_id": concurrency_id, + } + else: + return + + @router.post( + "/queue/event_analytics", + ) + async def queue_analytics( + app_key: str, event_analytics: EventAnalytics + ): + blocks = app.get_blocks() + if app_key != blocks.app_key: + raise HTTPException(status_code=403, detail="Invalid app key.") + + blocks._queue.update_event_analytics( + event_analytics.event_id, event_analytics.key, event_analytics.value + ) + @router.get("/upload_progress") def get_upload_progress(upload_id: str, request: fastapi.Request): async def sse_stream(request: fastapi.Request): From 730f198aa2f68b22c6bfe0ba5019fd9a9e9f295d Mon Sep 17 00:00:00 2001 From: Ali Abid Date: Sat, 3 May 2025 06:17:10 -0700 Subject: [PATCH 2/2] changes --- gradio/blocks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gradio/blocks.py b/gradio/blocks.py index 2a2fde289a..2922f9a644 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -2567,7 +2567,8 @@ def reverse(text): "You cannot use `role='master'` with renderable components, use role='hybrid' instead. This machine will handle all renders." ) self._queue.role = role - self._queue.master_url = master_url.strip("/") + if master_url: + self._queue.master_url = master_url.strip("/") if root_path is None: self.root_path = os.environ.get("GRADIO_ROOT_PATH", "") else: