diff --git a/python/xoscar/backends/communication/socket.py b/python/xoscar/backends/communication/socket.py index 246956e3..816859f2 100644 --- a/python/xoscar/backends/communication/socket.py +++ b/python/xoscar/backends/communication/socket.py @@ -203,7 +203,13 @@ def channel_type(self) -> int: @classmethod def parse_config(cls, config: dict) -> dict: - return config + if config is None or not config: + return dict() + # we only need the following config + keys = ["listen_elastic_ip"] + parsed_config = {key: config[key] for key in keys if key in config} + + return parsed_config @staticmethod @implements(Server.create) diff --git a/python/xoscar/backends/indigen/tests/test_pool.py b/python/xoscar/backends/indigen/tests/test_pool.py index 04d5c513..6b60481d 100644 --- a/python/xoscar/backends/indigen/tests/test_pool.py +++ b/python/xoscar/backends/indigen/tests/test_pool.py @@ -542,6 +542,104 @@ async def test_create_actor_pool(): assert len(global_router._mapping) == 0 +@pytest.mark.asyncio +async def test_create_actor_pool_extra_config(): + start_method = ( + os.environ.get("POOL_START_METHOD", "forkserver") + if sys.platform != "win32" + else None + ) + # create a actor pool based on socket rather than ucx + # pass `extra_conf` to check if we can filter out ucx config + pool = await create_actor_pool( + "127.0.0.1", + pool_cls=MainActorPool, + n_process=2, + subprocess_start_method=start_method, + extra_conf={ + "ucx": { + "tcp": None, + "nvlink": None, + "infiniband": None, + "rdmacm": None, + "cuda-copy": None, + "create-cuda-contex": None, + } + }, + ) + + async with pool: + # test global router + global_router = Router.get_instance() + # global router should not be the identical one with pool's router + assert global_router is not pool.router + assert pool.external_address in global_router._curr_external_addresses + assert pool.external_address in global_router._mapping + + ctx = get_context() + + # actor on main pool + actor_ref = await ctx.create_actor( + TestActor, uid="test-1", address=pool.external_address + ) + assert await actor_ref.add(3) == 3 + assert await actor_ref.add(1) == 4 + assert (await ctx.has_actor(actor_ref)) is True + assert (await ctx.actor_ref(actor_ref)) == actor_ref + # test cancel + task = asyncio.create_task(actor_ref.sleep(20)) + await asyncio.sleep(0) + task.cancel() + assert await task == 5 + await ctx.destroy_actor(actor_ref) + assert (await ctx.has_actor(actor_ref)) is False + for f in actor_ref.add, ctx.actor_ref, ctx.destroy_actor: + with pytest.raises(ActorNotExist): + await f(actor_ref) + + # actor on sub pool + actor_ref1 = await ctx.create_actor( + TestActor, uid="test-main", address=pool.external_address + ) + actor_ref2 = await ctx.create_actor( + TestActor, + uid="test-2", + address=pool.external_address, + allocate_strategy=RandomSubPool(), + ) + assert ( + await ctx.actor_ref(uid="test-2", address=actor_ref2.address) + ) == actor_ref2 + main_ref = await ctx.actor_ref(uid="test-main", address=actor_ref2.address) + assert main_ref.address == pool.external_address + main_ref = await ctx.actor_ref(actor_ref1) + assert main_ref.address == pool.external_address + assert actor_ref2.address != actor_ref.address + assert await actor_ref2.add(3) == 3 + assert await actor_ref2.add(1) == 4 + with pytest.raises(RuntimeError): + await actor_ref2.return_cannot_unpickle() + with pytest.raises(SendMessageFailed): + await actor_ref2.raise_cannot_pickle() + assert (await ctx.has_actor(actor_ref2)) is True + assert (await ctx.actor_ref(actor_ref2)) == actor_ref2 + # test cancel + task = asyncio.create_task(actor_ref2.sleep(20)) + start = time.time() + await asyncio.sleep(0) + task.cancel() + assert await task == 5 + assert time.time() - start < 3 + await ctx.destroy_actor(actor_ref2) + assert (await ctx.has_actor(actor_ref2)) is False + + assert pool.stopped + # after pool shutdown, global router must has been cleaned + global_router = Router.get_instance() + assert len(global_router._curr_external_addresses) == 0 + assert len(global_router._mapping) == 0 + + @pytest.mark.asyncio @require_unix async def test_create_actor_pool_elastic_ip():