Skip to content

Commit 49a6748

Browse files
committed
add tests
1 parent ae2911e commit 49a6748

File tree

1 file changed

+38
-2
lines changed

1 file changed

+38
-2
lines changed

tests/test_utils.py

+38-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import os
12
import pickle
2-
from unittest.mock import MagicMock
3+
import sys
4+
from unittest.mock import MagicMock, patch
35

46
import pytest
57
from fastapi import HTTPException
68

7-
from litserve.utils import call_after_stream, dump_exception
9+
from litserve.utils import call_after_stream, dump_exception, generate_random_zmq_address
810

911

1012
def test_dump_exception():
@@ -31,3 +33,37 @@ async def test_call_after_stream():
3133
pass
3234
callback.assert_called()
3335
callback.assert_called_with("first_arg", random_arg="second_arg")
36+
37+
38+
@pytest.mark.skipif(sys.platform == "win32", reason="This test is for non-Windows platforms only.")
39+
def test_generate_random_zmq_address_non_windows(tmpdir):
40+
"""Test generate_random_zmq_address on non-Windows platforms."""
41+
42+
temp_dir = str(tmpdir)
43+
address1 = generate_random_zmq_address(temp_dir=temp_dir)
44+
address2 = generate_random_zmq_address(temp_dir=temp_dir)
45+
46+
assert address1.startswith("ipc://"), "Address should start with 'ipc://'"
47+
assert address2.startswith("ipc://"), "Address should start with 'ipc://'"
48+
assert address1 != address2, "Addresses should be unique"
49+
50+
# Verify the path exists within the specified temp_dir
51+
assert os.path.commonpath([temp_dir, address1[6:]]) == temp_dir
52+
assert os.path.commonpath([temp_dir, address2[6:]]) == temp_dir
53+
54+
55+
@patch("sys.platform", "win32")
56+
@patch("zmq.Context")
57+
def test_generate_random_zmq_address_windows(mock_ctx):
58+
"""Test generate_random_zmq_address on Windows platforms."""
59+
mock_socket = mock_ctx.return_value.socket.return_value
60+
mock_socket.bind_to_random_port.return_value = 5555
61+
62+
address = generate_random_zmq_address()
63+
assert address == "tcp://localhost:5555"
64+
65+
# Verify socket and context were properly used
66+
mock_ctx.return_value.socket.assert_called_once()
67+
mock_socket.bind_to_random_port.assert_called_once_with("localhost")
68+
mock_socket.close.assert_called_once()
69+
mock_ctx.return_value.term.assert_called_once()

0 commit comments

Comments
 (0)