diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index f8d154f4d664..7351f53ef2a1 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -22,6 +22,7 @@ from itertools import count from math import ceil from time import perf_counter +from tqdm.contrib.logging import logging_redirect_tqdm from typing import Optional import torch @@ -813,6 +814,7 @@ def is_running(self) -> bool: """Check if the background generation thread is running.""" return self._generation_thread is not None and self._generation_thread.is_alive() + # NOTE: don't forget to update `continuous_batching_context_manager` when changing this method's definition def stop(self, block: bool = True, timeout: Optional[float] = None) -> None: """Signal the background thread to stop. @@ -1063,14 +1065,33 @@ class ContinuousMixin: """Mixin class for models to add continuous batching capabilities.""" @contextmanager - def continuous_batching_context_manager(self, **kwargs) -> Generator[ContinuousBatchingManager]: - manager = self.init_continuous_batching(**kwargs) + def continuous_batching_context_manager( + self, + generation_config: GenerationConfig | None = None, + manual_eviction: bool = False, + max_queue_size: int = 0, + num_q_cuda_graphs: int = 0, + num_kv_cuda_graphs: int = 0, + allow_prefix_sharing: bool = True, + block: bool = True, + timeout: Optional[float] = None, + ) -> Generator[ContinuousBatchingManager]: + manager = self.init_continuous_batching( + generation_config, + manual_eviction, + max_queue_size, + num_q_cuda_graphs, + num_kv_cuda_graphs, + allow_prefix_sharing, + ) manager.start() try: yield manager finally: - manager.stop(block=True) + logger.debug("Continuous batching loop finished") # a dummy log needed for the logs of stop to show. Won't show + manager.stop(block=block, timeout=timeout) + # NOTE: don't forget to update `continuous_batching_context_manager` when changing this method's definition def init_continuous_batching( self, generation_config: Optional[GenerationConfig] = None, @@ -1147,42 +1168,40 @@ def generate_batch( progress_bar = False # Initialize manager with the batch inputs - manager = self.init_continuous_batching( - generation_config=generation_config, - num_q_cuda_graphs=num_q_cuda_graphs, - num_kv_cuda_graphs=num_kv_cuda_graphs, - ) - manager.start() results = {} num_requests = len(inputs) - try: - from tqdm.contrib.logging import logging_redirect_tqdm - - with logging_redirect_tqdm([logger]): - with tqdm( - total=num_requests, - disable=(not progress_bar), - desc=f"Solving {num_requests} requests", - unit="request", - ) as pbar: - manager.add_requests(inputs=inputs, max_new_tokens=kwargs.get("max_new_tokens")) - finished_count = 0 - while finished_count < num_requests: - result = manager.get_result(timeout=1) - if result: - req_id = result.request_id - if result.is_finished(): - results[req_id] = result - finished_count += 1 - pbar.update(1) - else: - if not manager.is_running(): - logger.error("Generation thread terminated unexpectedly.") - break + with ( + self.continuous_batching_context_manager( + generation_config=generation_config, + num_q_cuda_graphs=num_q_cuda_graphs, + num_kv_cuda_graphs=num_kv_cuda_graphs, + block=True, + timeout=5, + ) as manager, + logging_redirect_tqdm([logger]), + tqdm( + total=num_requests, + disable=(not progress_bar), + desc=f"Solving {num_requests} requests", + unit="request", + ) as pbar, + ): + try: + manager.add_requests(inputs=inputs, max_new_tokens=kwargs.get("max_new_tokens")) + finished_count = 0 + while finished_count < num_requests: + result = manager.get_result(timeout=1) + if result: + req_id = result.request_id + if result.is_finished(): + results[req_id] = result + finished_count += 1 + pbar.update(1) + else: + if not manager.is_running(): + logger.error("Generation thread terminated unexpectedly.") + break - except Exception as e: - logger.error(f"Error during batch generation: {e}", exc_info=True) - finally: - logger.debug("Generate batch is finished.") # a dummy log needed for the logs of stop to show. Won't show. - manager.stop(block=True, timeout=5.0) + except Exception as e: + logger.error(f"Error during batch generation: {e}", exc_info=True) return results