diff --git a/tests/test_litapi.py b/tests/test_litapi.py index 38e08afd..2fc44e39 100644 --- a/tests/test_litapi.py +++ b/tests/test_litapi.py @@ -11,9 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json + import numpy as np import pytest - +import torch +from pydantic import BaseModel from fastapi import HTTPException from litserve.specs.openai import ChatCompletionRequest import litserve as ls @@ -68,6 +71,25 @@ def test_default_batch_unbatch(): assert api.unbatch(output) == inputs, "Default unbatch should not change input" +class TestStreamAPIBatched(TestStreamAPI): + def predict(self, x): + for i in range(4): + yield np.asarray(x) * i + + +def test_default_batch_unbatch_stream(): + api = TestStreamAPIBatched() + api.stream = True + api._sanitize(max_batch_size=4, spec=None) + inputs = [1, 2, 3, 4] + expected_output = [[0, 0, 0, 0], [1, 2, 3, 4], [2, 4, 6, 8], [3, 6, 9, 12]] + output = api.batch(inputs) + output = api.predict(output) + for out in api.unbatch(output): + expected = expected_output.pop(0) + assert np.all(out == expected), f"Default unbatch should not change input {out} != {expected}" + + def test_custom_batch_unbatch(): api = TestCustomBatchedAPI() api._sanitize(max_batch_size=4, spec=None) @@ -182,3 +204,39 @@ def predict(): api._sanitize(max_batch_size=1, spec=ls.OpenAISpec()) with pytest.raises(HTTPException, match=r"Malformed output from LitAPI.predict"): next(api.encode_response(predict())) + + +def test_format_encoded_response(): + api = ls.examples.SimpleLitAPI() + sample = {"output": 4.0} + msg = "Format encoded response should return the encoded response as a string" + assert api.format_encoded_response(sample) == '{"output": 4.0}\n', msg + + class Sample(BaseModel): + output: float + name: str + + sample = Sample(output=4.0, name="test") + msg = "Format encoded response should return the encoded response as a json string" + assert json.loads(api.format_encoded_response(sample)) == {"output": 4.0, "name": "test"}, msg + + msg = "non dict and non Pydantic objects are returned as it is." + assert api.format_encoded_response([1, 2, 3, 4]) == [1, 2, 3, 4], msg + + +def test_batch_torch(): + api = ls.examples.SimpleLitAPI() + x = [torch.Tensor([1, 2, 3, 4]), torch.Tensor([5, 6, 7, 8])] + assert torch.all(api.batch(x) == torch.stack(x)), "Batch should stack torch tensors" + + +def test_batch_numpy(): + api = ls.examples.SimpleLitAPI() + x = [np.asarray([1, 2, 3, 4]), np.asarray([5, 6, 7, 8])] + assert np.all(api.batch(x) == np.stack(x)), "Batch should stack Numpy array" + + +def test_device_property(): + api = ls.examples.SimpleLitAPI() + api.device = "cpu" + assert api.device == "cpu"