Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve] Clean up service replicas when CONTROLLER_FAILED happens #3470

Closed
39 changes: 3 additions & 36 deletions sky/serve/replica_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def launch_cluster(replica_id: int,
else: # No exception, the launch succeeds.
return

terminate_cluster(cluster_name)
serve_utils.terminate_cluster(cluster_name)
if retry_cnt >= max_retry:
raise RuntimeError('Failed to launch the sky serve replica cluster '
f'{cluster_name} after {max_retry} retries.')
Expand All @@ -133,39 +133,6 @@ def launch_cluster(replica_id: int,
time.sleep(gap_seconds)


# TODO(tian): Combine this with
# sky/spot/recovery_strategy.py::terminate_cluster
def terminate_cluster(cluster_name: str,
replica_drain_delay_seconds: int = 0,
max_retry: int = 3) -> None:
"""Terminate the sky serve replica cluster."""
time.sleep(replica_drain_delay_seconds)
retry_cnt = 0
backoff = common_utils.Backoff()
while True:
retry_cnt += 1
try:
usage_lib.messages.usage.set_internal()
sky.down(cluster_name)
return
except ValueError:
# The cluster is already terminated.
logger.info(
f'Replica cluster {cluster_name} is already terminated.')
return
except Exception as e: # pylint: disable=broad-except
if retry_cnt >= max_retry:
raise RuntimeError('Failed to terminate the sky serve replica '
f'cluster {cluster_name}.') from e
gap_seconds = backoff.current_backoff()
logger.error(
'Failed to terminate the sky serve replica cluster '
f'{cluster_name}. Retrying after {gap_seconds} seconds.'
f'Details: {common_utils.format_exception(e)}')
logger.error(f' Traceback: {traceback.format_exc()}')
time.sleep(gap_seconds)


def _get_resources_ports(task_yaml: str) -> str:
"""Get the resources ports used by the task."""
task = sky.Task.from_yaml(task_yaml)
Expand Down Expand Up @@ -730,8 +697,8 @@ def _download_and_stream_logs(info: ReplicaInfo):
logger.info(f'preempted: {info.status_property.preempted}, '
f'replica_id: {replica_id}')
p = multiprocessing.Process(
target=ux_utils.RedirectOutputForProcess(terminate_cluster,
log_file_name, 'a').run,
target=ux_utils.RedirectOutputForProcess(
serve_utils.terminate_cluster, log_file_name, 'a').run,
args=(info.cluster_name, replica_drain_delay_seconds),
)
info.status_property.sky_down_status = ProcessStatus.RUNNING
Expand Down
63 changes: 61 additions & 2 deletions sky/serve/serve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import shutil
import threading
import time
import traceback
import typing
from typing import (Any, Callable, DefaultDict, Dict, Generic, Iterator, List,
Optional, TextIO, Type, TypeVar)
Expand All @@ -23,7 +24,9 @@
from sky import backends
from sky import exceptions
from sky import global_user_state
from sky import sky_logging
from sky import status_lib
from sky.backends import backend_utils
from sky.serve import constants
from sky.serve import serve_state
from sky.skylet import constants as skylet_constants
Expand All @@ -38,6 +41,8 @@

from sky.serve import replica_managers

logger = sky_logging.init_logger(__name__)

SKY_SERVE_CONTROLLER_NAME: str = (
f'sky-serve-controller-{common_utils.get_user_hash()}')
_SYSTEM_MEMORY_GB = psutil.virtual_memory().total // (1024**3)
Expand Down Expand Up @@ -270,19 +275,69 @@ def set_service_status_and_active_versions_from_replica(
active_versions=active_versions)


def update_service_status() -> None:
# TODO(tian): Combine this with
# sky/spot/recovery_strategy.py::terminate_cluster
def terminate_cluster(cluster_name: str,
replica_drain_delay_seconds: int = 0,
max_retry: int = 3) -> None:
"""Terminate the sky serve replica cluster."""
time.sleep(replica_drain_delay_seconds)
retry_cnt = 0
backoff = common_utils.Backoff()
while True:
retry_cnt += 1
handle = global_user_state.get_handle_from_cluster_name(cluster_name)
try:
if handle is not None:
backend = backend_utils.get_backend_from_handle(handle)
backend.teardown(handle, terminate=True)
return
except ValueError:
# The cluster is already terminated.
logger.info(
f'Replica cluster {cluster_name} is already terminated.')
return
except Exception as e: # pylint: disable=broad-except
if retry_cnt >= max_retry:
raise RuntimeError('Failed to terminate the sky serve replica '
f'cluster {cluster_name}.') from e
gap_seconds = backoff.current_backoff()
logger.error(
'Failed to terminate the sky serve replica cluster '
f'{cluster_name}. Retrying after {gap_seconds} seconds.'
f'Details: {common_utils.format_exception(e)}')
logger.error(f' Traceback: {traceback.format_exc()}')
time.sleep(gap_seconds)


def update_service_status(service_names: Optional[List[str]] = None) -> None:
services = serve_state.get_services()
for record in services:
if service_names is not None and record['name'] not in service_names:
continue
if record['status'] == serve_state.ServiceStatus.SHUTTING_DOWN:
# Skip services that is shutting down.
continue
controller_job_id = record['controller_job_id']
assert controller_job_id is not None
controller_status = job_lib.get_status(controller_job_id)
if controller_status is None or controller_status.is_terminal():
service_name = record['name']
# If controller job is not running, set it as controller failed.
serve_state.set_service_status_and_active_versions(
record['name'], serve_state.ServiceStatus.CONTROLLER_FAILED)
service_name, serve_state.ServiceStatus.CONTROLLER_FAILED)

# Find all service replicas and terminate them when the controller
# fails to avoid resource leak.
# TODO(zhwu): this may need to be think of for the fault tolerance
# case, since we may want to make sure the replicas are still
# accessible even if the controller fails. If this is too aggressive
# we can terminate the replica when the user call `sky serve down`
# for the service with CONTROLLER_FAILED.
replica_infos = serve_state.get_replica_infos(service_name)
for replica in replica_infos:
terminate_cluster(replica.cluster_name)
serve_state.remove_replica(service_name, replica.replica_id)


def update_service_encoded(service_name: str, version: int, mode: str) -> str:
Expand Down Expand Up @@ -387,6 +442,7 @@ def _terminate_failed_services(
# replicas, so we don't need to try again here.
for replica_info in serve_state.get_replica_infos(service_name):
# TODO(tian): Refresh latest status of the cluster.
terminate_cluster(replica_info.cluster_name)
if global_user_state.get_cluster_from_name(
replica_info.cluster_name) is not None:
remaining_replica_clusters.append(f'{replica_info.cluster_name!r}')
Expand All @@ -410,6 +466,9 @@ def terminate_services(service_names: Optional[List[str]], purge: bool) -> str:
service_names = serve_state.get_glob_service_names(service_names)
terminated_service_names = []
messages = []
# We should update the service status before terminating services to avoid
# service status in a staled state, and being handled incorrectly.
update_service_status(service_names)
for service_name in service_names:
service_status = _get_service_status(service_name,
with_replica_info=False)
Expand Down
2 changes: 1 addition & 1 deletion sky/serve/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _cleanup(service_name: str) -> bool:
info2proc: Dict[replica_managers.ReplicaInfo,
multiprocessing.Process] = dict()
for info in replica_infos:
p = multiprocessing.Process(target=replica_managers.terminate_cluster,
p = multiprocessing.Process(target=serve_utils.terminate_cluster,
args=(info.cluster_name,))
p.start()
info2proc[info] = p
Expand Down
2 changes: 1 addition & 1 deletion sky/skylet/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class JobSchedulerEvent(SkyletEvent):
EVENT_INTERVAL_SECONDS = 300

def _run(self):
job_lib.scheduler.schedule_step()
job_lib.scheduler.schedule_step(force_update_jobs=True)


class SpotJobUpdateEvent(SkyletEvent):
Expand Down
4 changes: 2 additions & 2 deletions sky/skylet/job_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,9 @@ def _run_job(self, job_id: int, run_cmd: str):
_CONN.commit()
subprocess.Popen(run_cmd, shell=True, stdout=subprocess.DEVNULL)

def schedule_step(self) -> None:
def schedule_step(self, force_update_jobs: bool = False) -> None:
jobs = self._get_jobs()
if len(jobs) > 0:
if len(jobs) > 0 or force_update_jobs:
update_status()
# TODO(zhwu, mraheja): One optimization can be allowing more than one
# job staying in the pending state after ray job submit, so that to be
Expand Down
Loading