Skip to content

Commit

Permalink
create_actor_pool: Add extra_conf option listen_elastic_ip
Browse files Browse the repository at this point in the history
Usage:

    create_actor_pool(elastic_address,
	n_process=0,
	extra_conf={'listen_elastic_ip': True},
    )

While xinference worker serve on cloud elastic_ip,
the address used in create_actor_pool() and create_actor both have to be the elastic ip,
in order for ActorRef passing around RPC method to client, but we could only listen on 0.0.0.0.
(Because the ip is not only valid outside the host)
  • Loading branch information
frostyplanet committed Jul 14, 2024
1 parent c11c9d5 commit 127fab2
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 4 deletions.
17 changes: 15 additions & 2 deletions python/xoscar/backends/communication/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ..._utils import to_binary
from ...constants import XOSCAR_UNIX_SOCKET_DIR
from ...serialization import AioDeserializer, AioSerializer, deserialize
from ...utils import classproperty, implements
from ...utils import classproperty, implements, is_v6_ip
from .base import Channel, ChannelType, Client, Server
from .core import register_client, register_server
from .utils import read_buffers, write_buffers
Expand Down Expand Up @@ -201,6 +201,10 @@ def client_type(self) -> Type["Client"]:
def channel_type(self) -> int:
return ChannelType.remote

@classmethod
def parse_config(cls, config: dict) -> dict:
return config

@staticmethod
@implements(Server.create)
async def create(config: Dict) -> "Server":
Expand All @@ -212,6 +216,15 @@ async def create(config: Dict) -> "Server":
else:
host = config.pop("host")
port = int(config.pop("port"))
# The Actor.address is not on our host, cannot actually listen on it.
# But we have to keep it for announcement to client.s
_host = host
if config.pop("listen_elastic_ip", False):
if is_v6_ip(host):
_host = "::"
else:
_host = "0.0.0.0"

handle_channel = config.pop("handle_channel")
if "start_serving" not in config:
config["start_serving"] = False
Expand All @@ -224,7 +237,7 @@ async def handle_connection(reader: StreamReader, writer: StreamWriter):

port = port if port != 0 else None
aio_server = await asyncio.start_server(
handle_connection, host=host, port=port, **config
handle_connection, host=_host, port=port, **config
)

# get port of the socket if not specified
Expand Down
13 changes: 11 additions & 2 deletions python/xoscar/backends/communication/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ...nvutils import get_cuda_context, get_index_and_uuid
from ...serialization import deserialize
from ...serialization.aio import BUFFER_SIZES_NAME, AioSerializer, get_header_length
from ...utils import classproperty, implements, is_cuda_buffer, lazy_import
from ...utils import classproperty, implements, is_cuda_buffer, is_v6_ip, lazy_import
from ..message import _MessageBase
from .base import Channel, ChannelType, Client, Server
from .core import register_client, register_server
Expand Down Expand Up @@ -406,6 +406,15 @@ async def create(config: Dict) -> "Server":
else:
host = config.pop("host")
port = int(config.pop("port"))
# The Actor.address is not on our host, cannot actually listen on it.
# But we have to keep it for announcement to client.s
_host = host
if config.pop("listen_elastic_ip", False):
if is_v6_ip(host):
_host = "::"
else:
_host = "0.0.0.0"

handle_channel = config.pop("handle_channel")

# init
Expand All @@ -414,7 +423,7 @@ async def create(config: Dict) -> "Server":
async def serve_forever(client_ucp_endpoint: "ucp.Endpoint"): # type: ignore
try:
await server.on_connected(
client_ucp_endpoint, local_address=server.address
client_ucp_endpoint, local_address="%s:%d" % (_host, port)
)
except ChannelClosed: # pragma: no cover
logger.exception("Connection closed before handshake completed")
Expand Down
5 changes: 5 additions & 0 deletions python/xoscar/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,11 @@ def is_v6_zero_ip(ip_port_addr: str) -> bool:
return True


def is_v6_ip(ip_port_addr: str) -> bool:
arr = ip_port_addr.split("://", 1)[-1].split(":")
return len(arr) > 1


def fix_all_zero_ip(remote_addr: str, connect_addr: str) -> str:
"""
Use connect_addr to fix ActorRef.address return by remote server.
Expand Down

0 comments on commit 127fab2

Please sign in to comment.