Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Jobs] Limit number of concurrent jobs & launches. #4248

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sky/jobs/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
# We use 50 GB disk size to reduce the cost.
CONTROLLER_RESOURCES = {'cpus': '8+', 'memory': '3x', 'disk_size': 50}

# Accordingly, we reserve 0.75 GB memory for each controller process.
CONTROLLER_MEMORY_USAGE_GB = 0.75

# Max length of the cluster name for GCP is 35, the user hash to be attached is
# 4+1 chars, and we assume the maximum length of the job id is 4+1, so the max
# length of the cluster name prefix is 25 to avoid the cluster name being too
Expand Down
18 changes: 18 additions & 0 deletions sky/jobs/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Tuple

import filelock
import psutil

from sky import exceptions
from sky import sky_logging
Expand All @@ -35,6 +36,10 @@
# to inherit the setup from the `sky` logger.
logger = sky_logging.init_logger('sky.jobs.controller')

# Since sky.launch is very resource demanding, we limit the number of
# concurrent sky.launch process to avoid overloading the machine.
_MAX_NUM_LAUNCH = psutil.cpu_count() * 2


def _get_dag_and_name(dag_yaml: str) -> Tuple['sky.Dag', str]:
dag = dag_utils.load_chain_dag_from_yaml(dag_yaml)
Expand Down Expand Up @@ -182,6 +187,14 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool:
f'{task.name!r}); {constants.TASK_ID_ENV_VAR}: {task_id_env_var}')

logger.info('Started monitoring.')

while len(managed_job_state.get_starting_job_ids()) >= _MAX_NUM_LAUNCH:
logger.info('Number of concurrent launches reached the limit '
f'({_MAX_NUM_LAUNCH}). Waiting for some launch process '
'to finish...')
time.sleep(managed_job_utils.JOB_STARTING_STATUS_CHECK_GAP_SECONDS)

logger.info(f'Starting the job {self._job_id} (task: {task_id}).')
managed_job_state.set_starting(job_id=self._job_id,
task_id=task_id,
callback_func=callback_func)
Expand Down Expand Up @@ -474,6 +487,11 @@ def start(job_id, dag_yaml, retry_until_up):
controller_process = None
cancelling = False
try:
if (len(managed_job_state.get_nonterminal_job_ids_by_name(None)) >
managed_job_utils.NUM_JOBS_THRESHOLD):
raise exceptions.ManagedJobUserCancelledError(
'Too many concurrent managed jobs are running. '
'Please try again later or cancel some jobs.')
Comment on lines +490 to +494
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this does not make sense, as users should expect the managed jobs being submitted queued on the controller when there is not enough resources on controller, instead of erroring out. Should we instead changing the task CPU requirement based on the memory on controller?

_handle_signal(job_id)
# TODO(suquark): In theory, we should make controller process a
# daemon process so it will be killed after this process exits,
Expand Down
45 changes: 29 additions & 16 deletions sky/jobs/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,12 +498,11 @@ def set_cancelled(job_id: int, callback_func: CallbackType):


# ======== utility functions ========
def get_nonterminal_job_ids_by_name(name: Optional[str]) -> List[int]:
"""Get non-terminal job ids by name."""
statuses = ', '.join(['?'] * len(ManagedJobStatus.terminal_statuses()))
field_values = [
status.value for status in ManagedJobStatus.terminal_statuses()
]
def _get_query_jobs_by_status_and_name(
statuses: List[ManagedJobStatus],
name: Optional[str]) -> Tuple[str, List[str]]:
statuses_str = ', '.join(['?'] * len(statuses))
field_values = [status.value for status in statuses]

name_filter = ''
if name is not None:
Expand All @@ -516,17 +515,31 @@ def get_nonterminal_job_ids_by_name(name: Optional[str]) -> List[int]:

# Left outer join is used here instead of join, because the job_info does
# not contain the managed jobs submitted before #1982.
return f"""\
SELECT DISTINCT spot.spot_job_id
FROM spot
LEFT OUTER JOIN job_info
ON spot.spot_job_id=job_info.spot_job_id
WHERE status NOT IN
({statuses_str})
{name_filter}
ORDER BY spot.spot_job_id DESC""", field_values


def get_starting_job_ids() -> List[int]:
"""Get job ids that are starting."""
with db_utils.safe_cursor(_DB_PATH) as cursor:
rows = cursor.execute(
f"""\
SELECT DISTINCT spot.spot_job_id
FROM spot
LEFT OUTER JOIN job_info
ON spot.spot_job_id=job_info.spot_job_id
WHERE status NOT IN
({statuses})
{name_filter}
ORDER BY spot.spot_job_id DESC""", field_values).fetchall()
rows = cursor.execute(*_get_query_jobs_by_status_and_name(
[ManagedJobStatus.STARTING], None)).fetchall()
job_ids = [row[0] for row in rows if row[0] is not None]
return job_ids


def get_nonterminal_job_ids_by_name(name: Optional[str]) -> List[int]:
"""Get non-terminal job ids by name."""
with db_utils.safe_cursor(_DB_PATH) as cursor:
rows = cursor.execute(*_get_query_jobs_by_status_and_name(
ManagedJobStatus.terminal_statuses(), name)).fetchall()
job_ids = [row[0] for row in rows if row[0] is not None]
return job_ids

Expand Down
10 changes: 10 additions & 0 deletions sky/jobs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import colorama
import filelock
import psutil
from typing_extensions import Literal

from sky import backends
Expand Down Expand Up @@ -48,11 +49,20 @@
f'sky-jobs-controller-{common_utils.get_user_hash()}')
LEGACY_JOB_CONTROLLER_NAME: str = (
f'sky-spot-controller-{common_utils.get_user_hash()}')

_SYSTEM_MEMORY_GB = psutil.virtual_memory().total // (1024**3)
NUM_JOBS_THRESHOLD = (_SYSTEM_MEMORY_GB //
managed_job_constants.CONTROLLER_MEMORY_USAGE_GB)

SIGNAL_FILE_PREFIX = '/tmp/sky_jobs_controller_signal_{}'
LEGACY_SIGNAL_FILE_PREFIX = '/tmp/sky_spot_controller_signal_{}'
# Controller checks its job's status every this many seconds.
JOB_STATUS_CHECK_GAP_SECONDS = 20

# Controller checks if the job is valid to start every this many seconds.
# Jobs will be started only if the controller has enough resources.
JOB_STARTING_STATUS_CHECK_GAP_SECONDS = 5

# Controller checks if its job has started every this many seconds.
JOB_STARTED_STATUS_CHECK_GAP_SECONDS = 5

Expand Down
Loading