diff --git a/src/litserve/api.py b/src/litserve/api.py index e2dd9eea..a1f0b8fa 100644 --- a/src/litserve/api.py +++ b/src/litserve/api.py @@ -56,10 +56,9 @@ def batch(self, inputs): return inputs - @abstractmethod def predict(self, x, **kwargs): """Run the model on the input and return or yield the output.""" - pass + raise NotImplementedError("predict is not implemented") def _unbatch_no_stream(self, output): if isinstance(output, str): @@ -121,6 +120,7 @@ def pre_setup(self, max_batch_size: int, spec: Optional[LitSpec]): if spec: self._spec = spec + spec.pre_setup(self) def set_logger_queue(self, queue: Queue): """Set the queue for logging events.""" diff --git a/src/litserve/loops.py b/src/litserve/loops.py index 46a309c8..521c6783 100644 --- a/src/litserve/loops.py +++ b/src/litserve/loops.py @@ -700,18 +700,62 @@ class Output: class ContinuousBatchingLoop(LitLoop): def __init__(self, max_sequence_length: int = 2048): + """Runs continuous batching loop. This loop handles adding new requests, processing them in batches, and + managing the state of active sequences. + + The loop requires the following methods to be implemented in the LitAPI: + - setup: sets up the model on the device + - decode_request: decodes the client request into a format that can be processed by the model + - step: generates a new token for each sequence + - encode_response: encodes the response into a format that can be sent to the client + - has_finished: checks if the sequence has finished generating + + Args: + max_sequence_length (int): The maximum sequence length allowed for any active sequence. + + """ super().__init__() - self.active_sequences: Dict[str, Dict] = {} # uid -> {input, current_length, generated_tokens} + self.active_sequences: Dict[str, Dict] = {} # uid -> {input, current_length, generated_sequence} self.max_sequence_length = max_sequence_length self.response_queue_ids: Dict[str, int] = {} # uid -> response_queue_id + def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]): + if not lit_api.stream: + raise ValueError( + "Continuous batching loop requires streaming to be enabled. Please set LitServe(..., stream=True)" + ) + + if not hasattr(lit_api, "step") and not hasattr(lit_api, "predict"): + raise ValueError("""Using the default step method with Continuous batching loop requires the lit_api to +have a `predict` method which accepts decoded request inputs and a list of generated_sequence. +Please implement the has_finished method in the lit_api. + + class ExampleAPI(LitAPI): + ... + def predict(self, inputs, generated_sequence): + # implement predict logic + # return list of new tokens + ... + """) + + if not hasattr(lit_api, "step") and not hasattr(lit_api, "has_finished"): + raise ValueError("""Using the default step method with Continuous batching loop +requires the lit_api to have a has_finished method. Please implement the has_finished method in the lit_api. + + class ExampleAPI(LitAPI): + ... + def has_finished(self, uid: str, token: str, max_sequence_length: int) -> bool: + # implement has_finished logic + return False + """) + def add_request(self, uid: str, request: Any, lit_api: LitAPI, lit_spec: Optional[LitSpec]) -> None: - """Add a new sequence to active sequences.""" + """Add a new sequence to active sequences and perform any action before prediction such as filling the cache.""" decoded_request = lit_api.decode_request(request) - self.active_sequences[uid] = {"input": decoded_request, "current_length": 0, "generated_tokens": []} + self.active_sequences[uid] = {"input": decoded_request, "current_length": 0, "generated_sequence": []} def mark_completed(self, uid: str) -> None: - """Mark a sequence as completed.""" + """Mark a request as completed and remove it from the tracked state.""" logger.info(f"Marking sequence {uid} as completed") del self.active_sequences[uid] del self.response_queue_ids[uid] @@ -725,36 +769,39 @@ def has_capacity(self, lit_api: LitAPI) -> bool: ) return capacity - def step( - self, prev_outputs: Optional[List[Output]], lit_api: LitAPI, lit_spec: Optional[LitSpec] - ) -> List[Tuple[str, Tuple[Any, LitAPIStatus]]]: + def step(self, prev_outputs: Optional[List[Output]], lit_api: LitAPI, lit_spec: Optional[LitSpec]) -> List[Output]: """Process one token generation step for all active sequences.""" + if hasattr(lit_api, "step"): + return lit_api.step(prev_outputs) + if not self.active_sequences: return [] # Batch forward pass for all active sequences inputs = [seq["input"] for seq in self.active_sequences.values()] - generated = [seq["generated_tokens"] for seq in self.active_sequences.values()] + generated = [seq["generated_sequence"] for seq in self.active_sequences.values()] try: # Assume lit_api.predict handles batched token generation - new_tokens = lit_api.predict(inputs, generated) + new_tokens: List[Any] = lit_api.predict(inputs, generated) - responses = [] + responses: List[Output] = [] # Process each sequence's new token for uid, token in zip(self.active_sequences.keys(), new_tokens): seq = self.active_sequences[uid] - seq["generated_tokens"].append(token) + seq["generated_sequence"].append(token) seq["current_length"] += 1 + step_output = Output(uid, token, LitAPIStatus.OK) + responses.append(step_output) + # Check completion conditions - is_finished = lit_api.is_finished(uid, token, self.max_sequence_length) + is_finished = lit_api.has_finished(uid, token, self.max_sequence_length) if is_finished: # Encode final response for completed sequence - response = lit_api.encode_response(seq["generated_tokens"]) - step_output = Output(uid, response, LitAPIStatus.FINISH_STREAMING) + step_output = Output(uid, "", LitAPIStatus.FINISH_STREAMING) responses.append(step_output) return responses @@ -815,11 +862,6 @@ def run( workers_setup_status: Dict[int, str], callback_runner: CallbackRunner, ): - if not lit_api.stream: - raise ValueError( - "Continuous batching loop requires streaming to be enabled. Please set LitServe(..., stream=True)" - ) - """Main loop that processes batches of requests.""" pending_requests = self.prefill( [], @@ -846,7 +888,7 @@ def run( for step_output in responses: logger.debug(f"Processing response: {step_output}") status = step_output.status - response_data = step_output.output + response_data = lit_api.encode_response(step_output.output) uid = step_output.uid response_queue_id = self.response_queue_ids[uid] @@ -870,7 +912,7 @@ def run( ) except Exception as e: - logger.exception("Error in continuous batching loop") + logger.exception(f"Error in continuous batching loop: {e}") # Handle any errors by sending error responses for all tracked requests for uid, response_queue_id in self.response_queue_ids.items(): self.put_error_response(response_queues, response_queue_id, uid, e) diff --git a/src/litserve/specs/base.py b/src/litserve/specs/base.py index 5691920d..6f5afc7b 100644 --- a/src/litserve/specs/base.py +++ b/src/litserve/specs/base.py @@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Callable, List if TYPE_CHECKING: - from litserve import LitServer + from litserve import LitAPI, LitServer class LitSpec: @@ -26,6 +26,9 @@ def __init__(self): self._server: LitServer = None + def pre_setup(self, lit_api: "LitAPI"): + pass + def setup(self, server: "LitServer"): self._server = server diff --git a/src/litserve/specs/openai.py b/src/litserve/specs/openai.py index 40e340ad..f7447d88 100644 --- a/src/litserve/specs/openai.py +++ b/src/litserve/specs/openai.py @@ -30,7 +30,7 @@ from litserve.utils import LitAPIStatus, azip if typing.TYPE_CHECKING: - from litserve import LitServer + from litserve import LitAPI, LitServer logger = logging.getLogger(__name__) @@ -262,18 +262,18 @@ def __init__( self.add_endpoint("/v1/chat/completions", self.chat_completion, ["POST"]) self.add_endpoint("/v1/chat/completions", self.options_chat_completions, ["OPTIONS"]) - def setup(self, server: "LitServer"): + def pre_setup(self, lit_api: "LitAPI"): from litserve import LitAPI - super().setup(server) - - lit_api = self._server.lit_api if not inspect.isgeneratorfunction(lit_api.predict): raise ValueError(LITAPI_VALIDATION_MSG.format("predict is not a generator")) is_encode_response_original = lit_api.encode_response.__code__ is LitAPI.encode_response.__code__ if not is_encode_response_original and not inspect.isgeneratorfunction(lit_api.encode_response): raise ValueError(LITAPI_VALIDATION_MSG.format("encode_response is not a generator")) + + def setup(self, server: "LitServer"): + super().setup(server) print("OpenAI spec setup complete") def populate_context(self, context, request): diff --git a/tests/test_litapi.py b/tests/test_litapi.py index 797ecbdf..6c92c018 100644 --- a/tests/test_litapi.py +++ b/tests/test_litapi.py @@ -175,7 +175,7 @@ def predict(): def test_encode_response_with_custom_spec_api(): - class CustomSpecAPI(ls.test_examples.TestAPI): + class CustomSpecAPI(ls.OpenAISpec): def encode_response(self, output_stream): for output in output_stream: yield {"content": output} diff --git a/tests/test_specs.py b/tests/test_specs.py index 360b376c..78b71980 100644 --- a/tests/test_specs.py +++ b/tests/test_specs.py @@ -176,15 +176,11 @@ def encode_response(self, output): @pytest.mark.asyncio async def test_openai_spec_validation(openai_request_data): - server = ls.LitServer(IncorrectAPI1(), spec=OpenAISpec()) - with pytest.raises(ValueError, match="predict is not a generator"), wrap_litserve_start(server) as server: - async with LifespanManager(server.app) as manager: - await manager.shutdown() - - server = ls.LitServer(IncorrectAPI2(), spec=OpenAISpec()) - with pytest.raises(ValueError, match="encode_response is not a generator"), wrap_litserve_start(server) as server: - async with LifespanManager(server.app) as manager: - await manager.shutdown() + with pytest.raises(ValueError, match="predict is not a generator"): + ls.LitServer(IncorrectAPI1(), spec=OpenAISpec()) + + with pytest.raises(ValueError, match="encode_response is not a generator"): + ls.LitServer(IncorrectAPI2(), spec=OpenAISpec()) class PrePopulatedAPI(ls.LitAPI):