Skip to content

Commit

Permalink
FEAT: Delete cluster (#1719)
Browse files Browse the repository at this point in the history
Co-authored-by: wuzhaoxin <[email protected]>
  • Loading branch information
hainaweiben and wuzhaoxin authored Jun 28, 2024
1 parent 8feac94 commit 3d9c261
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 0 deletions.
67 changes: 67 additions & 0 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,36 @@ def serve(self, logging_conf: Optional[dict] = None):
else None
),
)
self._router.add_api_route(
"/v1/workers",
self.get_workers_info,
methods=["GET"],
dependencies=(
[Security(self._auth_service, scopes=["admin"])]
if self.is_authenticated()
else None
),
)
self._router.add_api_route(
"/v1/supervisor",
self.get_supervisor_info,
methods=["GET"],
dependencies=(
[Security(self._auth_service, scopes=["admin"])]
if self.is_authenticated()
else None
),
)
self._router.add_api_route(
"/v1/clusters",
self.abort_cluster,
methods=["DELETE"],
dependencies=(
[Security(self._auth_service, scopes=["admin"])]
if self.is_authenticated()
else None
),
)

if XINFERENCE_DISABLE_METRICS:
logger.info(
Expand Down Expand Up @@ -1730,6 +1760,43 @@ async def confirm_and_remove_model(
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

async def get_workers_info(self) -> JSONResponse:
try:
res = await (await self._get_supervisor_ref()).get_workers_info()
return JSONResponse(content=res)
except ValueError as re:
logger.error(re, exc_info=True)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

async def get_supervisor_info(self) -> JSONResponse:
try:
res = await (await self._get_supervisor_ref()).get_supervisor_info()
return res
except ValueError as re:
logger.error(re, exc_info=True)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

async def abort_cluster(self) -> JSONResponse:
import os
import signal

try:
res = await (await self._get_supervisor_ref()).abort_cluster()
os.kill(os.getpid(), signal.SIGINT)
return JSONResponse(content={"result": res})
except ValueError as re:
logger.error(re, exc_info=True)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))


def run(
supervisor_address: str,
Expand Down
30 changes: 30 additions & 0 deletions xinference/client/restful/restful_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,3 +1324,33 @@ def abort_request(self, model_uid: str, request_id: str):

response_data = response.json()
return response_data

def get_workers_info(self):
url = f"{self.base_url}/v1/workers"
response = requests.get(url, headers=self._headers)
if response.status_code != 200:
raise RuntimeError(
f"Failed to get workers info, detail: {_get_error_string(response)}"
)
response_data = response.json()
return response_data

def get_supervisor_info(self):
url = f"{self.base_url}/v1/supervisor"
response = requests.get(url, headers=self._headers)
if response.status_code != 200:
raise RuntimeError(
f"Failed to get supervisor info, detail: {_get_error_string(response)}"
)
response_json = response.json()
return response_json

def abort_cluster(self):
url = f"{self.base_url}/v1/clusters"
response = requests.delete(url, headers=self._headers)
if response.status_code != 200:
raise RuntimeError(
f"Failed to abort cluster, detail: {_get_error_string(response)}"
)
response_json = response.json()
return response_json
41 changes: 41 additions & 0 deletions xinference/core/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import asyncio
import itertools
import os
import signal
import time
import typing
from dataclasses import dataclass
Expand Down Expand Up @@ -217,6 +219,17 @@ async def __post_create__(self):
model_version_infos, self.address
)

# Windows does not have signal handler
if os.name != "nt":

async def signal_handler():
os._exit(0)

loop = asyncio.get_running_loop()
loop.add_signal_handler(
signal.SIGTERM, lambda: asyncio.create_task(signal_handler())
)

@typing.no_type_check
async def get_cluster_device_info(self, detailed: bool = False) -> List:
import psutil
Expand Down Expand Up @@ -1153,6 +1166,34 @@ async def confirm_and_remove_model(
)
return ret

async def get_workers_info(self) -> List[Dict[str, Any]]:
ret = []
for worker in self._worker_address_to_worker.values():
ret.append(await worker.get_workers_info())
return ret

async def get_supervisor_info(self) -> Dict[str, Any]:
ret = {
"supervisor_ip": self.address,
}
return ret

async def trigger_exit(self) -> bool:
try:
os.kill(os.getpid(), signal.SIGTERM)
except Exception as e:
logger.info(f"trigger exit error: {e}")
return False
return True

async def abort_cluster(self) -> bool:
ret = True
for worker in self._worker_address_to_worker.values():
ret = ret and await worker.trigger_exit()

ret = ret and await self.trigger_exit()
return ret

@staticmethod
def record_metrics(name, op, kwargs):
record_metrics(name, op, kwargs)
15 changes: 15 additions & 0 deletions xinference/core/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,14 @@ async def signal_handler():
async def __pre_destroy__(self):
self._isolation.stop()

async def trigger_exit(self) -> bool:
try:
os.kill(os.getpid(), signal.SIGINT)
except Exception as e:
logger.info(f"trigger exit error: {e}")
return False
return True

@staticmethod
def get_devices_count():
from ..device_utils import gpu_count
Expand Down Expand Up @@ -863,6 +871,13 @@ async def confirm_and_remove_model(self, model_version: str) -> bool:
)
return True

async def get_workers_info(self) -> Dict[str, Any]:
ret = {
"work-ip": self.address,
"models": await self.list_models(),
}
return ret

@staticmethod
def record_metrics(name, op, kwargs):
record_metrics(name, op, kwargs)
46 changes: 46 additions & 0 deletions xinference/deploy/cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1578,5 +1578,51 @@ def cal_model_mem(
print(" total: %d MB (%d GB)" % (mem_info.total, total_mem_g))


@cli.command(
"stop-cluster",
help="Stop a cluster using the Xinference framework with the given parameters.",
)
@click.option(
"--endpoint",
"-e",
type=str,
required=True,
help="Xinference endpoint.",
)
@click.option(
"--api-key",
"-ak",
default=None,
type=str,
help="API key for accessing the Xinference API with authorization.",
)
@click.option("--check", is_flag=True, help="Confirm the deletion of the cache.")
def stop_cluster(endpoint: str, api_key: Optional[str], check: bool):
endpoint = get_endpoint(endpoint)
client = RESTfulClient(base_url=endpoint, api_key=api_key)
if api_key is None:
client._set_token(get_stored_token(endpoint, client))

if not check:
click.echo(
f"This command will stop Xinference cluster in {endpoint}.", err=True
)
supervisor_info = client.get_supervisor_info()
click.echo("Supervisor information: ")
click.echo(supervisor_info)

workers_info = client.get_workers_info()
click.echo("Workers information:")
click.echo(workers_info)

click.confirm("Continue?", abort=True)
try:
result = client.abort_cluster()
result = result.get("result")
click.echo(f"Cluster stopped: {result}")
except Exception as e:
click.echo(e)


if __name__ == "__main__":
cli()

0 comments on commit 3d9c261

Please sign in to comment.