From 171df239bbc2e4078f6cb0152faa0ea79845a49d Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 21 Jun 2024 22:25:24 +0000 Subject: [PATCH 01/22] Add docker run options --- sky/backends/backend_utils.py | 8 ++++++++ sky/templates/aws-ray.yml.j2 | 3 +++ sky/templates/azure-ray.yml.j2 | 3 +++ sky/templates/gcp-ray.yml.j2 | 3 +++ sky/templates/paperspace-ray.yml.j2 | 3 +++ sky/utils/schemas.py | 18 ++++++++++++++++++ 6 files changed, 38 insertions(+) diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 03f644930f4..6ddb2c8ef0d 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -873,6 +873,11 @@ def write_cluster_config( f'open(os.path.expanduser("{constants.SKY_REMOTE_RAY_PORT_FILE}"), "w", encoding="utf-8"))\'' ) + # Docker run options + docker_run_options = skypilot_config.get_nested(('docker', 'run_options'), []) + if isinstance(docker_run_options, str): + docker_run_options = [docker_run_options] + # Use a tmp file path to avoid incomplete YAML file being re-used in the # future. tmp_yaml_path = yaml_path + '.tmp' @@ -917,6 +922,9 @@ def write_cluster_config( wheel_hash).replace('{cloud}', str(cloud).lower())), + # Docker + 'docker_run_options': docker_run_options, + # Port of Ray (GCS server). # Ray's default port 6379 is conflicted with Redis. 'ray_port': constants.SKY_REMOTE_RAY_PORT, diff --git a/sky/templates/aws-ray.yml.j2 b/sky/templates/aws-ray.yml.j2 index 66c01f53617..26870ad4b04 100644 --- a/sky/templates/aws-ray.yml.j2 +++ b/sky/templates/aws-ray.yml.j2 @@ -14,6 +14,9 @@ docker: {%- if custom_resources is not none %} --gpus all {%- endif %} + {%- for run_option in docker_run_options %} + - {{run_option}} + {%- endfor %} {%- if docker_login_config is not none %} docker_login_config: username: |- diff --git a/sky/templates/azure-ray.yml.j2 b/sky/templates/azure-ray.yml.j2 index 803327f1032..66eac439453 100644 --- a/sky/templates/azure-ray.yml.j2 +++ b/sky/templates/azure-ray.yml.j2 @@ -14,6 +14,9 @@ docker: {%- if custom_resources is not none %} --gpus all {%- endif %} + {%- for run_option in docker_run_options %} + - {{run_option}} + {%- endfor %} {%- endif %} provider: diff --git a/sky/templates/gcp-ray.yml.j2 b/sky/templates/gcp-ray.yml.j2 index f4ec10a697d..cb305d59a84 100644 --- a/sky/templates/gcp-ray.yml.j2 +++ b/sky/templates/gcp-ray.yml.j2 @@ -15,6 +15,9 @@ docker: {%- if gpu is not none %} --gpus all {%- endif %} + {%- for run_option in docker_run_options %} + - {{run_option}} + {%- endfor %} {%- if docker_login_config is not none %} docker_login_config: username: |- diff --git a/sky/templates/paperspace-ray.yml.j2 b/sky/templates/paperspace-ray.yml.j2 index ba0886ee679..a2430a5a3e6 100644 --- a/sky/templates/paperspace-ray.yml.j2 +++ b/sky/templates/paperspace-ray.yml.j2 @@ -14,6 +14,9 @@ docker: {%- if custom_resources is not none %} --gpus all {%- endif %} + {%- for run_option in docker_run_options %} + - {{run_option}} + {%- endfor %} {%- if docker_login_config is not none %} docker_login_config: username: |- diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 932f2075d21..2ef0069db07 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -757,6 +757,23 @@ def get_config_schema(): } } + docker_configs = { + 'type': 'object', + 'required': [], + 'additionalProperties': False, + 'properties': { + 'run_options': { + 'anyOf': [{ + 'type': 'string', + }, { + 'type': 'array', + 'items': { + 'type': 'string', + } + }] + } + } + } for cloud, config in cloud_configs.items(): if cloud == 'aws': config['properties'].update(_REMOTE_IDENTITY_SCHEMA_AWS) @@ -774,6 +791,7 @@ def get_config_schema(): 'spot': controller_resources_schema, 'serve': controller_resources_schema, 'allowed_clouds': allowed_clouds, + 'docker': docker_configs, **cloud_configs, }, # Avoid spot and jobs being present at the same time. From fbb73b1963b963d5a372440845102b94b0083e88 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 21 Jun 2024 22:33:39 +0000 Subject: [PATCH 02/22] Add docs --- docs/source/reference/config.rst | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index 74cd2c01092..6b08585fce0 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -40,6 +40,24 @@ Available fields and semantics: - gcp - kubernetes + docker: + # Additional Docker run options (optional). + # + # When image_id: docker: is used in a task YAML, additional + # run options for starting the Docker container can be specified here. + # The default run options are: + # --net=host + # --cap-add=SYS_ADMIN + # --device=/dev/fuse + # --security-opt=apparmor:unconfined + # + # This field can be useful for mounting volumes and other advanced Docker + # configurations. The following is an example option for allowing running + # Docker inside Docker. + # sky launch --cloud aws --image-id docker:continuumio/miniconda3 "apt update; apt install -y docker.io; docker run hello-world" + run_options: + - -v /var/run/docker.sock:/var/run/docker.sock + # Advanced AWS configurations (optional). # Apply to all new instances but not existing ones. aws: From 0e09389dd8cb0df4792246823da9afdf14fdbe1e Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Fri, 21 Jun 2024 23:09:13 +0000 Subject: [PATCH 03/22] Add warning for docker run options in kubernetes --- sky/backends/backend_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index ae98b241d49..028f81ec588 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -874,9 +874,15 @@ def write_cluster_config( ) # Docker run options - docker_run_options = skypilot_config.get_nested(('docker', 'run_options'), []) + docker_run_options = skypilot_config.get_nested(('docker', 'run_options'), + []) if isinstance(docker_run_options, str): docker_run_options = [docker_run_options] + if docker_run_options and isinstance(to_provision.cloud, clouds.Kubernetes): + logger.warning(f'{colorama.Style.DIM}Docker run options are specified, ' + 'but ignored for Kubernetes: ' + f'{" ".join(docker_run_options)}' + f'{colorama.Style.RESET_ALL}') # Use a tmp file path to avoid incomplete YAML file being re-used in the # future. From d793eeaab293b295240be336279fea93f112060f Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Tue, 25 Jun 2024 01:47:08 +0000 Subject: [PATCH 04/22] Add experimental config --- sky/backends/backend_utils.py | 24 ++---------- sky/clouds/gcp.py | 10 +++-- sky/clouds/kubernetes.py | 22 ++++++----- sky/provision/kubernetes/utils.py | 9 +++-- sky/resources.py | 59 +++++++++++++++++++++++++++-- sky/skypilot_config.py | 62 ++++++++++++++++++++----------- sky/task.py | 17 ++++++++- sky/utils/schemas.py | 37 ++++++++++++++++++ 8 files changed, 177 insertions(+), 63 deletions(-) diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 028f81ec588..3f22f264024 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -873,23 +873,8 @@ def write_cluster_config( f'open(os.path.expanduser("{constants.SKY_REMOTE_RAY_PORT_FILE}"), "w", encoding="utf-8"))\'' ) - # Docker run options - docker_run_options = skypilot_config.get_nested(('docker', 'run_options'), - []) - if isinstance(docker_run_options, str): - docker_run_options = [docker_run_options] - if docker_run_options and isinstance(to_provision.cloud, clouds.Kubernetes): - logger.warning(f'{colorama.Style.DIM}Docker run options are specified, ' - 'but ignored for Kubernetes: ' - f'{" ".join(docker_run_options)}' - f'{colorama.Style.RESET_ALL}') - # Use a tmp file path to avoid incomplete YAML file being re-used in the # future. - initial_setup_commands = [] - if (skypilot_config.get_nested(('nvidia_gpus', 'disable_ecc'), False) and - to_provision.accelerators is not None): - initial_setup_commands.append(constants.DISABLE_GPU_ECC_COMMAND) tmp_yaml_path = yaml_path + '.tmp' common_utils.fill_template( cluster_config_template, @@ -921,8 +906,6 @@ def write_cluster_config( # currently only used by GCP. 'specific_reservations': specific_reservations, - # Initial setup commands. - 'initial_setup_commands': initial_setup_commands, # Conda setup 'conda_installation_commands': constants.CONDA_INSTALLATION_COMMANDS, @@ -934,9 +917,6 @@ def write_cluster_config( wheel_hash).replace('{cloud}', str(cloud).lower())), - # Docker - 'docker_run_options': docker_run_options, - # Port of Ray (GCS server). # Ray's default port 6379 is conflicted with Redis. 'ray_port': constants.SKY_REMOTE_RAY_PORT, @@ -983,7 +963,9 @@ def write_cluster_config( # Add kubernetes config fields from ~/.sky/config if isinstance(cloud, clouds.Kubernetes): - kubernetes_utils.combine_pod_config_fields(tmp_yaml_path) + kubernetes_utils.combine_pod_config_fields( + tmp_yaml_path, + skypilot_override_configs=to_provision.skypilot_config_override) kubernetes_utils.combine_metadata_fields(tmp_yaml_path) # Restore the old yaml content for backward compatibility. diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index 94add7fce7d..afe5cea19c6 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -197,8 +197,10 @@ def _unsupported_features_for_resources( # because `skypilot_config` may change for an existing cluster. # Clusters created with MIG (only GPU clusters) cannot be stopped. if (skypilot_config.get_nested( - ('gcp', 'managed_instance_group'), None) is not None and - resources.accelerators): + ('gcp', 'managed_instance_group'), + None, + override_configs=resources.skypilot_config_override) is not None + and resources.accelerators): unsupported[clouds.CloudImplementationFeatures.STOP] = ( 'Managed Instance Group (MIG) does not support stopping yet.') unsupported[clouds.CloudImplementationFeatures.SPOT_INSTANCE] = ( @@ -506,7 +508,9 @@ def make_deploy_resources_variables( resources_vars['tpu_node_name'] = tpu_node_name managed_instance_group_config = skypilot_config.get_nested( - ('gcp', 'managed_instance_group'), None) + ('gcp', 'managed_instance_group'), + None, + override_configs=resources.skypilot_config_override) use_mig = managed_instance_group_config is not None resources_vars['gcp_use_managed_instance_group'] = use_mig # Convert boolean to 0 or 1 in string, as GCP does not support boolean diff --git a/sky/clouds/kubernetes.py b/sky/clouds/kubernetes.py index 5d9e57568b9..6210e82c4e3 100644 --- a/sky/clouds/kubernetes.py +++ b/sky/clouds/kubernetes.py @@ -41,15 +41,6 @@ class Kubernetes(clouds.Cloud): PORT_FORWARD_PROXY_CMD_TEMPLATE = \ 'kubernetes-port-forward-proxy-command.sh.j2' PORT_FORWARD_PROXY_CMD_PATH = '~/.sky/port-forward-proxy-cmd.sh' - # Timeout for resource provisioning. This timeout determines how long to - # wait for pod to be in pending status before giving up. - # Larger timeout may be required for autoscaling clusters, since autoscaler - # may take some time to provision new nodes. - # Note that this timeout includes time taken by the Kubernetes scheduler - # itself, which can be upto 2-3 seconds. - # For non-autoscaling clusters, we conservatively set this to 10s. - timeout = skypilot_config.get_nested(['kubernetes', 'provision_timeout'], - 10) # Limit the length of the cluster name to avoid exceeding the limit of 63 # characters for Kubernetes resources. We limit to 42 characters (63-21) to @@ -312,6 +303,17 @@ def make_deploy_resources_variables( if resources.use_spot: spot_label_key, spot_label_value = kubernetes_utils.get_spot_label() + # Timeout for resource provisioning. This timeout determines how long to + # wait for pod to be in pending status before giving up. + # Larger timeout may be required for autoscaling clusters, since + # autoscaler may take some time to provision new nodes. + # Note that this timeout includes time taken by the Kubernetes scheduler + # itself, which can be upto 2-3 seconds. + # For non-autoscaling clusters, we conservatively set this to 10s. + timeout = skypilot_config.get_nested( + ['kubernetes', 'provision_timeout'], + 10, + override_configs=resources.skypilot_config_override) deploy_vars = { 'instance_type': resources.instance_type, 'custom_resources': custom_resources, @@ -319,7 +321,7 @@ def make_deploy_resources_variables( 'cpus': str(cpus), 'memory': str(mem), 'accelerator_count': str(acc_count), - 'timeout': str(self.timeout), + 'timeout': str(timeout), 'k8s_namespace': kubernetes_utils.get_current_kube_config_context_namespace(), 'k8s_port_mode': port_mode.value, diff --git a/sky/provision/kubernetes/utils.py b/sky/provision/kubernetes/utils.py index c599a5738d0..6bcccb2436a 100644 --- a/sky/provision/kubernetes/utils.py +++ b/sky/provision/kubernetes/utils.py @@ -1353,7 +1353,9 @@ def merge_dicts(source: Dict[Any, Any], destination: Dict[Any, Any]): destination[key] = value -def combine_pod_config_fields(cluster_yaml_path: str) -> None: +def combine_pod_config_fields( + cluster_yaml_path: str, skypilot_override_configs: Dict[str, + Any]) -> None: """Adds or updates fields in the YAML with fields from the ~/.sky/config's kubernetes.pod_spec dict. This can be used to add fields to the YAML that are not supported by @@ -1395,8 +1397,9 @@ def combine_pod_config_fields(cluster_yaml_path: str) -> None: with open(cluster_yaml_path, 'r', encoding='utf-8') as f: yaml_content = f.read() yaml_obj = yaml.safe_load(yaml_content) - kubernetes_config = skypilot_config.get_nested(('kubernetes', 'pod_config'), - {}) + kubernetes_config = skypilot_config.get_nested( + ('kubernetes', 'pod_config'), {}, + override_configs=skypilot_override_configs) # Merge the kubernetes config into the YAML for both head and worker nodes. merge_dicts( diff --git a/sky/resources.py b/sky/resources.py index 252edff5da6..9b26fa90c90 100644 --- a/sky/resources.py +++ b/sky/resources.py @@ -26,6 +26,14 @@ _DEFAULT_DISK_SIZE_GB = 256 +OVERRIDEABLE_CONFIG_KEYS = [ + ('docker',), + ('nvidia_gpus',), + ('kubernetes', 'pod_config'), + ('kubernetes', 'provision_timeout'), + ('gcp', 'managed_instance_group'), +] + class Resources: """Resources: compute requirements of Tasks. @@ -44,7 +52,7 @@ class Resources: """ # If any fields changed, increment the version. For backward compatibility, # modify the __setstate__ method to handle the old version. - _VERSION = 18 + _VERSION = 19 def __init__( self, @@ -68,6 +76,7 @@ def __init__( _docker_login_config: Optional[docker_utils.DockerLoginConfig] = None, _is_image_managed: Optional[bool] = None, _requires_fuse: Optional[bool] = None, + _skypilot_config_override: Optional[Dict[str, Any]] = None, ): """Initialize a Resources object. @@ -218,6 +227,8 @@ def __init__( self._requires_fuse = _requires_fuse + self._skypilot_config_override = _skypilot_config_override + self._set_cpus(cpus) self._set_memory(memory) self._set_accelerators(accelerators, accelerator_args) @@ -448,6 +459,12 @@ def requires_fuse(self) -> bool: return False return self._requires_fuse + @property + def skypilot_config_override(self) -> Dict[str, Any]: + if self._skypilot_config_override is None: + return {} + return self._skypilot_config_override + @requires_fuse.setter def requires_fuse(self, value: Optional[bool]) -> None: self._requires_fuse = value @@ -1011,13 +1028,39 @@ def make_deploy_variables(self, cluster_name_on_cloud: str, cloud.make_deploy_resources_variables() method, and the cloud-agnostic variables are generated by this method. """ + # Initial setup commands + initial_setup_commands = [] + if (skypilot_config.get_nested( + ('nvidia_gpus', 'disable_ecc'), + False, + override_configs=self._skypilot_config_override) and + self.accelerators is not None): + initial_setup_commands = [constants.DISABLE_GPU_ECC_COMMAND] + + # Docker run options + docker_run_options = skypilot_config.get_nested( + ('docker', 'run_options'), + default_value=[], + override_configs=self._skypilot_config_override) + if isinstance(docker_run_options, str): + docker_run_options = [docker_run_options] + if docker_run_options and isinstance(self.cloud, clouds.Kubernetes): + logger.warning( + f'{colorama.Style.DIM}Docker run options are specified, ' + 'but ignored for Kubernetes: ' + f'{" ".join(docker_run_options)}' + f'{colorama.Style.RESET_ALL}') + + docker_image = self.extract_docker_image() + + # Cloud specific variables cloud_specific_variables = self.cloud.make_deploy_resources_variables( self, cluster_name_on_cloud, region, zones, dryrun) - docker_image = self.extract_docker_image() return dict( cloud_specific_variables, **{ # Docker config + 'docker_run_options': docker_run_options, # Docker image. The image name used to pull the image, e.g. # ubuntu:latest. 'docker_image': docker_image, @@ -1027,7 +1070,9 @@ def make_deploy_variables(self, cluster_name_on_cloud: str, constants.DEFAULT_DOCKER_CONTAINER_NAME, # Docker login config (if any). This helps pull the image from # private registries. - 'docker_login_config': self._docker_login_config + 'docker_login_config': self._docker_login_config, + # Initial setup commands. + 'initial_setup_commands': initial_setup_commands, }) def get_reservations_available_resources(self) -> Dict[str, int]: @@ -1367,6 +1412,8 @@ def _from_yaml_config_single(cls, config: Dict[str, str]) -> 'Resources': resources_fields['_is_image_managed'] = config.pop( '_is_image_managed', None) resources_fields['_requires_fuse'] = config.pop('_requires_fuse', None) + resources_fields['_skypilot_config_override'] = config.pop( + '_skypilot_config_override', None) if resources_fields['cpus'] is not None: resources_fields['cpus'] = str(resources_fields['cpus']) @@ -1410,6 +1457,8 @@ def add_if_not_none(key, value): if self._docker_login_config is not None: config['_docker_login_config'] = dataclasses.asdict( self._docker_login_config) + add_if_not_none('_skypilot_config_override', + self._skypilot_config_override) if self._is_image_managed is not None: config['_is_image_managed'] = self._is_image_managed if self._requires_fuse is not None: @@ -1525,4 +1574,8 @@ def __setstate__(self, state): if version < 18: self._job_recovery = state.pop('_spot_recovery', None) + if version < 19: + self._skypilot_config_override = state.pop( + '_skypilot_config_override', None) + self.__dict__.update(state) diff --git a/sky/skypilot_config.py b/sky/skypilot_config.py index 5b205e2692a..9ad24292948 100644 --- a/sky/skypilot_config.py +++ b/sky/skypilot_config.py @@ -77,15 +77,11 @@ _loaded_config_path = None -def get_nested(keys: Iterable[str], default_value: Any) -> Any: - """Gets a nested key. - - If any key is not found, or any intermediate key does not point to a dict - value, returns 'default_value'. - """ - if _dict is None: +def _get_nested(configs: Optional[Dict[str, Any]], keys: Iterable[str], + default_value: Any) -> Any: + if configs is None: return default_value - curr = _dict + curr = configs for key in keys: if isinstance(curr, dict) and key in curr: curr = curr[key] @@ -95,6 +91,35 @@ def get_nested(keys: Iterable[str], default_value: Any) -> Any: return curr +def get_nested(keys: Iterable[str], + default_value: Any, + override_configs: Optional[Dict[str, Any]] = None) -> Any: + """Gets a nested key. + + If any key is not found, or any intermediate key does not point to a dict + value, returns 'default_value'. + """ + # TODO (zhwu): Verify that the override_configs is provided when keys is + # within resources.OVERRIDEABLE_CONFIG_KEYS. + if _dict is None: + if override_configs is not None: + return _get_nested(override_configs, keys, default_value) + return default_value + return _get_nested(_dict, keys, default_value) + + +def _recursive_update(base_config: Dict[str, Any], + override_config: Dict[str, Any]) -> Dict[str, Any]: + """Recursively updates base configuration with override configuration""" + for key, value in override_config.items(): + if (isinstance(value, dict) and key in base_config and + isinstance(base_config[key], dict)): + _recursive_update(base_config[key], value) + else: + base_config[key] = value + return base_config + + def set_nested(keys: Iterable[str], value: Any) -> Dict[str, Any]: """Returns a deep-copied config with the nested key set to value. @@ -102,20 +127,13 @@ def set_nested(keys: Iterable[str], value: Any) -> Dict[str, Any]: """ _check_loaded_or_die() assert _dict is not None - curr = copy.deepcopy(_dict) - to_return = curr - prev = None - for i, key in enumerate(keys): - if key not in curr: - curr[key] = {} - prev = curr - curr = curr[key] - if i == len(keys) - 1: - prev_value = prev[key] - prev[key] = value - logger.debug(f'Set the value of {keys} to {value} (previous: ' - f'{prev_value}). Returning conf: {to_return}') - return to_return + override = {} + for i, key in enumerate(reversed(keys)): + if i == 0: + override = {key: value} + else: + override = {key: override} + return _recursive_update(copy.deepcopy(_dict), override) def to_dict() -> Dict[str, Any]: diff --git a/sky/task.py b/sky/task.py index 3dd254838f0..a1252beb362 100644 --- a/sky/task.py +++ b/sky/task.py @@ -456,8 +456,23 @@ def from_yaml_config( task.set_outputs(outputs=outputs, estimated_size_gigabytes=estimated_size_gigabytes) + # Experimental configs. + experimnetal_configs = config.pop('experimental', None) + skypilot_config_override = None + if experimnetal_configs is not None: + skypilot_config_override = experimnetal_configs.pop( + 'config_overrides', None) + assert not experimnetal_configs, ('Invalid task args: ' + f'{experimnetal_configs.keys()}') + # Parse resources field. - resources_config = config.pop('resources', None) + resources_config = config.pop('resources', {}) + if skypilot_config_override is not None: + assert resources_config.get('_skypilot_config_override') is None, ( + 'Cannot set _skypilot_config_override in both resources and ' + 'experimental.config_overrides') + resources_config[ + '_skypilot_config_override'] = skypilot_config_override task.set_resources(sky.Resources.from_yaml_config(resources_config)) service = config.pop('service', None) diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 2f1dd649ade..7936cfccc23 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -4,6 +4,7 @@ https://json-schema.org/ """ import enum +from typing import List, Tuple def _check_not_both_fields_present(field1: str, field2: str): @@ -370,6 +371,41 @@ def get_service_schema(): } +def _filter_configs(configs: dict, filter_keys: List[Tuple[str, ...]]) -> dict: + new_config = { + k: v for k, v in configs.items() if k not in ['properties', '$schema'] + } + + def _get_value(config: dict, keys: Tuple[str, ...]) -> dict: + if len(keys) == 1: + return config[keys[0]] + return _get_value(config[keys[0]], keys[1:]) + + for keys in filter_keys: + value = _get_value(new_config, keys) + for key in keys: + if key not in new_config: + new_config[key] = {} + if key == keys[-1]: + new_config[key] = value + return new_config + + +def _experimental_task_schema() -> dict: + from sky import resources # pylint: disable=import-outside-toplevel + return { + 'experimental': { + 'type': 'object', + 'required': [], + 'additionalProperties': False, + 'properties': { + 'configs': _filter_configs(get_config_schema(), + resources.OVERRIDEABLE_CONFIG_KEYS), + } + } + } + + def get_task_schema(): return { '$schema': 'https://json-schema.org/draft/2020-12/schema', @@ -435,6 +471,7 @@ def get_task_schema(): 'type': 'number' } }, + **_experimental_task_schema(), } } From 787109df627e7015ca275fae06cb3895cd28fc59 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Tue, 25 Jun 2024 18:13:48 +0000 Subject: [PATCH 05/22] fix --- sky/provision/kubernetes/utils.py | 5 ++- sky/resources.py | 2 + sky/skypilot_config.py | 3 +- sky/task.py | 2 + sky/utils/schemas.py | 71 ++++++++++++++++++++++--------- 5 files changed, 61 insertions(+), 22 deletions(-) diff --git a/sky/provision/kubernetes/utils.py b/sky/provision/kubernetes/utils.py index 6bcccb2436a..a08e0155adb 100644 --- a/sky/provision/kubernetes/utils.py +++ b/sky/provision/kubernetes/utils.py @@ -1354,8 +1354,9 @@ def merge_dicts(source: Dict[Any, Any], destination: Dict[Any, Any]): def combine_pod_config_fields( - cluster_yaml_path: str, skypilot_override_configs: Dict[str, - Any]) -> None: + cluster_yaml_path: str, + skypilot_override_configs: Dict[str, Any], +) -> None: """Adds or updates fields in the YAML with fields from the ~/.sky/config's kubernetes.pod_spec dict. This can be used to add fields to the YAML that are not supported by diff --git a/sky/resources.py b/sky/resources.py index 9b26fa90c90..bf88742d1a0 100644 --- a/sky/resources.py +++ b/sky/resources.py @@ -1253,6 +1253,8 @@ def copy(self, **override) -> 'Resources': _is_image_managed=override.pop('_is_image_managed', self._is_image_managed), _requires_fuse=override.pop('_requires_fuse', self._requires_fuse), + _skypilot_config_override=override.pop( + '_skypilot_config_override', self._skypilot_config_override), ) assert len(override) == 0 return resources diff --git a/sky/skypilot_config.py b/sky/skypilot_config.py index 9ad24292948..edf85cc9e68 100644 --- a/sky/skypilot_config.py +++ b/sky/skypilot_config.py @@ -105,7 +105,8 @@ def get_nested(keys: Iterable[str], if override_configs is not None: return _get_nested(override_configs, keys, default_value) return default_value - return _get_nested(_dict, keys, default_value) + config = _recursive_update(copy.deepcopy(_dict), override_configs or {}) + return _get_nested(config, keys, default_value) def _recursive_update(base_config: Dict[str, Any], diff --git a/sky/task.py b/sky/task.py index a1252beb362..fc7c022e427 100644 --- a/sky/task.py +++ b/sky/task.py @@ -462,6 +462,8 @@ def from_yaml_config( if experimnetal_configs is not None: skypilot_config_override = experimnetal_configs.pop( 'config_overrides', None) + logger.debug('Overriding skypilot config with task-level config: ' + f'{skypilot_config_override}') assert not experimnetal_configs, ('Invalid task args: ' f'{experimnetal_configs.keys()}') diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 7936cfccc23..e25c9251deb 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -146,7 +146,8 @@ def _get_single_resources_schema(): 'type': 'null', }] }, - # The following fields are for internal use only. + # The following fields are for internal use only. Should not be + # specified in the task config. '_docker_login_config': { 'type': 'object', 'required': ['username', 'password', 'server'], @@ -169,6 +170,9 @@ def _get_single_resources_schema(): '_requires_fuse': { 'type': 'boolean', }, + '_skypilot_config_override': { + 'type': 'object', + }, } } @@ -371,36 +375,65 @@ def get_service_schema(): } -def _filter_configs(configs: dict, filter_keys: List[Tuple[str, ...]]) -> dict: - new_config = { - k: v for k, v in configs.items() if k not in ['properties', '$schema'] - } +def _filter_schema(schema, keys_to_keep): + """ + Recursively filter a schema to include only certain keys. + :param schema: The original schema dictionary. + :param keys_to_keep: List of tuples with the path of keys to retain. + :return: The filtered schema. + """ + if not isinstance(schema, dict): + return schema # Return as is if it's not a dictionary + + # Convert list of tuples to a dictionary for easier access + paths_dict = {} + for path in keys_to_keep: + current = paths_dict + for step in path: + if step not in current: + current[step] = {} + current = current[step] + + def keep_keys(current_schema: dict, current_path_dict: dict, + new_schema: dict): + # Base case: if we reach a leaf in the path_dict, we stop. + if (not current_path_dict or not isinstance(current_schema, dict) or + not current_schema.get('properties')): + return current_schema + + if 'properties' not in new_schema: + new_schema = { + key: current_schema[key] + for key in current_schema + if key != 'properties' + } + new_schema['properties'] = {} + + for key, sub_schema in current_schema['properties'].items(): + if key in current_path_dict: + # Recursively keep keys if further path dict exists + new_schema['properties'][key] = {} + new_schema['properties'][key] = keep_keys( + sub_schema, current_path_dict[key], + new_schema['properties'][key]) - def _get_value(config: dict, keys: Tuple[str, ...]) -> dict: - if len(keys) == 1: - return config[keys[0]] - return _get_value(config[keys[0]], keys[1:]) + return new_schema - for keys in filter_keys: - value = _get_value(new_config, keys) - for key in keys: - if key not in new_config: - new_config[key] = {} - if key == keys[-1]: - new_config[key] = value - return new_config + # Start the recursive filtering + return keep_keys(schema, paths_dict, {}) def _experimental_task_schema() -> dict: from sky import resources # pylint: disable=import-outside-toplevel + config_override_schema = _filter_schema(get_config_schema(), + resources.OVERRIDEABLE_CONFIG_KEYS) return { 'experimental': { 'type': 'object', 'required': [], 'additionalProperties': False, 'properties': { - 'configs': _filter_configs(get_config_schema(), - resources.OVERRIDEABLE_CONFIG_KEYS), + 'config_overrides': config_override_schema, } } } From f9ae5d0421640caf126f110709b56a1b7aa62915 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Thu, 27 Jun 2024 05:24:43 +0000 Subject: [PATCH 06/22] rename vars --- sky/backends/backend_utils.py | 2 +- sky/clouds/gcp.py | 4 ++-- sky/clouds/kubernetes.py | 2 +- sky/resources.py | 30 +++++++++++++++--------------- sky/task.py | 14 +++++++------- sky/utils/schemas.py | 21 ++++++++++----------- 6 files changed, 36 insertions(+), 37 deletions(-) diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 3f22f264024..f72561c44f4 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -965,7 +965,7 @@ def write_cluster_config( if isinstance(cloud, clouds.Kubernetes): kubernetes_utils.combine_pod_config_fields( tmp_yaml_path, - skypilot_override_configs=to_provision.skypilot_config_override) + skypilot_override_configs=to_provision.cluster_config_override) kubernetes_utils.combine_metadata_fields(tmp_yaml_path) # Restore the old yaml content for backward compatibility. diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index afe5cea19c6..550acbb5259 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -199,7 +199,7 @@ def _unsupported_features_for_resources( if (skypilot_config.get_nested( ('gcp', 'managed_instance_group'), None, - override_configs=resources.skypilot_config_override) is not None + override_configs=resources.cluster_config_override) is not None and resources.accelerators): unsupported[clouds.CloudImplementationFeatures.STOP] = ( 'Managed Instance Group (MIG) does not support stopping yet.') @@ -510,7 +510,7 @@ def make_deploy_resources_variables( managed_instance_group_config = skypilot_config.get_nested( ('gcp', 'managed_instance_group'), None, - override_configs=resources.skypilot_config_override) + override_configs=resources.cluster_config_override) use_mig = managed_instance_group_config is not None resources_vars['gcp_use_managed_instance_group'] = use_mig # Convert boolean to 0 or 1 in string, as GCP does not support boolean diff --git a/sky/clouds/kubernetes.py b/sky/clouds/kubernetes.py index 6210e82c4e3..5680d9c461e 100644 --- a/sky/clouds/kubernetes.py +++ b/sky/clouds/kubernetes.py @@ -313,7 +313,7 @@ def make_deploy_resources_variables( timeout = skypilot_config.get_nested( ['kubernetes', 'provision_timeout'], 10, - override_configs=resources.skypilot_config_override) + override_configs=resources.cluster_config_override) deploy_vars = { 'instance_type': resources.instance_type, 'custom_resources': custom_resources, diff --git a/sky/resources.py b/sky/resources.py index bf88742d1a0..047117ec29d 100644 --- a/sky/resources.py +++ b/sky/resources.py @@ -76,7 +76,7 @@ def __init__( _docker_login_config: Optional[docker_utils.DockerLoginConfig] = None, _is_image_managed: Optional[bool] = None, _requires_fuse: Optional[bool] = None, - _skypilot_config_override: Optional[Dict[str, Any]] = None, + _cluster_config_override: Optional[Dict[str, Any]] = None, ): """Initialize a Resources object. @@ -227,7 +227,7 @@ def __init__( self._requires_fuse = _requires_fuse - self._skypilot_config_override = _skypilot_config_override + self._cluster_config_override = _cluster_config_override self._set_cpus(cpus) self._set_memory(memory) @@ -460,10 +460,10 @@ def requires_fuse(self) -> bool: return self._requires_fuse @property - def skypilot_config_override(self) -> Dict[str, Any]: - if self._skypilot_config_override is None: + def cluster_config_override(self) -> Dict[str, Any]: + if self._cluster_config_override is None: return {} - return self._skypilot_config_override + return self._cluster_config_override @requires_fuse.setter def requires_fuse(self, value: Optional[bool]) -> None: @@ -1033,7 +1033,7 @@ def make_deploy_variables(self, cluster_name_on_cloud: str, if (skypilot_config.get_nested( ('nvidia_gpus', 'disable_ecc'), False, - override_configs=self._skypilot_config_override) and + override_configs=self._cluster_config_override) and self.accelerators is not None): initial_setup_commands = [constants.DISABLE_GPU_ECC_COMMAND] @@ -1041,7 +1041,7 @@ def make_deploy_variables(self, cluster_name_on_cloud: str, docker_run_options = skypilot_config.get_nested( ('docker', 'run_options'), default_value=[], - override_configs=self._skypilot_config_override) + override_configs=self._cluster_config_override) if isinstance(docker_run_options, str): docker_run_options = [docker_run_options] if docker_run_options and isinstance(self.cloud, clouds.Kubernetes): @@ -1253,8 +1253,8 @@ def copy(self, **override) -> 'Resources': _is_image_managed=override.pop('_is_image_managed', self._is_image_managed), _requires_fuse=override.pop('_requires_fuse', self._requires_fuse), - _skypilot_config_override=override.pop( - '_skypilot_config_override', self._skypilot_config_override), + _cluster_config_override=override.pop( + '_cluster_config_override', self._cluster_config_override), ) assert len(override) == 0 return resources @@ -1414,8 +1414,8 @@ def _from_yaml_config_single(cls, config: Dict[str, str]) -> 'Resources': resources_fields['_is_image_managed'] = config.pop( '_is_image_managed', None) resources_fields['_requires_fuse'] = config.pop('_requires_fuse', None) - resources_fields['_skypilot_config_override'] = config.pop( - '_skypilot_config_override', None) + resources_fields['_cluster_config_override'] = config.pop( + '_cluster_config_override', None) if resources_fields['cpus'] is not None: resources_fields['cpus'] = str(resources_fields['cpus']) @@ -1459,8 +1459,8 @@ def add_if_not_none(key, value): if self._docker_login_config is not None: config['_docker_login_config'] = dataclasses.asdict( self._docker_login_config) - add_if_not_none('_skypilot_config_override', - self._skypilot_config_override) + add_if_not_none('_cluster_config_override', + self._cluster_config_override) if self._is_image_managed is not None: config['_is_image_managed'] = self._is_image_managed if self._requires_fuse is not None: @@ -1577,7 +1577,7 @@ def __setstate__(self, state): self._job_recovery = state.pop('_spot_recovery', None) if version < 19: - self._skypilot_config_override = state.pop( - '_skypilot_config_override', None) + self._cluster_config_override = state.pop( + '_cluster_config_override', None) self.__dict__.update(state) diff --git a/sky/task.py b/sky/task.py index fc7c022e427..062af91fde7 100644 --- a/sky/task.py +++ b/sky/task.py @@ -458,23 +458,23 @@ def from_yaml_config( # Experimental configs. experimnetal_configs = config.pop('experimental', None) - skypilot_config_override = None + cluster_config_override = None if experimnetal_configs is not None: - skypilot_config_override = experimnetal_configs.pop( + cluster_config_override = experimnetal_configs.pop( 'config_overrides', None) logger.debug('Overriding skypilot config with task-level config: ' - f'{skypilot_config_override}') + f'{cluster_config_override}') assert not experimnetal_configs, ('Invalid task args: ' f'{experimnetal_configs.keys()}') # Parse resources field. resources_config = config.pop('resources', {}) - if skypilot_config_override is not None: - assert resources_config.get('_skypilot_config_override') is None, ( - 'Cannot set _skypilot_config_override in both resources and ' + if cluster_config_override is not None: + assert resources_config.get('_cluster_config_override') is None, ( + 'Cannot set _cluster_config_override in both resources and ' 'experimental.config_overrides') resources_config[ - '_skypilot_config_override'] = skypilot_config_override + '_cluster_config_override'] = cluster_config_override task.set_resources(sky.Resources.from_yaml_config(resources_config)) service = config.pop('service', None) diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index e25c9251deb..89763ae4137 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -4,7 +4,6 @@ https://json-schema.org/ """ import enum -from typing import List, Tuple def _check_not_both_fields_present(field1: str, field2: str): @@ -170,7 +169,7 @@ def _get_single_resources_schema(): '_requires_fuse': { 'type': 'boolean', }, - '_skypilot_config_override': { + '_cluster_config_override': { 'type': 'object', }, } @@ -375,16 +374,16 @@ def get_service_schema(): } -def _filter_schema(schema, keys_to_keep): +def _filter_schema(schema: dict, keys_to_keep: dict): + """Recursively filter a schema to include only certain keys. + + Args: + schema: The original schema dictionary. + keys_to_keep: List of tuples with the path of keys to retain. + + Returns: + The filtered schema. """ - Recursively filter a schema to include only certain keys. - :param schema: The original schema dictionary. - :param keys_to_keep: List of tuples with the path of keys to retain. - :return: The filtered schema. - """ - if not isinstance(schema, dict): - return schema # Return as is if it's not a dictionary - # Convert list of tuples to a dictionary for easier access paths_dict = {} for path in keys_to_keep: From a0a682ce0c20c3db5bc58c19e2a8211c308e3107 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Thu, 27 Jun 2024 05:25:45 +0000 Subject: [PATCH 07/22] type --- sky/utils/schemas.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 89763ae4137..057ef9db380 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -4,6 +4,7 @@ https://json-schema.org/ """ import enum +from typing import Any, Dict def _check_not_both_fields_present(field1: str, field2: str): @@ -374,7 +375,7 @@ def get_service_schema(): } -def _filter_schema(schema: dict, keys_to_keep: dict): +def _filter_schema(schema: dict, keys_to_keep: dict) -> dict: """Recursively filter a schema to include only certain keys. Args: @@ -385,7 +386,7 @@ def _filter_schema(schema: dict, keys_to_keep: dict): The filtered schema. """ # Convert list of tuples to a dictionary for easier access - paths_dict = {} + paths_dict: Dict[str, Any] = {} for path in keys_to_keep: current = paths_dict for step in path: @@ -394,7 +395,7 @@ def _filter_schema(schema: dict, keys_to_keep: dict): current = current[step] def keep_keys(current_schema: dict, current_path_dict: dict, - new_schema: dict): + new_schema: dict) -> dict: # Base case: if we reach a leaf in the path_dict, we stop. if (not current_path_dict or not isinstance(current_schema, dict) or not current_schema.get('properties')): From 13ba7eefa7cda85b91f147ef41be6b4906e442af Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Thu, 27 Jun 2024 05:27:18 +0000 Subject: [PATCH 08/22] format --- sky/utils/schemas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 057ef9db380..50316bca0a5 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -377,11 +377,11 @@ def get_service_schema(): def _filter_schema(schema: dict, keys_to_keep: dict) -> dict: """Recursively filter a schema to include only certain keys. - + Args: schema: The original schema dictionary. keys_to_keep: List of tuples with the path of keys to retain. - + Returns: The filtered schema. """ From 16b6dd957c46b8c2d096832a7183e8213c272d7e Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Tue, 2 Jul 2024 03:16:58 +0000 Subject: [PATCH 09/22] wip --- tests/test_config.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_config.py b/tests/test_config.py index 44154d7348d..a461c6605c2 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -230,3 +230,13 @@ def test_config_with_env(monkeypatch, tmp_path) -> None: None) == PROXY_COMMAND assert skypilot_config.get_nested(('gcp', 'vpc_name'), None) == VPC_NAME assert skypilot_config.get_nested(('gcp', 'use_internal_ips'), None) + + +def test_config_with_override(monkeypatch, tmp_path) -> None: + config_path = tmp_path / 'config.yaml' + _create_config_file(config_path) + monkeypatch.setattr(skypilot_config, 'CONFIG_PATH', config_path) + + _reload_config() + + \ No newline at end of file From a0177c7abc8886e07dd564322059e1fd0b57cb0c Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Tue, 2 Jul 2024 17:03:53 +0000 Subject: [PATCH 10/22] rename and add tests --- sky/backends/backend_utils.py | 2 +- sky/clouds/gcp.py | 4 +- sky/clouds/kubernetes.py | 2 +- sky/provision/kubernetes/utils.py | 12 ++-- sky/resources.py | 30 ++++---- sky/task.py | 6 +- sky/utils/schemas.py | 2 +- tests/test_config.py | 109 +++++++++++++++++++++++++++++- 8 files changed, 138 insertions(+), 29 deletions(-) diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index ee0249297a7..b4e9a335f11 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -980,7 +980,7 @@ def write_cluster_config( if isinstance(cloud, clouds.Kubernetes): kubernetes_utils.combine_pod_config_fields( tmp_yaml_path, - skypilot_override_configs=to_provision.cluster_config_override) + cluster_config_overrides=to_provision.cluster_config_overrides) kubernetes_utils.combine_metadata_fields(tmp_yaml_path) # Restore the old yaml content for backward compatibility. diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index 550acbb5259..367300b1daf 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -199,7 +199,7 @@ def _unsupported_features_for_resources( if (skypilot_config.get_nested( ('gcp', 'managed_instance_group'), None, - override_configs=resources.cluster_config_override) is not None + override_configs=resources.cluster_config_overrides) is not None and resources.accelerators): unsupported[clouds.CloudImplementationFeatures.STOP] = ( 'Managed Instance Group (MIG) does not support stopping yet.') @@ -510,7 +510,7 @@ def make_deploy_resources_variables( managed_instance_group_config = skypilot_config.get_nested( ('gcp', 'managed_instance_group'), None, - override_configs=resources.cluster_config_override) + override_configs=resources.cluster_config_overrides) use_mig = managed_instance_group_config is not None resources_vars['gcp_use_managed_instance_group'] = use_mig # Convert boolean to 0 or 1 in string, as GCP does not support boolean diff --git a/sky/clouds/kubernetes.py b/sky/clouds/kubernetes.py index b2a4dbc62b8..64570de78e8 100644 --- a/sky/clouds/kubernetes.py +++ b/sky/clouds/kubernetes.py @@ -310,7 +310,7 @@ def make_deploy_resources_variables( timeout = skypilot_config.get_nested( ['kubernetes', 'provision_timeout'], 10, - override_configs=resources.cluster_config_override) + override_configs=resources.cluster_config_overrides) deploy_vars = { 'instance_type': resources.instance_type, 'custom_resources': custom_resources, diff --git a/sky/provision/kubernetes/utils.py b/sky/provision/kubernetes/utils.py index 63dd9b43100..ebd743f55ad 100644 --- a/sky/provision/kubernetes/utils.py +++ b/sky/provision/kubernetes/utils.py @@ -1389,7 +1389,7 @@ def merge_dicts(source: Dict[Any, Any], destination: Dict[Any, Any]): def combine_pod_config_fields( cluster_yaml_path: str, - skypilot_override_configs: Dict[str, Any], + cluster_config_overrides: Dict[str, Any], ) -> None: """Adds or updates fields in the YAML with fields from the ~/.sky/config's kubernetes.pod_spec dict. @@ -1432,9 +1432,13 @@ def combine_pod_config_fields( with open(cluster_yaml_path, 'r', encoding='utf-8') as f: yaml_content = f.read() yaml_obj = yaml.safe_load(yaml_content) - kubernetes_config = skypilot_config.get_nested( - ('kubernetes', 'pod_config'), {}, - override_configs=skypilot_override_configs) + # We don't use override_config in `skypilot_config.get_nested`, as merging + # the pod config requires special handling. + kubernetes_config = skypilot_config.get_nested(('kubernetes', 'pod_config'), + default_value={}) + override_pod_config = (cluster_config_overrides.get('kubernetes', {}).get( + 'pod_config', {})) + merge_dicts(override_pod_config, kubernetes_config) # Merge the kubernetes config into the YAML for both head and worker nodes. merge_dicts( diff --git a/sky/resources.py b/sky/resources.py index 047117ec29d..1f426646dfc 100644 --- a/sky/resources.py +++ b/sky/resources.py @@ -76,7 +76,7 @@ def __init__( _docker_login_config: Optional[docker_utils.DockerLoginConfig] = None, _is_image_managed: Optional[bool] = None, _requires_fuse: Optional[bool] = None, - _cluster_config_override: Optional[Dict[str, Any]] = None, + _cluster_config_overrides: Optional[Dict[str, Any]] = None, ): """Initialize a Resources object. @@ -227,7 +227,7 @@ def __init__( self._requires_fuse = _requires_fuse - self._cluster_config_override = _cluster_config_override + self._cluster_config_overrides = _cluster_config_overrides self._set_cpus(cpus) self._set_memory(memory) @@ -460,10 +460,10 @@ def requires_fuse(self) -> bool: return self._requires_fuse @property - def cluster_config_override(self) -> Dict[str, Any]: - if self._cluster_config_override is None: + def cluster_config_overrides(self) -> Dict[str, Any]: + if self._cluster_config_overrides is None: return {} - return self._cluster_config_override + return self._cluster_config_overrides @requires_fuse.setter def requires_fuse(self, value: Optional[bool]) -> None: @@ -1033,7 +1033,7 @@ def make_deploy_variables(self, cluster_name_on_cloud: str, if (skypilot_config.get_nested( ('nvidia_gpus', 'disable_ecc'), False, - override_configs=self._cluster_config_override) and + override_configs=self._cluster_config_overrides) and self.accelerators is not None): initial_setup_commands = [constants.DISABLE_GPU_ECC_COMMAND] @@ -1041,7 +1041,7 @@ def make_deploy_variables(self, cluster_name_on_cloud: str, docker_run_options = skypilot_config.get_nested( ('docker', 'run_options'), default_value=[], - override_configs=self._cluster_config_override) + override_configs=self._cluster_config_overrides) if isinstance(docker_run_options, str): docker_run_options = [docker_run_options] if docker_run_options and isinstance(self.cloud, clouds.Kubernetes): @@ -1253,8 +1253,8 @@ def copy(self, **override) -> 'Resources': _is_image_managed=override.pop('_is_image_managed', self._is_image_managed), _requires_fuse=override.pop('_requires_fuse', self._requires_fuse), - _cluster_config_override=override.pop( - '_cluster_config_override', self._cluster_config_override), + _cluster_config_overrides=override.pop( + '_cluster_config_overrides', self._cluster_config_overrides), ) assert len(override) == 0 return resources @@ -1414,8 +1414,8 @@ def _from_yaml_config_single(cls, config: Dict[str, str]) -> 'Resources': resources_fields['_is_image_managed'] = config.pop( '_is_image_managed', None) resources_fields['_requires_fuse'] = config.pop('_requires_fuse', None) - resources_fields['_cluster_config_override'] = config.pop( - '_cluster_config_override', None) + resources_fields['_cluster_config_overrides'] = config.pop( + '_cluster_config_overrides', None) if resources_fields['cpus'] is not None: resources_fields['cpus'] = str(resources_fields['cpus']) @@ -1459,8 +1459,8 @@ def add_if_not_none(key, value): if self._docker_login_config is not None: config['_docker_login_config'] = dataclasses.asdict( self._docker_login_config) - add_if_not_none('_cluster_config_override', - self._cluster_config_override) + add_if_not_none('_cluster_config_overrides', + self._cluster_config_overrides) if self._is_image_managed is not None: config['_is_image_managed'] = self._is_image_managed if self._requires_fuse is not None: @@ -1577,7 +1577,7 @@ def __setstate__(self, state): self._job_recovery = state.pop('_spot_recovery', None) if version < 19: - self._cluster_config_override = state.pop( - '_cluster_config_override', None) + self._cluster_config_overrides = state.pop( + '_cluster_config_overrides', None) self.__dict__.update(state) diff --git a/sky/task.py b/sky/task.py index 062af91fde7..b11f1428cd3 100644 --- a/sky/task.py +++ b/sky/task.py @@ -470,11 +470,11 @@ def from_yaml_config( # Parse resources field. resources_config = config.pop('resources', {}) if cluster_config_override is not None: - assert resources_config.get('_cluster_config_override') is None, ( - 'Cannot set _cluster_config_override in both resources and ' + assert resources_config.get('_cluster_config_overrides') is None, ( + 'Cannot set _cluster_config_overrides in both resources and ' 'experimental.config_overrides') resources_config[ - '_cluster_config_override'] = cluster_config_override + '_cluster_config_overrides'] = cluster_config_override task.set_resources(sky.Resources.from_yaml_config(resources_config)) service = config.pop('service', None) diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 50316bca0a5..2139571796a 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -170,7 +170,7 @@ def _get_single_resources_schema(): '_requires_fuse': { 'type': 'boolean', }, - '_cluster_config_override': { + '_cluster_config_overrides': { 'type': 'object', }, } diff --git a/tests/test_config.py b/tests/test_config.py index a461c6605c2..6791cd0d11e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -4,7 +4,9 @@ import pytest +import sky from sky import skypilot_config +from sky.skylet import constants from sky.utils import common_utils from sky.utils import kubernetes_enums @@ -12,6 +14,9 @@ PROXY_COMMAND = 'ssh -W %h:%p -i ~/.ssh/id_rsa -o StrictHostKeyChecking=no' NODEPORT_MODE_NAME = kubernetes_enums.KubernetesNetworkingMode.NODEPORT.value PORT_FORWARD_MODE_NAME = kubernetes_enums.KubernetesNetworkingMode.PORTFORWARD.value +RUN_DURATION = 30 +OVERRIDE_RUN_DURATION = 10 +PROVISION_TIMEOUT = 600 def _reload_config() -> None: @@ -41,9 +46,22 @@ def _create_config_file(config_file_path: pathlib.Path) -> None: gcp: vpc_name: {VPC_NAME} use_internal_ips: true + managed_instance_group: + run_duration: {RUN_DURATION} + provision_timeout: {PROVISION_TIMEOUT} kubernetes: networking: {NODEPORT_MODE_NAME} + pod_config: + spec: + metadata: + annotations: + my_annotation: my_value + runtimeClassName: nvidia # Custom runtimeClassName for GPU pods. + imagePullSecrets: + - name: my-secret # Pull images from a private registry using a secret + + allowed_clouds: ['aws', 'gcp', 'kubernetes'] """)) @@ -232,11 +250,98 @@ def test_config_with_env(monkeypatch, tmp_path) -> None: assert skypilot_config.get_nested(('gcp', 'use_internal_ips'), None) -def test_config_with_override(monkeypatch, tmp_path) -> None: +def test_config_with_override(monkeypatch, tmp_path, enable_all_clouds) -> None: config_path = tmp_path / 'config.yaml' _create_config_file(config_path) monkeypatch.setattr(skypilot_config, 'CONFIG_PATH', config_path) _reload_config() - \ No newline at end of file + task_config_yaml = textwrap.dedent(f"""\ + experimental: + config_overrides: + docker: + run_options: + - -v /tmp:/tmp + kubernetes: + pod_config: + metadata: + labels: + test-key: test-value + annotations: + abc: def + spec: + imagePullSecrets: + - name: my-secret-2 + gcp: + managed_instance_group: + run_duration: {OVERRIDE_RUN_DURATION} + nvidia_gpus: + disable_ecc: true + resources: + image_id: docker:ubuntu:latest + + setup: echo 'Setting up...' + run: echo 'Running...' + """) + task_path = tmp_path / 'task.yaml' + task_path.write_text(task_config_yaml) + task = sky.Task.from_yaml(task_path) + + # Test Kubernetes overrides + # Get cluster YAML + cluster_name = 'test-kubernetes-config-with-override' + task.set_resources_override({'cloud': sky.Kubernetes()}) + sky.launch(task, cluster_name=cluster_name, dryrun=True) + cluster_yaml = pathlib.Path(f'~/.sky/generated/{cluster_name}.yml').expanduser().rename(tmp_path / cluster_name + '.yml') + + # Load the cluster YAML + cluster_config = common_utils.read_yaml(cluster_yaml) + head_node_type = cluster_config['head_node_type'] + cluster_pod_config = cluster_config['available_node_types'][head_node_type]['node_config'] + assert cluster_pod_config['metadata']['labels']['test-key'] == 'test-value' + assert cluster_pod_config['metadata']['labels']['parent'] == 'skypilot' + assert cluster_pod_config['metadata']['annotations']['abc'] == 'def' + assert len(cluster_pod_config['spec']['imagePullSecrets']) == 1 and cluster_pod_config['spec']['imagePullSecrets'][0]['name'] == 'my-secret-2' + assert cluster_pod_config['spec']['runtimeClassName'] == 'nvidia' + + + # Test GCP overrides + cluster_name = 'test-gcp-config-with-override' + task.set_resources_override({'cloud': sky.GCP()}) + sky.launch(task, cluster_name=cluster_name, dryrun=True) + cluster_yaml = pathlib.Path(f'~/.sky/generated/{cluster_name}.yml').expanduser().rename(tmp_path / cluster_name + '.yml') + + # Load the cluster YAML + cluster_config = common_utils.read_yaml(cluster_yaml) + assert cluster_config['provider']['vpc_name'] == VPC_NAME + assert '-v /tmp:/tmp' in cluster_config['docker']['run_options'] + assert constants.DISABLE_GPU_ECC_COMMAND in cluster_config['setup_commands'] + head_node_type = cluster_config['head_node_type'] + cluster_node_config = cluster_config['available_node_types'][head_node_type]['node_config'] + assert cluster_node_config['managed-instance-group']['run_duration'] == RUN_DURATION + assert cluster_node_config['managed-instance-group']['provision-timeout'] == PROVISION_TIMEOUT + +def test_config_with_invalid_override(monkeypatch, tmp_path, enable_all_clouds) -> None: + config_path = tmp_path / 'config.yaml' + _create_config_file(config_path) + monkeypatch.setattr(skypilot_config, 'CONFIG_PATH', config_path) + + _reload_config() + + task_config_yaml = textwrap.dedent(f"""\ + experimental: + config_overrides: + gcp: + vpc_name: abc + resources: + image_id: docker:ubuntu:latest + + setup: echo 'Setting up...' + run: echo 'Running...' + """) + + with pytest.raises(ValueError, match='Found unsupported') as e: + task_path = tmp_path / 'task.yaml' + task_path.write_text(task_config_yaml) + sky.Task.from_yaml(task_path) From 96abfd4fdb41ad09a241b0ac63cd0e9cf3ad00c7 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Tue, 2 Jul 2024 18:07:31 +0000 Subject: [PATCH 11/22] Fixes and add tests --- sky/backends/backend_utils.py | 25 ++------- sky/provision/kubernetes/utils.py | 7 ++- tests/conftest.py | 2 +- tests/test_config.py | 91 +++++++++++++++++-------------- 4 files changed, 62 insertions(+), 63 deletions(-) diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index b4e9a335f11..b80cf667413 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -874,17 +874,6 @@ def write_cluster_config( f'open(os.path.expanduser("{constants.SKY_REMOTE_RAY_PORT_FILE}"), "w", encoding="utf-8"))\'' ) - # Docker run options - docker_run_options = skypilot_config.get_nested(('docker', 'run_options'), - []) - if isinstance(docker_run_options, str): - docker_run_options = [docker_run_options] - if docker_run_options and isinstance(to_provision.cloud, clouds.Kubernetes): - logger.warning(f'{colorama.Style.DIM}Docker run options are specified, ' - 'but ignored for Kubernetes: ' - f'{" ".join(docker_run_options)}' - f'{colorama.Style.RESET_ALL}') - # Use a tmp file path to avoid incomplete YAML file being re-used in the # future. tmp_yaml_path = yaml_path + '.tmp' @@ -929,9 +918,6 @@ def write_cluster_config( wheel_hash).replace('{cloud}', str(cloud).lower())), - # Docker - 'docker_run_options': docker_run_options, - # Port of Ray (GCS server). # Ray's default port 6379 is conflicted with Redis. 'ray_port': constants.SKY_REMOTE_RAY_PORT, @@ -970,11 +956,6 @@ def write_cluster_config( output_path=tmp_yaml_path) config_dict['cluster_name'] = cluster_name config_dict['ray'] = yaml_path - if dryrun: - # If dryrun, return the unfinished tmp yaml path. - config_dict['ray'] = tmp_yaml_path - return config_dict - _add_auth_to_cluster_config(cloud, tmp_yaml_path) # Add kubernetes config fields from ~/.sky/config if isinstance(cloud, clouds.Kubernetes): @@ -983,6 +964,12 @@ def write_cluster_config( cluster_config_overrides=to_provision.cluster_config_overrides) kubernetes_utils.combine_metadata_fields(tmp_yaml_path) + if dryrun: + # If dryrun, return the unfinished tmp yaml path. + config_dict['ray'] = tmp_yaml_path + return config_dict + _add_auth_to_cluster_config(cloud, tmp_yaml_path) + # Restore the old yaml content for backward compatibility. if os.path.exists(yaml_path) and keep_launch_fields_in_existing_config: with open(yaml_path, 'r', encoding='utf-8') as f: diff --git a/sky/provision/kubernetes/utils.py b/sky/provision/kubernetes/utils.py index ebd743f55ad..e3d4132e3ab 100644 --- a/sky/provision/kubernetes/utils.py +++ b/sky/provision/kubernetes/utils.py @@ -1362,9 +1362,10 @@ def merge_dicts(source: Dict[Any, Any], destination: Dict[Any, Any]): elif isinstance(value, list) and key in destination: assert isinstance(destination[key], list), \ f'Expected {key} to be a list, found {destination[key]}' - if key == 'containers': - # If the key is 'containers', we take the first and only - # container in the list and merge it. + if key in ['containers', 'imagePullSecrets']: + # If the key is 'containers' or 'imagePullSecrets, we take the + # first and only container/secret in the list and merge it, as + # we only support one container per pod. assert len(value) == 1, \ f'Expected only one container, found {value}' merge_dicts(value[0], destination[key][0]) diff --git a/tests/conftest.py b/tests/conftest.py index ce92afd88c7..b4e025a8f2d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -199,7 +199,7 @@ def generic_cloud(request) -> str: @pytest.fixture -def enable_all_clouds(monkeypatch: pytest.MonkeyPatch): +def enable_all_clouds(monkeypatch: pytest.MonkeyPatch) -> None: common.enable_all_clouds_in_monkeypatch(monkeypatch) diff --git a/tests/test_config.py b/tests/test_config.py index 6791cd0d11e..595ad420a4a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -15,7 +15,7 @@ NODEPORT_MODE_NAME = kubernetes_enums.KubernetesNetworkingMode.NODEPORT.value PORT_FORWARD_MODE_NAME = kubernetes_enums.KubernetesNetworkingMode.PORTFORWARD.value RUN_DURATION = 30 -OVERRIDE_RUN_DURATION = 10 +RUN_DURATION_OVERRIDE = 10 PROVISION_TIMEOUT = 600 @@ -36,7 +36,7 @@ def _check_empty_config() -> None: def _create_config_file(config_file_path: pathlib.Path) -> None: - config_file_path.open('w', encoding='utf-8').write( + config_file_path.write_text( textwrap.dedent(f"""\ aws: vpc_name: {VPC_NAME} @@ -61,9 +61,38 @@ def _create_config_file(config_file_path: pathlib.Path) -> None: imagePullSecrets: - name: my-secret # Pull images from a private registry using a secret - allowed_clouds: ['aws', 'gcp', 'kubernetes'] """)) +def _create_task_yaml_file(task_file_path: pathlib.Path) -> None: + task_file_path.write_text(textwrap.dedent(f"""\ + experimental: + config_overrides: + docker: + run_options: + - -v /tmp:/tmp + kubernetes: + pod_config: + metadata: + labels: + test-key: test-value + annotations: + abc: def + spec: + imagePullSecrets: + - name: my-secret-2 + gcp: + managed_instance_group: + run_duration: {RUN_DURATION_OVERRIDE} + nvidia_gpus: + disable_ecc: true + resources: + image_id: docker:ubuntu:latest + + setup: echo 'Setting up...' + run: echo 'Running...' + """)) + + def test_no_config(monkeypatch) -> None: """Test that the config is not loaded if the config file does not exist.""" @@ -250,42 +279,14 @@ def test_config_with_env(monkeypatch, tmp_path) -> None: assert skypilot_config.get_nested(('gcp', 'use_internal_ips'), None) -def test_config_with_override(monkeypatch, tmp_path, enable_all_clouds) -> None: +def test_k8s_config_with_override(monkeypatch, tmp_path, enable_all_clouds) -> None: config_path = tmp_path / 'config.yaml' _create_config_file(config_path) monkeypatch.setattr(skypilot_config, 'CONFIG_PATH', config_path) _reload_config() - - task_config_yaml = textwrap.dedent(f"""\ - experimental: - config_overrides: - docker: - run_options: - - -v /tmp:/tmp - kubernetes: - pod_config: - metadata: - labels: - test-key: test-value - annotations: - abc: def - spec: - imagePullSecrets: - - name: my-secret-2 - gcp: - managed_instance_group: - run_duration: {OVERRIDE_RUN_DURATION} - nvidia_gpus: - disable_ecc: true - resources: - image_id: docker:ubuntu:latest - - setup: echo 'Setting up...' - run: echo 'Running...' - """) task_path = tmp_path / 'task.yaml' - task_path.write_text(task_config_yaml) + _create_task_yaml_file(task_path) task = sky.Task.from_yaml(task_path) # Test Kubernetes overrides @@ -293,7 +294,7 @@ def test_config_with_override(monkeypatch, tmp_path, enable_all_clouds) -> None: cluster_name = 'test-kubernetes-config-with-override' task.set_resources_override({'cloud': sky.Kubernetes()}) sky.launch(task, cluster_name=cluster_name, dryrun=True) - cluster_yaml = pathlib.Path(f'~/.sky/generated/{cluster_name}.yml').expanduser().rename(tmp_path / cluster_name + '.yml') + cluster_yaml = pathlib.Path(f'~/.sky/generated/{cluster_name}.yml.tmp').expanduser().rename(tmp_path / (cluster_name + '.yml')) # Load the cluster YAML cluster_config = common_utils.read_yaml(cluster_yaml) @@ -306,21 +307,31 @@ def test_config_with_override(monkeypatch, tmp_path, enable_all_clouds) -> None: assert cluster_pod_config['spec']['runtimeClassName'] == 'nvidia' +def test_gcp_config_with_override(monkeypatch, tmp_path, enable_all_clouds) -> None: + config_path = tmp_path / 'config.yaml' + _create_config_file(config_path) + monkeypatch.setattr(skypilot_config, 'CONFIG_PATH', config_path) + + _reload_config() + task_path = tmp_path / 'task.yaml' + _create_task_yaml_file(task_path) + task = sky.Task.from_yaml(task_path) + # Test GCP overrides cluster_name = 'test-gcp-config-with-override' - task.set_resources_override({'cloud': sky.GCP()}) + task.set_resources_override({'cloud': sky.GCP(), 'accelerators': 'L4'}) sky.launch(task, cluster_name=cluster_name, dryrun=True) - cluster_yaml = pathlib.Path(f'~/.sky/generated/{cluster_name}.yml').expanduser().rename(tmp_path / cluster_name + '.yml') + cluster_yaml = pathlib.Path(f'~/.sky/generated/{cluster_name}.yml.tmp').expanduser().rename(tmp_path / (cluster_name + '.yml')) # Load the cluster YAML cluster_config = common_utils.read_yaml(cluster_yaml) assert cluster_config['provider']['vpc_name'] == VPC_NAME - assert '-v /tmp:/tmp' in cluster_config['docker']['run_options'] - assert constants.DISABLE_GPU_ECC_COMMAND in cluster_config['setup_commands'] + assert '-v /tmp:/tmp' in cluster_config['docker']['run_options'], cluster_config + assert constants.DISABLE_GPU_ECC_COMMAND in cluster_config['setup_commands'][0] head_node_type = cluster_config['head_node_type'] cluster_node_config = cluster_config['available_node_types'][head_node_type]['node_config'] - assert cluster_node_config['managed-instance-group']['run_duration'] == RUN_DURATION - assert cluster_node_config['managed-instance-group']['provision-timeout'] == PROVISION_TIMEOUT + assert cluster_node_config['managed-instance-group']['run_duration'] == RUN_DURATION_OVERRIDE + assert cluster_node_config['managed-instance-group']['provision_timeout'] == PROVISION_TIMEOUT def test_config_with_invalid_override(monkeypatch, tmp_path, enable_all_clouds) -> None: config_path = tmp_path / 'config.yaml' From 70f02814de91e5a307559c3fc49d41a87d39e428 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Tue, 2 Jul 2024 18:10:57 +0000 Subject: [PATCH 12/22] format --- tests/test_config.py | 53 +++++++++++++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 18 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index 595ad420a4a..e47d31b045b 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -63,8 +63,10 @@ def _create_config_file(config_file_path: pathlib.Path) -> None: """)) + def _create_task_yaml_file(task_file_path: pathlib.Path) -> None: - task_file_path.write_text(textwrap.dedent(f"""\ + task_file_path.write_text( + textwrap.dedent(f"""\ experimental: config_overrides: docker: @@ -91,7 +93,6 @@ def _create_task_yaml_file(task_file_path: pathlib.Path) -> None: setup: echo 'Setting up...' run: echo 'Running...' """)) - def test_no_config(monkeypatch) -> None: @@ -279,7 +280,8 @@ def test_config_with_env(monkeypatch, tmp_path) -> None: assert skypilot_config.get_nested(('gcp', 'use_internal_ips'), None) -def test_k8s_config_with_override(monkeypatch, tmp_path, enable_all_clouds) -> None: +def test_k8s_config_with_override(monkeypatch, tmp_path, + enable_all_clouds) -> None: config_path = tmp_path / 'config.yaml' _create_config_file(config_path) monkeypatch.setattr(skypilot_config, 'CONFIG_PATH', config_path) @@ -294,20 +296,26 @@ def test_k8s_config_with_override(monkeypatch, tmp_path, enable_all_clouds) -> N cluster_name = 'test-kubernetes-config-with-override' task.set_resources_override({'cloud': sky.Kubernetes()}) sky.launch(task, cluster_name=cluster_name, dryrun=True) - cluster_yaml = pathlib.Path(f'~/.sky/generated/{cluster_name}.yml.tmp').expanduser().rename(tmp_path / (cluster_name + '.yml')) - + cluster_yaml = pathlib.Path( + f'~/.sky/generated/{cluster_name}.yml.tmp').expanduser().rename( + tmp_path / (cluster_name + '.yml')) + # Load the cluster YAML cluster_config = common_utils.read_yaml(cluster_yaml) head_node_type = cluster_config['head_node_type'] - cluster_pod_config = cluster_config['available_node_types'][head_node_type]['node_config'] - assert cluster_pod_config['metadata']['labels']['test-key'] == 'test-value' - assert cluster_pod_config['metadata']['labels']['parent'] == 'skypilot' + cluster_pod_config = cluster_config['available_node_types'][head_node_type][ + 'node_config'] + assert cluster_pod_config['metadata']['labels']['test-key'] == 'test-value' + assert cluster_pod_config['metadata']['labels']['parent'] == 'skypilot' assert cluster_pod_config['metadata']['annotations']['abc'] == 'def' - assert len(cluster_pod_config['spec']['imagePullSecrets']) == 1 and cluster_pod_config['spec']['imagePullSecrets'][0]['name'] == 'my-secret-2' + assert len(cluster_pod_config['spec'] + ['imagePullSecrets']) == 1 and cluster_pod_config['spec'][ + 'imagePullSecrets'][0]['name'] == 'my-secret-2' assert cluster_pod_config['spec']['runtimeClassName'] == 'nvidia' -def test_gcp_config_with_override(monkeypatch, tmp_path, enable_all_clouds) -> None: +def test_gcp_config_with_override(monkeypatch, tmp_path, + enable_all_clouds) -> None: config_path = tmp_path / 'config.yaml' _create_config_file(config_path) monkeypatch.setattr(skypilot_config, 'CONFIG_PATH', config_path) @@ -316,24 +324,33 @@ def test_gcp_config_with_override(monkeypatch, tmp_path, enable_all_clouds) -> N task_path = tmp_path / 'task.yaml' _create_task_yaml_file(task_path) task = sky.Task.from_yaml(task_path) - + # Test GCP overrides cluster_name = 'test-gcp-config-with-override' task.set_resources_override({'cloud': sky.GCP(), 'accelerators': 'L4'}) sky.launch(task, cluster_name=cluster_name, dryrun=True) - cluster_yaml = pathlib.Path(f'~/.sky/generated/{cluster_name}.yml.tmp').expanduser().rename(tmp_path / (cluster_name + '.yml')) + cluster_yaml = pathlib.Path( + f'~/.sky/generated/{cluster_name}.yml.tmp').expanduser().rename( + tmp_path / (cluster_name + '.yml')) # Load the cluster YAML cluster_config = common_utils.read_yaml(cluster_yaml) assert cluster_config['provider']['vpc_name'] == VPC_NAME - assert '-v /tmp:/tmp' in cluster_config['docker']['run_options'], cluster_config - assert constants.DISABLE_GPU_ECC_COMMAND in cluster_config['setup_commands'][0] + assert '-v /tmp:/tmp' in cluster_config['docker'][ + 'run_options'], cluster_config + assert constants.DISABLE_GPU_ECC_COMMAND in cluster_config[ + 'setup_commands'][0] head_node_type = cluster_config['head_node_type'] - cluster_node_config = cluster_config['available_node_types'][head_node_type]['node_config'] - assert cluster_node_config['managed-instance-group']['run_duration'] == RUN_DURATION_OVERRIDE - assert cluster_node_config['managed-instance-group']['provision_timeout'] == PROVISION_TIMEOUT + cluster_node_config = cluster_config['available_node_types'][ + head_node_type]['node_config'] + assert cluster_node_config['managed-instance-group'][ + 'run_duration'] == RUN_DURATION_OVERRIDE + assert cluster_node_config['managed-instance-group'][ + 'provision_timeout'] == PROVISION_TIMEOUT + -def test_config_with_invalid_override(monkeypatch, tmp_path, enable_all_clouds) -> None: +def test_config_with_invalid_override(monkeypatch, tmp_path, + enable_all_clouds) -> None: config_path = tmp_path / 'config.yaml' _create_config_file(config_path) monkeypatch.setattr(skypilot_config, 'CONFIG_PATH', config_path) From d0503283c75131d9cb10cdc503d2ad26cc9c6caf Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 8 Jul 2024 01:43:30 +0000 Subject: [PATCH 13/22] Assert for override configs specification --- sky/resources.py | 12 ++--------- sky/skylet/constants.py | 12 +++++++++++ sky/skypilot_config.py | 45 ++++++++++++++++++++++++++++++++--------- sky/utils/schemas.py | 9 +++++---- 4 files changed, 54 insertions(+), 24 deletions(-) diff --git a/sky/resources.py b/sky/resources.py index 1f426646dfc..38f7a9784e6 100644 --- a/sky/resources.py +++ b/sky/resources.py @@ -26,14 +26,6 @@ _DEFAULT_DISK_SIZE_GB = 256 -OVERRIDEABLE_CONFIG_KEYS = [ - ('docker',), - ('nvidia_gpus',), - ('kubernetes', 'pod_config'), - ('kubernetes', 'provision_timeout'), - ('gcp', 'managed_instance_group'), -] - class Resources: """Resources: compute requirements of Tasks. @@ -1033,7 +1025,7 @@ def make_deploy_variables(self, cluster_name_on_cloud: str, if (skypilot_config.get_nested( ('nvidia_gpus', 'disable_ecc'), False, - override_configs=self._cluster_config_overrides) and + override_configs=self.cluster_config_overrides) and self.accelerators is not None): initial_setup_commands = [constants.DISABLE_GPU_ECC_COMMAND] @@ -1041,7 +1033,7 @@ def make_deploy_variables(self, cluster_name_on_cloud: str, docker_run_options = skypilot_config.get_nested( ('docker', 'run_options'), default_value=[], - override_configs=self._cluster_config_overrides) + override_configs=self.cluster_config_overrides) if isinstance(docker_run_options, str): docker_run_options = [docker_run_options] if docker_run_options and isinstance(self.cloud, clouds.Kubernetes): diff --git a/sky/skylet/constants.py b/sky/skylet/constants.py index c456b48b306..359914b51f9 100644 --- a/sky/skylet/constants.py +++ b/sky/skylet/constants.py @@ -1,4 +1,6 @@ """Constants for SkyPilot.""" +from typing import List, Tuple + from packaging import version import sky @@ -261,3 +263,13 @@ # Placeholder for the SSH user in proxy command, replaced when the ssh_user is # known after provisioning. SKY_SSH_USER_PLACEHOLDER = 'skypilot:ssh_user' + +# The keys that can be overridden in the `~/.sky/config.yaml` file. The +# overrides are specified in task YAMLs. +OVERRIDEABLE_CONFIG_KEYS: List[Tuple[str, ...]] = [ + ('docker', 'run_options'), + ('nvidia_gpus', 'disable_ecc'), + ('kubernetes', 'pod_config'), + ('kubernetes', 'provision_timeout'), + ('gcp', 'managed_instance_group'), +] diff --git a/sky/skypilot_config.py b/sky/skypilot_config.py index edf85cc9e68..53ddc0edc21 100644 --- a/sky/skypilot_config.py +++ b/sky/skypilot_config.py @@ -44,11 +44,12 @@ import copy import os import pprint -from typing import Any, Dict, Iterable, Optional +from typing import Any, Dict, Iterable, Optional, Tuple import yaml from sky import sky_logging +from sky.skylet import constants from sky.utils import common_utils from sky.utils import schemas from sky.utils import ux_utils @@ -73,7 +74,7 @@ logger = sky_logging.init_logger(__name__) # The loaded config. -_dict = None +_dict: Optional[Dict[str, Any]] = None _loaded_config_path = None @@ -98,14 +99,38 @@ def get_nested(keys: Iterable[str], If any key is not found, or any intermediate key does not point to a dict value, returns 'default_value'. + + When 'keys' is within OVERRIDEABLE_CONFIG_KEYS, 'override_configs' must be + provided (can be empty). Otherwise, 'override_configs' must not be provided. + + Args: + keys: A tuple of strings representing the nested keys. + default_value: The default value to return if the key is not found. + override_configs: A dict of override configs with the same schema as + the config file, but only containing the keys to override. + + Returns: + The value of the nested key, or 'default_value' if not found. """ - # TODO (zhwu): Verify that the override_configs is provided when keys is - # within resources.OVERRIDEABLE_CONFIG_KEYS. - if _dict is None: - if override_configs is not None: - return _get_nested(override_configs, keys, default_value) - return default_value - config = _recursive_update(copy.deepcopy(_dict), override_configs or {}) + assert ( + keys in constants.OVERRIDEABLE_CONFIG_KEYS or + override_configs is not None + ), (f'Override configs must be provided when keys {keys} is within ' + 'constants.OVERRIDEABLE_CONFIG_KEYS: ' + f'{constants.OVERRIDEABLE_CONFIG_KEYS}' + ) + assert ( + keys in constants.OVERRIDEABLE_CONFIG_KEYS or override_configs is None + ), (f'Override configs must not be provided when keys {keys} is not within ' + 'constants.OVERRIDEABLE_CONFIG_KEYS: ' + f'{constants.OVERRIDEABLE_CONFIG_KEYS}' + ) + config: Dict[str, Any] = {} + if _dict is not None: + config = copy.deepcopy(_dict) + if override_configs is None: + override_configs = {} + config = _recursive_update(config, override_configs) return _get_nested(config, keys, default_value) @@ -121,7 +146,7 @@ def _recursive_update(base_config: Dict[str, Any], return base_config -def set_nested(keys: Iterable[str], value: Any) -> Dict[str, Any]: +def set_nested(keys: Tuple[str, ...], value: Any) -> Dict[str, Any]: """Returns a deep-copied config with the nested key set to value. Like get_nested(), if any key is not found, this will not raise an error. diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 2139571796a..1193c71ac2a 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -4,7 +4,9 @@ https://json-schema.org/ """ import enum -from typing import Any, Dict +from typing import Any, Dict, List, Tuple + +from sky.skylet import constants def _check_not_both_fields_present(field1: str, field2: str): @@ -375,7 +377,7 @@ def get_service_schema(): } -def _filter_schema(schema: dict, keys_to_keep: dict) -> dict: +def _filter_schema(schema: dict, keys_to_keep: List[Tuple[str, ...]]) -> dict: """Recursively filter a schema to include only certain keys. Args: @@ -424,9 +426,8 @@ def keep_keys(current_schema: dict, current_path_dict: dict, def _experimental_task_schema() -> dict: - from sky import resources # pylint: disable=import-outside-toplevel config_override_schema = _filter_schema(get_config_schema(), - resources.OVERRIDEABLE_CONFIG_KEYS) + constants.OVERRIDEABLE_CONFIG_KEYS) return { 'experimental': { 'type': 'object', From dd5565fa5b77b68d04f670aa08fc2d902e367b17 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 8 Jul 2024 01:45:27 +0000 Subject: [PATCH 14/22] format --- sky/skypilot_config.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/sky/skypilot_config.py b/sky/skypilot_config.py index 53ddc0edc21..8525635fad1 100644 --- a/sky/skypilot_config.py +++ b/sky/skypilot_config.py @@ -112,19 +112,16 @@ def get_nested(keys: Iterable[str], Returns: The value of the nested key, or 'default_value' if not found. """ - assert ( - keys in constants.OVERRIDEABLE_CONFIG_KEYS or - override_configs is not None - ), (f'Override configs must be provided when keys {keys} is within ' - 'constants.OVERRIDEABLE_CONFIG_KEYS: ' - f'{constants.OVERRIDEABLE_CONFIG_KEYS}' - ) + assert (keys in constants.OVERRIDEABLE_CONFIG_KEYS or + override_configs is not None), ( + f'Override configs must be provided when keys {keys} is within ' + 'constants.OVERRIDEABLE_CONFIG_KEYS: ' + f'{constants.OVERRIDEABLE_CONFIG_KEYS}') assert ( keys in constants.OVERRIDEABLE_CONFIG_KEYS or override_configs is None ), (f'Override configs must not be provided when keys {keys} is not within ' 'constants.OVERRIDEABLE_CONFIG_KEYS: ' - f'{constants.OVERRIDEABLE_CONFIG_KEYS}' - ) + f'{constants.OVERRIDEABLE_CONFIG_KEYS}') config: Dict[str, Any] = {} if _dict is not None: config = copy.deepcopy(_dict) From 9c7b27b2d392738d8b25f7a2e3871a009d7a06c0 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 8 Jul 2024 02:01:52 +0000 Subject: [PATCH 15/22] Add comments --- sky/skypilot_config.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sky/skypilot_config.py b/sky/skypilot_config.py index 8525635fad1..48b98bbf3b3 100644 --- a/sky/skypilot_config.py +++ b/sky/skypilot_config.py @@ -1,7 +1,7 @@ """Immutable user configurations (EXPERIMENTAL). -On module import, we attempt to parse the config located at CONFIG_PATH. Caller -can then use +On module import, we attempt to parse the config located at CONFIG_PATH +(default: ~/.sky/config.yaml). Caller can then use >> skypilot_config.loaded() @@ -11,6 +11,13 @@ >> skypilot_config.get_nested(('auth', 'some_auth_config'), default_value) +The config can be overridden by the configs in task YAMLs. Callers are +responsible to provide the override_configs. If the nested key is part of +OVERRIDEABLE_CONFIG_KEYS, override_configs must be provided (can be empty): + + >> skypilot_config.get_nested(('docker', 'run_options'), default_value + override_configs={'docker': {'run_options': 'value'}}) + To set a value in the nested-key config: >> config_dict = skypilot_config.set_nested(('auth', 'some_key'), value) From 826bab3027f74ca24020747d97c344828d19cf68 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 8 Jul 2024 02:18:37 +0000 Subject: [PATCH 16/22] fix --- sky/skypilot_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/skypilot_config.py b/sky/skypilot_config.py index 48b98bbf3b3..d051590cafb 100644 --- a/sky/skypilot_config.py +++ b/sky/skypilot_config.py @@ -119,7 +119,7 @@ def get_nested(keys: Iterable[str], Returns: The value of the nested key, or 'default_value' if not found. """ - assert (keys in constants.OVERRIDEABLE_CONFIG_KEYS or + assert (keys not in constants.OVERRIDEABLE_CONFIG_KEYS or override_configs is not None), ( f'Override configs must be provided when keys {keys} is within ' 'constants.OVERRIDEABLE_CONFIG_KEYS: ' From d858f3430828bdd83cf6123d4a54fde0935f1e8e Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 8 Jul 2024 02:21:02 +0000 Subject: [PATCH 17/22] fix assertions --- sky/skypilot_config.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/sky/skypilot_config.py b/sky/skypilot_config.py index d051590cafb..b7d53375f23 100644 --- a/sky/skypilot_config.py +++ b/sky/skypilot_config.py @@ -119,13 +119,15 @@ def get_nested(keys: Iterable[str], Returns: The value of the nested key, or 'default_value' if not found. """ - assert (keys not in constants.OVERRIDEABLE_CONFIG_KEYS or - override_configs is not None), ( - f'Override configs must be provided when keys {keys} is within ' - 'constants.OVERRIDEABLE_CONFIG_KEYS: ' - f'{constants.OVERRIDEABLE_CONFIG_KEYS}') - assert ( - keys in constants.OVERRIDEABLE_CONFIG_KEYS or override_configs is None + assert not ( + keys in constants.OVERRIDEABLE_CONFIG_KEYS and + override_configs is None), ( + f'Override configs must be provided when keys {keys} is within ' + 'constants.OVERRIDEABLE_CONFIG_KEYS: ' + f'{constants.OVERRIDEABLE_CONFIG_KEYS}') + assert not ( + keys not in constants.OVERRIDEABLE_CONFIG_KEYS and + override_configs is not None ), (f'Override configs must not be provided when keys {keys} is not within ' 'constants.OVERRIDEABLE_CONFIG_KEYS: ' f'{constants.OVERRIDEABLE_CONFIG_KEYS}') From eee288187136c9df124d15a8f9663c1a59d3b854 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 8 Jul 2024 02:25:51 +0000 Subject: [PATCH 18/22] fix assertions --- sky/check.py | 4 ++-- sky/clouds/kubernetes.py | 2 +- sky/provision/kubernetes/utils.py | 2 +- sky/skypilot_config.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sky/check.py b/sky/check.py index e8a61317d63..c361c962c94 100644 --- a/sky/check.py +++ b/sky/check.py @@ -77,8 +77,8 @@ def get_all_clouds(): # Use allowed_clouds from config if it exists, otherwise check all clouds. # Also validate names with get_cloud_tuple. config_allowed_cloud_names = [ - get_cloud_tuple(c)[0] for c in skypilot_config.get_nested( - ['allowed_clouds'], get_all_clouds()) + get_cloud_tuple(c)[0] for c in skypilot_config.get_nested(( + 'allowed_clouds',), get_all_clouds()) ] # Use disallowed_cloud_names for logging the clouds that will be disabled # because they are not included in allowed_clouds in config.yaml. diff --git a/sky/clouds/kubernetes.py b/sky/clouds/kubernetes.py index 64570de78e8..78471e0de9f 100644 --- a/sky/clouds/kubernetes.py +++ b/sky/clouds/kubernetes.py @@ -308,7 +308,7 @@ def make_deploy_resources_variables( # itself, which can be upto 2-3 seconds. # For non-autoscaling clusters, we conservatively set this to 10s. timeout = skypilot_config.get_nested( - ['kubernetes', 'provision_timeout'], + ('kubernetes', 'provision_timeout'), 10, override_configs=resources.cluster_config_overrides) deploy_vars = { diff --git a/sky/provision/kubernetes/utils.py b/sky/provision/kubernetes/utils.py index e3d4132e3ab..17d2d4039e1 100644 --- a/sky/provision/kubernetes/utils.py +++ b/sky/provision/kubernetes/utils.py @@ -1563,7 +1563,7 @@ def get_head_pod_name(cluster_name_on_cloud: str): def get_autoscaler_type( ) -> Optional[kubernetes_enums.KubernetesAutoscalerType]: """Returns the autoscaler type by reading from config""" - autoscaler_type = skypilot_config.get_nested(['kubernetes', 'autoscaler'], + autoscaler_type = skypilot_config.get_nested(('kubernetes', 'autoscaler'), None) if autoscaler_type is not None: autoscaler_type = kubernetes_enums.KubernetesAutoscalerType( diff --git a/sky/skypilot_config.py b/sky/skypilot_config.py index b7d53375f23..52e1d0ae3d9 100644 --- a/sky/skypilot_config.py +++ b/sky/skypilot_config.py @@ -99,7 +99,7 @@ def _get_nested(configs: Optional[Dict[str, Any]], keys: Iterable[str], return curr -def get_nested(keys: Iterable[str], +def get_nested(keys: Tuple[str, ...], default_value: Any, override_configs: Optional[Dict[str, Any]] = None) -> Any: """Gets a nested key. From ebd9586f95ce3a84c412fc6be28a1b338da0efb6 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 8 Jul 2024 02:38:33 +0000 Subject: [PATCH 19/22] Fix test --- sky/provision/kubernetes/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sky/provision/kubernetes/utils.py b/sky/provision/kubernetes/utils.py index 17d2d4039e1..62cc738e9a6 100644 --- a/sky/provision/kubernetes/utils.py +++ b/sky/provision/kubernetes/utils.py @@ -1433,10 +1433,11 @@ def combine_pod_config_fields( with open(cluster_yaml_path, 'r', encoding='utf-8') as f: yaml_content = f.read() yaml_obj = yaml.safe_load(yaml_content) - # We don't use override_config in `skypilot_config.get_nested`, as merging + # We don't use override_configs in `skypilot_config.get_nested`, as merging # the pod config requires special handling. kubernetes_config = skypilot_config.get_nested(('kubernetes', 'pod_config'), - default_value={}) + default_value={}, + override_configs={}) override_pod_config = (cluster_config_overrides.get('kubernetes', {}).get( 'pod_config', {})) merge_dicts(override_pod_config, kubernetes_config) From 32eabf512891b939238cd4fb0c6f780ffb3444db Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 8 Jul 2024 06:51:27 +0000 Subject: [PATCH 20/22] fix --- sky/utils/schemas.py | 5 ++++- tests/test_config.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 1193c71ac2a..3c03cae9847 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -410,8 +410,11 @@ def keep_keys(current_schema: dict, current_path_dict: dict, if key != 'properties' } new_schema['properties'] = {} - for key, sub_schema in current_schema['properties'].items(): + assert key not in { + 'oneOf', 'anyOf', 'allOf' + }, ('Schema filtering does not work with oneOf, anyOf, allOf. ' + f'Key: {key}, Schema: {current_schema}') if key in current_path_dict: # Recursively keep keys if further path dict exists new_schema['properties'][key] = {} diff --git a/tests/test_config.py b/tests/test_config.py index e47d31b045b..c01f06d6fca 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -82,6 +82,7 @@ def _create_task_yaml_file(task_file_path: pathlib.Path) -> None: spec: imagePullSecrets: - name: my-secret-2 + provision_timeout: 100 gcp: managed_instance_group: run_duration: {RUN_DURATION_OVERRIDE} From 2fd40ea0b002cbd33328d4de8790a417e084226e Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 8 Jul 2024 07:07:26 +0000 Subject: [PATCH 21/22] remove unsupported keys --- sky/utils/schemas.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 3c03cae9847..de6703d8cf7 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -407,25 +407,26 @@ def keep_keys(current_schema: dict, current_path_dict: dict, new_schema = { key: current_schema[key] for key in current_schema - if key != 'properties' + # We do not support the handling of `oneOf`, `anyOf`, `allOf`, + # `required` for now. + if key not in {'properties', 'oneOf', 'anyOf', 'allOf', 'required'} } new_schema['properties'] = {} for key, sub_schema in current_schema['properties'].items(): - assert key not in { - 'oneOf', 'anyOf', 'allOf' - }, ('Schema filtering does not work with oneOf, anyOf, allOf. ' - f'Key: {key}, Schema: {current_schema}') if key in current_path_dict: # Recursively keep keys if further path dict exists new_schema['properties'][key] = {} + current_path_value = current_path_dict.pop(key) new_schema['properties'][key] = keep_keys( - sub_schema, current_path_dict[key], + sub_schema, current_path_value, new_schema['properties'][key]) return new_schema # Start the recursive filtering - return keep_keys(schema, paths_dict, {}) + new_schema = keep_keys(schema, paths_dict, {}) + assert not paths_dict, f'Unprocessed keys: {paths_dict}' + return new_schema def _experimental_task_schema() -> dict: From 2f75e25d938767ae856ad6e354faf64aa4bcddd2 Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 8 Jul 2024 07:07:32 +0000 Subject: [PATCH 22/22] format --- sky/utils/schemas.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index de6703d8cf7..d936625a618 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -409,7 +409,8 @@ def keep_keys(current_schema: dict, current_path_dict: dict, for key in current_schema # We do not support the handling of `oneOf`, `anyOf`, `allOf`, # `required` for now. - if key not in {'properties', 'oneOf', 'anyOf', 'allOf', 'required'} + if key not in + {'properties', 'oneOf', 'anyOf', 'allOf', 'required'} } new_schema['properties'] = {} for key, sub_schema in current_schema['properties'].items():