Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make LitAPI.predict optional and validate API implementation #394

Merged
merged 4 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/litserve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,9 @@ def batch(self, inputs):

return inputs

@abstractmethod
def predict(self, x, **kwargs):
"""Run the model on the input and return or yield the output."""
pass
raise NotImplementedError("predict is not implemented")

def _unbatch_no_stream(self, output):
if isinstance(output, str):
Expand Down Expand Up @@ -121,6 +120,7 @@ def pre_setup(self, max_batch_size: int, spec: Optional[LitSpec]):

if spec:
self._spec = spec
spec.pre_setup(self)

def set_logger_queue(self, queue: Queue):
"""Set the queue for logging events."""
Expand Down
84 changes: 63 additions & 21 deletions src/litserve/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,18 +700,62 @@ class Output:

class ContinuousBatchingLoop(LitLoop):
def __init__(self, max_sequence_length: int = 2048):
"""Runs continuous batching loop. This loop handles adding new requests, processing them in batches, and
managing the state of active sequences.

The loop requires the following methods to be implemented in the LitAPI:
- setup: sets up the model on the device
- decode_request: decodes the client request into a format that can be processed by the model
- step: generates a new token for each sequence
- encode_response: encodes the response into a format that can be sent to the client
- has_finished: checks if the sequence has finished generating

Args:
max_sequence_length (int): The maximum sequence length allowed for any active sequence.

"""
super().__init__()
self.active_sequences: Dict[str, Dict] = {} # uid -> {input, current_length, generated_tokens}
self.active_sequences: Dict[str, Dict] = {} # uid -> {input, current_length, generated_sequence}
self.max_sequence_length = max_sequence_length
self.response_queue_ids: Dict[str, int] = {} # uid -> response_queue_id

def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]):
if not lit_api.stream:
raise ValueError(
"Continuous batching loop requires streaming to be enabled. Please set LitServe(..., stream=True)"
)

if not hasattr(lit_api, "step") and not hasattr(lit_api, "predict"):
raise ValueError("""Using the default step method with Continuous batching loop requires the lit_api to
have a `predict` method which accepts decoded request inputs and a list of generated_sequence.
Please implement the has_finished method in the lit_api.

class ExampleAPI(LitAPI):
...
def predict(self, inputs, generated_sequence):
# implement predict logic
# return list of new tokens
...
""")

if not hasattr(lit_api, "step") and not hasattr(lit_api, "has_finished"):
raise ValueError("""Using the default step method with Continuous batching loop
requires the lit_api to have a has_finished method. Please implement the has_finished method in the lit_api.

class ExampleAPI(LitAPI):
...
def has_finished(self, uid: str, token: str, max_sequence_length: int) -> bool:
# implement has_finished logic
return False
""")

def add_request(self, uid: str, request: Any, lit_api: LitAPI, lit_spec: Optional[LitSpec]) -> None:
"""Add a new sequence to active sequences."""
"""Add a new sequence to active sequences and perform any action before prediction such as filling the cache."""
decoded_request = lit_api.decode_request(request)
self.active_sequences[uid] = {"input": decoded_request, "current_length": 0, "generated_tokens": []}
self.active_sequences[uid] = {"input": decoded_request, "current_length": 0, "generated_sequence": []}

def mark_completed(self, uid: str) -> None:
"""Mark a sequence as completed."""
"""Mark a request as completed and remove it from the tracked state."""
logger.info(f"Marking sequence {uid} as completed")
del self.active_sequences[uid]
del self.response_queue_ids[uid]
Expand All @@ -725,36 +769,39 @@ def has_capacity(self, lit_api: LitAPI) -> bool:
)
return capacity

def step(
self, prev_outputs: Optional[List[Output]], lit_api: LitAPI, lit_spec: Optional[LitSpec]
) -> List[Tuple[str, Tuple[Any, LitAPIStatus]]]:
def step(self, prev_outputs: Optional[List[Output]], lit_api: LitAPI, lit_spec: Optional[LitSpec]) -> List[Output]:
"""Process one token generation step for all active sequences."""
if hasattr(lit_api, "step"):
return lit_api.step(prev_outputs)

if not self.active_sequences:
return []

# Batch forward pass for all active sequences
inputs = [seq["input"] for seq in self.active_sequences.values()]
generated = [seq["generated_tokens"] for seq in self.active_sequences.values()]
generated = [seq["generated_sequence"] for seq in self.active_sequences.values()]

try:
# Assume lit_api.predict handles batched token generation
new_tokens = lit_api.predict(inputs, generated)
new_tokens: List[Any] = lit_api.predict(inputs, generated)

responses = []
responses: List[Output] = []

# Process each sequence's new token
for uid, token in zip(self.active_sequences.keys(), new_tokens):
seq = self.active_sequences[uid]
Copy link
Member

