Skip to content

Commit

Permalink
Merge branch 'main' into bugfix/windows_multiple_workers
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya authored Jan 5, 2025
2 parents 63b1963 + f28c816 commit bde6fdb
Show file tree
Hide file tree
Showing 7 changed files with 221 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/litserve/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.2.6.dev1"
__version__ = "0.2.6.dev2"
__author__ = "Lightning-AI et al."
__author_email__ = "[email protected]"
__license__ = "Apache-2.0"
Expand Down
17 changes: 10 additions & 7 deletions src/litserve/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,9 +495,6 @@ def __init__(self):
self._context = {}

def get_batch_requests(self, lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float):
if max_batch_size <= 1:
raise ValueError("max_batch_size must be greater than 1")

batches, timed_out_uids = collate_requests(
lit_api,
request_queue,
Expand All @@ -507,8 +504,10 @@ def get_batch_requests(self, lit_api: LitAPI, request_queue: Queue, max_batch_si
return batches, timed_out_uids

def get_request(self, request_queue: Queue, timeout: float = 1.0):
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=timeout)
return response_queue_id, uid, timestamp, x_enc
try:
return request_queue.get(timeout=timeout)
except Empty:
return None

def populate_context(self, lit_spec: LitSpec, request: Any):
if lit_spec and hasattr(lit_spec, "populate_context"):
Expand Down Expand Up @@ -751,12 +750,14 @@ def has_finished(self, uid: str, token: str, max_sequence_length: int) -> bool:

def add_request(self, uid: str, request: Any, lit_api: LitAPI, lit_spec: Optional[LitSpec]) -> None:
"""Add a new sequence to active sequences and perform any action before prediction such as filling the cache."""
if hasattr(lit_api, "add_request"):
lit_api.add_request(uid, request)
decoded_request = lit_api.decode_request(request)
self.active_sequences[uid] = {"input": decoded_request, "current_length": 0, "generated_sequence": []}

def mark_completed(self, uid: str) -> None:
"""Mark a request as completed and remove it from the tracked state."""
logger.info(f"Marking sequence {uid} as completed")
logger.debug(f"Marking sequence {uid} as completed")
del self.active_sequences[uid]
del self.response_queue_ids[uid]

Expand Down Expand Up @@ -839,7 +840,7 @@ def prefill(
if new_batches:
# Add new requests to pending_requests and try to process them
for response_queue_id, uid, input in new_batches:
logger.info(f"New request: {uid}, {input}")
logger.debug(f"New request: {uid}, {input}")
if self.has_capacity(lit_api):
self.add_request(uid, input, lit_api, lit_spec)
self.response_queue_ids[uid] = response_queue_id
Expand Down Expand Up @@ -892,6 +893,7 @@ def run(
uid = step_output.uid
response_queue_id = self.response_queue_ids[uid]

response_data = lit_api.format_encoded_response(response_data)
if status == LitAPIStatus.ERROR:
self.put_error_response(response_queues, response_queue_id, uid, response_data)
self.mark_completed(uid)
Expand All @@ -917,6 +919,7 @@ def run(
for uid, response_queue_id in self.response_queue_ids.items():
self.put_error_response(response_queues, response_queue_id, uid, e)
self.response_queue_ids.clear()
self.active_sequences.clear()


def inference_worker(
Expand Down
5 changes: 2 additions & 3 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
from litserve.loops import LitLoop, get_default_loop, inference_worker
from litserve.middlewares import MaxSizeMiddleware, RequestCountMiddleware
from litserve.python_client import client_template
from litserve.specs import OpenAISpec
from litserve.specs.base import LitSpec
from litserve.utils import LitAPIStatus, WorkerSetupStatus, call_after_stream

Expand Down Expand Up @@ -149,8 +148,8 @@ def __init__(
raise ValueError("batch_timeout must be less than timeout")
if max_batch_size <= 0:
raise ValueError("max_batch_size must be greater than 0")
if isinstance(spec, OpenAISpec):
stream = True
if isinstance(spec, LitSpec):
stream = spec.stream

if loop is None:
loop = "auto"
Expand Down
4 changes: 4 additions & 0 deletions src/litserve/specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def __init__(self):

self._server: LitServer = None

@property
def stream(self):
return False

def pre_setup(self, lit_api: "LitAPI"):
pass

Expand Down
6 changes: 5 additions & 1 deletion src/litserve/specs/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,10 @@ def __init__(
self.add_endpoint("/v1/chat/completions", self.chat_completion, ["POST"])
self.add_endpoint("/v1/chat/completions", self.options_chat_completions, ["OPTIONS"])

@property
def stream(self):
return True

def pre_setup(self, lit_api: "LitAPI"):
from litserve import LitAPI

Expand Down Expand Up @@ -439,6 +443,7 @@ async def non_streaming_completion(self, request: ChatCompletionRequest, generat
logger.debug(encoded_response)
chat_msg = ChatMessage(**encoded_response)
usage = UsageInfo(**encoded_response)
usage_infos.append(usage) # Aggregate usage info across all choices
msgs.append(chat_msg.content)
if chat_msg.tool_calls:
tool_calls = chat_msg.tool_calls
Expand All @@ -447,6 +452,5 @@ async def non_streaming_completion(self, request: ChatCompletionRequest, generat
msg = {"role": "assistant", "content": content, "tool_calls": tool_calls}
choice = ChatCompletionResponseChoice(index=i, message=msg, finish_reason="stop")
choices.append(choice)
usage_infos.append(usage) # Only use the last item from encode_response

return ChatCompletionResponse(model=model, choices=choices, usage=sum(usage_infos))
156 changes: 156 additions & 0 deletions tests/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import inspect
import io
import json
import re
import threading
import time
from queue import Queue
Expand All @@ -28,8 +29,13 @@
from litserve import LitAPI
from litserve.callbacks import CallbackRunner
from litserve.loops import (
ContinuousBatchingLoop,
DefaultLoop,
LitLoop,
Output,
_BaseLoop,
inference_worker,
notify_timed_out_requests,
run_batched_loop,
run_batched_streaming_loop,
run_single_loop,
Expand Down Expand Up @@ -495,3 +501,153 @@ def test_get_default_loop():
assert isinstance(loop, ls.loops.BatchedStreamingLoop), (
"BatchedStreamingLoop must be returned when stream=True and max_batch_size>1"
)


@pytest.fixture
def lit_loop_setup():
lit_loop = LitLoop()
lit_api = MagicMock(request_timeout=0.1)
request_queue = Queue()
return lit_loop, lit_api, request_queue


def test_lit_loop_get_batch_requests(lit_loop_setup):
lit_loop, lit_api, request_queue = lit_loop_setup
request_queue.put((0, "UUID-001", time.monotonic(), {"input": 4.0}))
request_queue.put((0, "UUID-002", time.monotonic(), {"input": 5.0}))
batches, timed_out_uids = lit_loop.get_batch_requests(lit_api, request_queue, 2, 0.001)
assert len(batches) == 2
assert batches == [(0, "UUID-001", {"input": 4.0}), (0, "UUID-002", {"input": 5.0})]
assert timed_out_uids == []


def test_lit_loop_get_request(lit_loop_setup):
lit_loop, _, request_queue = lit_loop_setup
t = time.monotonic()
request_queue.put((0, "UUID-001", t, {"input": 4.0}))
response_queue_id, uid, timestamp, x_enc = lit_loop.get_request(request_queue, timeout=1)
assert uid == "UUID-001"
assert response_queue_id == 0
assert timestamp == t
assert x_enc == {"input": 4.0}
assert lit_loop.get_request(request_queue, timeout=0.001) is None


def test_lit_loop_put_response(lit_loop_setup):
lit_loop, _, request_queue = lit_loop_setup
response_queues = [Queue()]
lit_loop.put_response(response_queues, 0, "UUID-001", {"output": 16.0}, LitAPIStatus.OK)
response = response_queues[0].get()
assert response == ("UUID-001", ({"output": 16.0}, LitAPIStatus.OK))


def test_notify_timed_out_requests():
response_queues = [Queue()]

# Simulate timed out requests
timed_out_uids = [(0, "UUID-001"), (0, "UUID-002")]

# Call the function to notify timed out requests
notify_timed_out_requests(response_queues, timed_out_uids)

# Check the responses in the response queue
response_1 = response_queues[0].get()
response_2 = response_queues[0].get()

assert response_1[0] == "UUID-001"
assert response_1[1][1] == LitAPIStatus.ERROR
assert isinstance(response_1[1][0], HTTPException)
assert response_2[0] == "UUID-002"
assert isinstance(response_2[1][0], HTTPException)
assert response_2[1][1] == LitAPIStatus.ERROR


class ContinuousBatchingAPI(ls.LitAPI):
def setup(self, spec: Optional[LitSpec]):
self.model = {}

def add_request(self, uid: str, request):
self.model[uid] = {"outputs": list(range(5))}

def decode_request(self, input: str):
return input

def encode_response(self, output: str):
return {"output": output}

def step(self, prev_outputs: Optional[List[Output]]) -> List[Output]:
outputs = []
for k in self.model:
v = self.model[k]
if v["outputs"]:
o = v["outputs"].pop(0)
outputs.append(Output(k, o, LitAPIStatus.OK))
keys = list(self.model.keys())
for k in keys:
if k not in [o.uid for o in outputs]:
outputs.append(Output(k, "", LitAPIStatus.FINISH_STREAMING))
del self.model[k]
return outputs


@pytest.mark.parametrize(
("stream", "max_batch_size", "error_msg"),
[
(True, 4, "`lit_api.unbatch` must generate values using `yield`."),
(True, 1, "`lit_api.encode_response` must generate values using `yield`."),
],
)
def test_default_loop_pre_setup_error(stream, max_batch_size, error_msg):
lit_api = ls.test_examples.SimpleLitAPI()
lit_api.stream = stream
lit_api.max_batch_size = max_batch_size
loop = DefaultLoop()
with pytest.raises(ValueError, match=error_msg):
loop.pre_setup(lit_api, None)


@pytest.fixture
def continuous_batching_setup():
lit_api = ContinuousBatchingAPI()
lit_api.stream = True
lit_api.request_timeout = 0.1
lit_api.pre_setup(2, None)
lit_api.setup(None)
request_queue = Queue()
response_queues = [Queue()]
loop = ContinuousBatchingLoop()
return lit_api, loop, request_queue, response_queues


def test_continuous_batching_pre_setup(continuous_batching_setup):
lit_api, loop, request_queue, response_queues = continuous_batching_setup
lit_api.stream = False
with pytest.raises(
ValueError,
match=re.escape(
"Continuous batching loop requires streaming to be enabled. Please set LitServe(..., stream=True)"
),
):
loop.pre_setup(lit_api, None)


def test_continuous_batching_run(continuous_batching_setup):
lit_api, loop, request_queue, response_queues = continuous_batching_setup
request_queue.put((0, "UUID-001", time.monotonic(), {"input": "Hello"}))
loop.run(lit_api, None, "cpu", 0, request_queue, response_queues, 2, 0.1, True, {}, NOOP_CB_RUNNER)

results = []
for i in range(5):
response = response_queues[0].get()
uid, (response_data, status) = response
o = json.loads(response_data)["output"]
assert o == i
assert status == LitAPIStatus.OK
assert uid == "UUID-001"
results.append(o)
assert results == list(range(5)), "API must return a sequence of numbers from 0 to 4"
response = response_queues[0].get()
uid, (response_data, status) = response
o = json.loads(response_data)["output"]
assert o == ""
assert status == LitAPIStatus.FINISH_STREAMING
43 changes: 43 additions & 0 deletions tests/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,49 @@ async def test_openai_token_usage(api, batch_size, openai_request_data, openai_r
assert result["usage"] == openai_response_data["usage"]


class OpenAIWithUsagePerToken(ls.LitAPI):
def setup(self, device):
self.model = None

def predict(self, x):
for i in range(1, 6):
yield {
"role": "assistant",
"content": f"{i}",
"prompt_tokens": 0,
"completion_tokens": 1,
"total_tokens": 1,
}


# OpenAIWithUsagePerToken
@pytest.mark.asyncio
@pytest.mark.parametrize(
("api", "batch_size"),
[
(OpenAIWithUsagePerToken(), 1),
],
)
async def test_openai_per_token_usage(api, batch_size, openai_request_data, openai_response_data):
server = ls.LitServer(api, spec=ls.OpenAISpec(), max_batch_size=batch_size, batch_timeout=0.01)
with wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager, AsyncClient(
transport=ASGITransport(app=manager.app), base_url="http://test"
) as ac:
resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10)
assert resp.status_code == 200, "Status code should be 200"
result = resp.json()
content = result["choices"][0]["message"]["content"]
assert content == "12345", "LitAPI predict response should match with the generated output"
assert result["usage"]["completion_tokens"] == 5, "API yields 5 tokens"

# with streaming
openai_request_data["stream"] = True
resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10)
assert resp.status_code == 200, "Status code should be 200"
assert result["usage"]["completion_tokens"] == 5, "API yields 5 tokens"


@pytest.mark.asyncio
async def test_openai_spec_with_image(openai_request_data_with_image):
server = ls.LitServer(TestAPI(), spec=OpenAISpec())
Expand Down

0 comments on commit bde6fdb

Please sign in to comment.