Skip to content

Commit 11cfbf6

Browse files
add e2e test for simple streaming server (#247)
* test: add e2e test for simple streaming server * test: remove redundant assertion and simplify output checks * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent c1b6ad4 commit 11cfbf6

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

tests/e2e/default_single_streaming.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from litserve import LitServer, LitAPI
2+
3+
4+
class SimpleStreamingAPI(LitAPI):
5+
def setup(self, device) -> None:
6+
self.model = lambda x, y: x * y
7+
8+
def decode_request(self, request):
9+
return request["input"]
10+
11+
def predict(self, x):
12+
for i in range(1, 4):
13+
yield self.model(i, x)
14+
15+
def encode_response(self, output_stream):
16+
for output in output_stream:
17+
yield {"output": output}
18+
19+
20+
if __name__ == "__main__":
21+
api = SimpleStreamingAPI()
22+
server = LitServer(api, stream=True)
23+
server.run(port=8000)

tests/e2e/test_e2e.py

+18
Original file line numberDiff line numberDiff line change
@@ -295,3 +295,21 @@ def test_openai_parity_with_response_format():
295295
assert r.choices[0].delta.content == expected_out, (
296296
f"Server didn't return expected output.\n" f"OpenAI client output: {r}"
297297
)
298+
299+
300+
@e2e_from_file("tests/e2e/default_single_streaming.py")
301+
def test_e2e_single_streaming():
302+
resp = requests.post("http://127.0.0.1:8000/predict", json={"input": 4.0}, headers=None, stream=True)
303+
assert resp.status_code == 200, f"Expected response to be 200 but got {resp.status_code}"
304+
305+
outputs = []
306+
for line in resp.iter_lines():
307+
if line:
308+
outputs.append(json.loads(line.decode("utf-8")))
309+
310+
assert len(outputs) == 3, "Expected 3 streamed outputs"
311+
assert outputs[-1] == {"output": 12.0}, "Final output doesn't match expected value"
312+
313+
expected_values = [4.0, 8.0, 12.0]
314+
for i, output in enumerate(outputs):
315+
assert output["output"] == expected_values[i], f"Intermediate output {i} is not expected value"

0 commit comments

Comments
 (0)