@Borda Borda Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a bit offtopic to this PR but how about using Dataclass for seq instend of dict?
would make it easier to work with IDE as dataclass can be tracked compare to dict where is anything

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I agree @Borda. I think dataclass here would be good for IDE and also less chances to make mistake.

seq["generated_tokens"].append(token)
seq["generated_sequence"].append(token)
seq["current_length"] += 1

step_output = Output(uid, token, LitAPIStatus.OK)
responses.append(step_output)

# Check completion conditions
is_finished = lit_api.is_finished(uid, token, self.max_sequence_length)
is_finished = lit_api.has_finished(uid, token, self.max_sequence_length)

if is_finished:
# Encode final response for completed sequence
response = lit_api.encode_response(seq["generated_tokens"])
step_output = Output(uid, response, LitAPIStatus.FINISH_STREAMING)
step_output = Output(uid, "", LitAPIStatus.FINISH_STREAMING)
responses.append(step_output)

return responses
Expand Down Expand Up @@ -815,11 +862,6 @@ def run(
workers_setup_status: Dict[int, str],
callback_runner: CallbackRunner,
):
if not lit_api.stream:
raise ValueError(
"Continuous batching loop requires streaming to be enabled. Please set LitServe(..., stream=True)"
)

"""Main loop that processes batches of requests."""
pending_requests = self.prefill(
[],
Expand All @@ -846,7 +888,7 @@ def run(
for step_output in responses:
logger.debug(f"Processing response: {step_output}")
status = step_output.status
response_data = step_output.output
response_data = lit_api.encode_response(step_output.output)
uid = step_output.uid
response_queue_id = self.response_queue_ids[uid]

Expand All @@ -870,7 +912,7 @@ def run(
)

except Exception as e:
logger.exception("Error in continuous batching loop")
logger.exception(f"Error in continuous batching loop: {e}")
# Handle any errors by sending error responses for all tracked requests
for uid, response_queue_id in self.response_queue_ids.items():
self.put_error_response(response_queues, response_queue_id, uid, e)
Expand Down
5 changes: 4 additions & 1 deletion src/litserve/specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import TYPE_CHECKING, Callable, List

if TYPE_CHECKING:
from litserve import LitServer
from litserve import LitAPI, LitServer


class LitSpec:
Expand All @@ -26,6 +26,9 @@ def __init__(self):

self._server: LitServer = None

def pre_setup(self, lit_api: "LitAPI"):
pass

def setup(self, server: "LitServer"):
self._server = server

Expand Down
10 changes: 5 additions & 5 deletions src/litserve/specs/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from litserve.utils import LitAPIStatus, azip

if typing.TYPE_CHECKING:
from litserve import LitServer
from litserve import LitAPI, LitServer

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -262,18 +262,18 @@ def __init__(
self.add_endpoint("/v1/chat/completions", self.chat_completion, ["POST"])
self.add_endpoint("/v1/chat/completions", self.options_chat_completions, ["OPTIONS"])

def setup(self, server: "LitServer"):
def pre_setup(self, lit_api: "LitAPI"):
from litserve import LitAPI

super().setup(server)

lit_api = self._server.lit_api
if not inspect.isgeneratorfunction(lit_api.predict):
raise ValueError(LITAPI_VALIDATION_MSG.format("predict is not a generator"))

is_encode_response_original = lit_api.encode_response.__code__ is LitAPI.encode_response.__code__
if not is_encode_response_original and not inspect.isgeneratorfunction(lit_api.encode_response):
raise ValueError(LITAPI_VALIDATION_MSG.format("encode_response is not a generator"))

def setup(self, server: "LitServer"):
super().setup(server)
print("OpenAI spec setup complete")

def populate_context(self, context, request):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_litapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def predict():


def test_encode_response_with_custom_spec_api():
class CustomSpecAPI(ls.test_examples.TestAPI):
class CustomSpecAPI(ls.OpenAISpec):
def encode_response(self, output_stream):
for output in output_stream:
yield {"content": output}
Expand Down
14 changes: 5 additions & 9 deletions tests/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,11 @@ def encode_response(self, output):

@pytest.mark.asyncio
async def test_openai_spec_validation(openai_request_data):
server = ls.LitServer(IncorrectAPI1(), spec=OpenAISpec())
with pytest.raises(ValueError, match="predict is not a generator"), wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager:
await manager.shutdown()

server = ls.LitServer(IncorrectAPI2(), spec=OpenAISpec())
with pytest.raises(ValueError, match="encode_response is not a generator"), wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager:
await manager.shutdown()
with pytest.raises(ValueError, match="predict is not a generator"):
ls.LitServer(IncorrectAPI1(), spec=OpenAISpec())

with pytest.raises(ValueError, match="encode_response is not a generator"):
ls.LitServer(IncorrectAPI2(), spec=OpenAISpec())


class PrePopulatedAPI(ls.LitAPI):
Expand Down
Loading