Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 57 additions & 38 deletions src/transformers/generation/continuous_batching/continuous_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from itertools import count
from math import ceil
from time import perf_counter
from tqdm.contrib.logging import logging_redirect_tqdm
from typing import Optional

import torch
Expand Down Expand Up @@ -813,6 +814,7 @@ def is_running(self) -> bool:
"""Check if the background generation thread is running."""
return self._generation_thread is not None and self._generation_thread.is_alive()

# NOTE: don't forget to update `continuous_batching_context_manager` when changing this method's definition
def stop(self, block: bool = True, timeout: Optional[float] = None) -> None:
"""Signal the background thread to stop.

Expand Down Expand Up @@ -1063,14 +1065,33 @@ class ContinuousMixin:
"""Mixin class for models to add continuous batching capabilities."""

@contextmanager
def continuous_batching_context_manager(self, **kwargs) -> Generator[ContinuousBatchingManager]:
manager = self.init_continuous_batching(**kwargs)
def continuous_batching_context_manager(
self,
generation_config: GenerationConfig | None = None,
manual_eviction: bool = False,
max_queue_size: int = 0,
num_q_cuda_graphs: int = 0,
num_kv_cuda_graphs: int = 0,
allow_prefix_sharing: bool = True,
block: bool = True,
timeout: Optional[float] = None,
) -> Generator[ContinuousBatchingManager]:
manager = self.init_continuous_batching(
generation_config,
manual_eviction,
max_queue_size,
num_q_cuda_graphs,
num_kv_cuda_graphs,
allow_prefix_sharing,
)
manager.start()
try:
yield manager
finally:
manager.stop(block=True)
manager.stop(block=block, timeout=timeout)
logger.debug("Continuous batching manager stopped")

# NOTE: don't forget to update `continuous_batching_context_manager` when changing this method's definition
def init_continuous_batching(
self,
generation_config: Optional[GenerationConfig] = None,
Expand Down Expand Up @@ -1147,42 +1168,40 @@ def generate_batch(
progress_bar = False

# Initialize manager with the batch inputs
manager = self.init_continuous_batching(
generation_config=generation_config,
num_q_cuda_graphs=num_q_cuda_graphs,
num_kv_cuda_graphs=num_kv_cuda_graphs,
)
manager.start()
results = {}
num_requests = len(inputs)
try:
from tqdm.contrib.logging import logging_redirect_tqdm

with logging_redirect_tqdm([logger]):
with tqdm(
total=num_requests,
disable=(not progress_bar),
desc=f"Solving {num_requests} requests",
unit="request",
) as pbar:
manager.add_requests(inputs=inputs, max_new_tokens=kwargs.get("max_new_tokens"))
finished_count = 0
while finished_count < num_requests:
result = manager.get_result(timeout=1)
if result:
req_id = result.request_id
if result.is_finished():
results[req_id] = result
finished_count += 1
pbar.update(1)
else:
if not manager.is_running():
logger.error("Generation thread terminated unexpectedly.")
break
with (
self.continuous_batching_context_manager(
generation_config=generation_config,
num_q_cuda_graphs=num_q_cuda_graphs,
num_kv_cuda_graphs=num_kv_cuda_graphs,
block=True,
timeout=5,
) as manager,
logging_redirect_tqdm([logger]),
tqdm(
total=num_requests,
disable=(not progress_bar),
desc=f"Solving {num_requests} requests",
unit="request",
) as pbar,
):
try:
manager.add_requests(inputs=inputs, max_new_tokens=kwargs.get("max_new_tokens"))
finished_count = 0
while finished_count < num_requests:
result = manager.get_result(timeout=1)
if result:
req_id = result.request_id
if result.is_finished():
results[req_id] = result
finished_count += 1
pbar.update(1)
else:
if not manager.is_running():
logger.error("Generation thread terminated unexpectedly.")
break

except Exception as e:
logger.error(f"Error during batch generation: {e}", exc_info=True)
finally:
logger.debug("Generate batch is finished.") # a dummy log needed for the logs of stop to show. Won't show.
manager.stop(block=True, timeout=5.0)
except Exception as e:
logger.error(f"Error during batch generation: {e}", exc_info=True)
return results