From 9c71b9747246a915d954807534a8dc4b0f0bc9d5 Mon Sep 17 00:00:00 2001 From: Jeremy Goodsitt Date: Sun, 12 May 2024 20:45:40 -0500 Subject: [PATCH] refactor: use ClusterName data class --- sky/backends/backend_utils.py | 11 +++- sky/backends/cloud_vm_ray_backend.py | 16 ++--- sky/cli.py | 2 +- sky/clouds/aws.py | 33 ++++++----- sky/clouds/azure.py | 4 +- sky/clouds/cloud.py | 6 +- sky/clouds/cudo.py | 4 +- sky/clouds/fluidstack.py | 5 +- sky/clouds/gcp.py | 25 ++++---- sky/clouds/ibm.py | 4 +- sky/clouds/kubernetes.py | 4 +- sky/clouds/lambda_cloud.py | 4 +- sky/clouds/oci.py | 4 +- sky/clouds/paperspace.py | 4 +- sky/clouds/runpod.py | 4 +- sky/clouds/scp.py | 4 +- sky/clouds/vsphere.py | 4 +- sky/provision/provisioner.py | 27 +++------ sky/resources.py | 18 +++++- sky/utils/resources_utils.py | 13 ++++ sky/utils/schemas.py | 88 +++++++++++----------------- 21 files changed, 148 insertions(+), 136 deletions(-) diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index cf43cfdf2ed..e76d29b2ed1 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -154,7 +154,11 @@ # we need to take this field from the new yaml. ('provider', 'tpu_node'), ('provider', 'security_group', 'GroupName'), + ('available_node_types', 'ray.head.default', 'node_config', + 'IamInstanceProfile'), ('available_node_types', 'ray.head.default', 'node_config', 'UserData'), + ('available_node_types', 'ray.worker.default', 'node_config', + 'IamInstanceProfile'), ('available_node_types', 'ray.worker.default', 'node_config', 'UserData'), ] @@ -792,8 +796,11 @@ def write_cluster_config( # move the check out of this function, i.e. the caller should be responsible # for the validation. # TODO(tian): Move more cloud agnostic vars to resources.py. - resources_vars = to_provision.make_deploy_variables(cluster_name_on_cloud, - region, zones, dryrun) + resources_vars = to_provision.make_deploy_variables( + resources_utils.ClusterName( + cluster_name, + cluster_name_on_cloud, + ), region, zones, dryrun) config_dict = {} specific_reservations = set( diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index a0f746a7098..9996d489d39 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -1545,8 +1545,8 @@ def _retry_zones( to_provision.cloud, region, zones, - provisioner.ClusterName(cluster_name, - handle.cluster_name_on_cloud), + resources_utils.ClusterName( + cluster_name, handle.cluster_name_on_cloud), num_nodes=num_nodes, cluster_yaml=handle.cluster_yaml, prev_cluster_ever_up=prev_cluster_ever_up, @@ -1556,8 +1556,10 @@ def _retry_zones( # caller. resources_vars = ( to_provision.cloud.make_deploy_resources_variables( - to_provision, handle.cluster_name_on_cloud, region, - zones)) + to_provision, + resources_utils.ClusterName( + cluster_name, handle.cluster_name_on_cloud), + region, zones)) config_dict['provision_record'] = provision_record config_dict['resources_vars'] = resources_vars config_dict['handle'] = handle @@ -2869,8 +2871,8 @@ def _provision( # 4. Starting ray cluster and skylet. cluster_info = provisioner.post_provision_runtime_setup( repr(handle.launched_resources.cloud), - provisioner.ClusterName(handle.cluster_name, - handle.cluster_name_on_cloud), + resources_utils.ClusterName(handle.cluster_name, + handle.cluster_name_on_cloud), handle.cluster_yaml, provision_record=provision_record, custom_resource=resources_vars.get('custom_resources'), @@ -3838,7 +3840,7 @@ def teardown_no_lock(self, try: provisioner.teardown_cluster(repr(cloud), - provisioner.ClusterName( + resources_utils.ClusterName( cluster_name, cluster_name_on_cloud), terminate=terminate, diff --git a/sky/cli.py b/sky/cli.py index 5b180d25dc8..0ff13574bd3 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -3695,7 +3695,7 @@ def _generate_task_with_service( env: List[Tuple[str, str]], gpus: Optional[str], instance_type: Optional[str], - ports: Tuple[str], + ports: Optional[Tuple[str]], cpus: Optional[str], memory: Optional[str], disk_size: Optional[int], diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py index 029ab928f62..72cb6937884 100644 --- a/sky/clouds/aws.py +++ b/sky/clouds/aws.py @@ -374,12 +374,13 @@ def get_vcpus_mem_from_instance_type( return service_catalog.get_vcpus_mem_from_instance_type(instance_type, clouds='aws') - def make_deploy_resources_variables(self, - resources: 'resources_lib.Resources', - cluster_name_on_cloud: str, - region: 'clouds.Region', - zones: Optional[List['clouds.Zone']], - dryrun: bool = False) -> Dict[str, Any]: + def make_deploy_resources_variables( + self, + resources: 'resources_lib.Resources', + cluster_name: resources_utils.ClusterName, + region: 'clouds.Region', + zones: Optional[List['clouds.Zone']], + dryrun: bool = False) -> Dict[str, Any]: del dryrun # unused assert zones is not None, (region, zones) @@ -406,7 +407,7 @@ def make_deploy_resources_variables(self, if user_security_group is not None and not isinstance( user_security_group, str): for profile in user_security_group: - if fnmatch.fnmatchcase(cluster_name_on_cloud, + if fnmatch.fnmatchcase(cluster_name.name_on_cloud, list(profile.keys())[0]): user_security_group = list(profile.values())[0] break @@ -414,7 +415,7 @@ def make_deploy_resources_variables(self, if user_security_group is None and resources.ports is not None: # Already checked in Resources._try_validate_ports security_group = USER_PORTS_SECURITY_GROUP_NAME.format( - cluster_name_on_cloud) + cluster_name.name_on_cloud) elif user_security_group is None: security_group = DEFAULT_SECURITY_GROUP_NAME @@ -848,22 +849,24 @@ def query_status(cls, name: str, tag_filters: Dict[str, str], assert False, 'This code path should not be used.' @classmethod - def create_image_from_cluster(cls, cluster_name: str, - cluster_name_on_cloud: str, + def create_image_from_cluster(cls, + cluster_name: resources_utils.ClusterName, region: Optional[str], zone: Optional[str]) -> str: - assert region is not None, (cluster_name, cluster_name_on_cloud, region) + assert region is not None, (cluster_name.display_name, + cluster_name.name_on_cloud, region) del zone # unused - image_name = f'skypilot-{cluster_name}-{int(time.time())}' + image_name = f'skypilot-{cluster_name.display_name}-{int(time.time())}' - status = provision_lib.query_instances('AWS', cluster_name_on_cloud, + status = provision_lib.query_instances('AWS', + cluster_name.name_on_cloud, {'region': region}) instance_ids = list(status.keys()) if not instance_ids: with ux_utils.print_exception_no_traceback(): raise RuntimeError( - f'Failed to find the source cluster {cluster_name!r} on ' + f'Failed to find the source cluster {cluster_name.display_name!r} on ' 'AWS.') if len(instance_ids) != 1: @@ -890,7 +893,7 @@ def create_image_from_cluster(cls, cluster_name: str, stream_logs=True) rich_utils.force_update_status( - f'Waiting for the source image {cluster_name!r} from {region} to be available on AWS.' + f'Waiting for the source image {cluster_name.display_name!r} from {region} to be available on AWS.' ) # Wait for the image to be available wait_image_cmd = ( diff --git a/sky/clouds/azure.py b/sky/clouds/azure.py index edf7eb1a060..b132e716b4a 100644 --- a/sky/clouds/azure.py +++ b/sky/clouds/azure.py @@ -276,7 +276,7 @@ def get_zone_shell_cmd(cls) -> Optional[str]: def make_deploy_resources_variables( self, resources: 'resources.Resources', - cluster_name_on_cloud: str, + cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], dryrun: bool = False) -> Dict[str, Optional[str]]: @@ -374,7 +374,7 @@ def _failover_disk_tier() -> Optional[resources_utils.DiskTier]: 'disk_tier': Azure._get_disk_type(_failover_disk_tier()), 'cloud_init_setup_commands': cloud_init_setup_commands, 'azure_subscription_id': self.get_project_id(dryrun), - 'resource_group': f'{cluster_name_on_cloud}-{region_name}', + 'resource_group': f'{cluster_name.name_on_cloud}-{region_name}', } def _get_feasible_launchable_resources( diff --git a/sky/clouds/cloud.py b/sky/clouds/cloud.py index 08045e28ab9..046f0f24948 100644 --- a/sky/clouds/cloud.py +++ b/sky/clouds/cloud.py @@ -253,7 +253,7 @@ def is_same_cloud(self, other: 'Cloud'): def make_deploy_resources_variables( self, resources: 'resources_lib.Resources', - cluster_name_on_cloud: str, + cluster_name: resources_utils.ClusterName, region: 'Region', zones: Optional[List['Zone']], dryrun: bool = False, @@ -726,8 +726,8 @@ def query_status(cls, name: str, tag_filters: Dict[str, str], # cloud._cloud_unsupported_features(). @classmethod - def create_image_from_cluster(cls, cluster_name: str, - cluster_name_on_cloud: str, + def create_image_from_cluster(cls, + cluster_name: resources_utils.ClusterName, region: Optional[str], zone: Optional[str]) -> str: """Creates an image from the cluster. diff --git a/sky/clouds/cudo.py b/sky/clouds/cudo.py index ad7a22e6e03..3cbc72b9162 100644 --- a/sky/clouds/cudo.py +++ b/sky/clouds/cudo.py @@ -185,12 +185,12 @@ def get_zone_shell_cmd(cls) -> Optional[str]: def make_deploy_resources_variables( self, resources: 'resources_lib.Resources', - cluster_name_on_cloud: str, + cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], dryrun: bool = False, ) -> Dict[str, Optional[str]]: - del zones + del zones, cluster_name # unused r = resources acc_dict = self.get_accelerators_from_instance_type(r.instance_type) if acc_dict is not None: diff --git a/sky/clouds/fluidstack.py b/sky/clouds/fluidstack.py index 4d6b7f1a2ec..b6fdd31be14 100644 --- a/sky/clouds/fluidstack.py +++ b/sky/clouds/fluidstack.py @@ -10,6 +10,7 @@ from sky import status_lib from sky.clouds import service_catalog from sky.provision.fluidstack import fluidstack_utils +from sky.utils import resources_utils from sky.utils.resources_utils import DiskTier _CREDENTIAL_FILES = [ @@ -178,7 +179,7 @@ def get_zone_shell_cmd(cls) -> Optional[str]: def make_deploy_resources_variables( self, resources: 'resources_lib.Resources', - cluster_name_on_cloud: str, + cluster_name: resources_utils.ClusterName, region: clouds.Region, zones: Optional[List[clouds.Zone]], dryrun: bool = False, @@ -193,7 +194,7 @@ def make_deploy_resources_variables( else: custom_resources = None cuda_installation_commands = """ - sudo wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.1-1_all.deb -O /usr/local/cuda-keyring_1.1-1_all.deb; + sudo wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.1-1_all.deb -O /usr/local/cuda-keyring_1.1-1_all.deb; sudo dpkg -i /usr/local/cuda-keyring_1.1-1_all.deb; sudo apt-get update; sudo apt-get -y install cuda-toolkit-12-3; diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index 7babf34ac52..4bbfa990347 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -393,7 +393,7 @@ def get_default_instance_type( def make_deploy_resources_variables( self, resources: 'resources.Resources', - cluster_name_on_cloud: str, + cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], dryrun: bool = False) -> Dict[str, Optional[str]]: @@ -484,15 +484,15 @@ def make_deploy_resources_variables( firewall_rule = None if resources.ports is not None: - firewall_rule = ( - USER_PORTS_FIREWALL_RULE_NAME.format(cluster_name_on_cloud)) + firewall_rule = (USER_PORTS_FIREWALL_RULE_NAME.format( + cluster_name.name_on_cloud)) resources_vars['firewall_rule'] = firewall_rule # For TPU nodes. TPU VMs do not need TPU_NAME. tpu_node_name = resources_vars.get('tpu_node_name') if gcp_utils.is_tpu(resources) and not gcp_utils.is_tpu_vm(resources): if tpu_node_name is None: - tpu_node_name = cluster_name_on_cloud + tpu_node_name = cluster_name.name_on_cloud resources_vars['tpu_node_name'] = tpu_node_name @@ -979,8 +979,8 @@ def query_status(cls, name: str, tag_filters: Dict[str, str], assert False, 'This code path should not be used.' @classmethod - def create_image_from_cluster(cls, cluster_name: str, - cluster_name_on_cloud: str, + def create_image_from_cluster(cls, + cluster_name: resources_utils.ClusterName, region: Optional[str], zone: Optional[str]) -> str: del region # unused @@ -989,7 +989,7 @@ def create_image_from_cluster(cls, cluster_name: str, # `ray-cluster-name` tag, which is guaranteed by the current `ray` # backend. Once the `provision.query_instances` is implemented for GCP, # we should be able to get rid of this assumption. - tag_filters = {'ray-cluster-name': cluster_name_on_cloud} + tag_filters = {'ray-cluster-name': cluster_name.name_on_cloud} label_filter_str = cls._label_filter_str(tag_filters) instance_name_cmd = ('gcloud compute instances list ' f'--filter="({label_filter_str})" ' @@ -1001,7 +1001,8 @@ def create_image_from_cluster(cls, cluster_name: str, subprocess_utils.handle_returncode( returncode, instance_name_cmd, - error_msg=f'Failed to get instance name for {cluster_name!r}', + error_msg= + f'Failed to get instance name for {cluster_name.display_name!r}', stderr=stderr, stream_logs=True) instance_names = json.loads(stdout) @@ -1012,7 +1013,7 @@ def create_image_from_cluster(cls, cluster_name: str, f'instance, but got: {instance_names}') instance_name = instance_names[0]['name'] - image_name = f'skypilot-{cluster_name}-{int(time.time())}' + image_name = f'skypilot-{cluster_name.display_name}-{int(time.time())}' create_image_cmd = (f'gcloud compute images create {image_name} ' f'--source-disk {instance_name} ' f'--source-disk-zone {zone}') @@ -1024,7 +1025,8 @@ def create_image_from_cluster(cls, cluster_name: str, subprocess_utils.handle_returncode( returncode, create_image_cmd, - error_msg=f'Failed to create image for {cluster_name!r}', + error_msg= + f'Failed to create image for {cluster_name.display_name!r}', stderr=stderr, stream_logs=True) @@ -1038,7 +1040,8 @@ def create_image_from_cluster(cls, cluster_name: str, subprocess_utils.handle_returncode( returncode, image_uri_cmd, - error_msg=f'Failed to get image uri for {cluster_name!r}', + error_msg= + f'Failed to get image uri for {cluster_name.display_name!r}', stderr=stderr, stream_logs=True) diff --git a/sky/clouds/ibm.py b/sky/clouds/ibm.py index 880ad212e25..bec156c56aa 100644 --- a/sky/clouds/ibm.py +++ b/sky/clouds/ibm.py @@ -171,7 +171,7 @@ def is_same_cloud(self, other): def make_deploy_resources_variables( self, resources: 'resources_lib.Resources', - cluster_name_on_cloud: str, + cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], dryrun: bool = False, @@ -187,7 +187,7 @@ def make_deploy_resources_variables( Returns: A dictionary of cloud-specific node type variables. """ - del cluster_name_on_cloud, dryrun # Unused. + del cluster_name, dryrun # Unused. def _get_profile_resources(instance_profile): """returns a dict representing the diff --git a/sky/clouds/kubernetes.py b/sky/clouds/kubernetes.py index fcf8c2f87ac..e116ad58056 100644 --- a/sky/clouds/kubernetes.py +++ b/sky/clouds/kubernetes.py @@ -222,11 +222,11 @@ def get_image_size(cls, image_id: str, region: Optional[str]) -> int: def make_deploy_resources_variables( self, resources: 'resources_lib.Resources', - cluster_name_on_cloud: str, + cluster_name: resources_utils.ClusterName, region: Optional['clouds.Region'], zones: Optional[List['clouds.Zone']], dryrun: bool = False) -> Dict[str, Optional[str]]: - del cluster_name_on_cloud, zones, dryrun # Unused. + del cluster_name, zones, dryrun # Unused. if region is None: region = self._regions[0] diff --git a/sky/clouds/lambda_cloud.py b/sky/clouds/lambda_cloud.py index 37750355a88..198cd20d15e 100644 --- a/sky/clouds/lambda_cloud.py +++ b/sky/clouds/lambda_cloud.py @@ -160,11 +160,11 @@ def get_zone_shell_cmd(cls) -> Optional[str]: def make_deploy_resources_variables( self, resources: 'resources_lib.Resources', - cluster_name_on_cloud: str, + cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], dryrun: bool = False) -> Dict[str, Optional[str]]: - del cluster_name_on_cloud, dryrun # Unused. + del cluster_name, dryrun # Unused. assert zones is None, 'Lambda does not support zones.' r = resources diff --git a/sky/clouds/oci.py b/sky/clouds/oci.py index 03351fc4cf6..02b8e55a1a3 100644 --- a/sky/clouds/oci.py +++ b/sky/clouds/oci.py @@ -191,11 +191,11 @@ def get_zone_shell_cmd(cls) -> Optional[str]: def make_deploy_resources_variables( self, resources: 'resources_lib.Resources', - cluster_name_on_cloud: str, + cluster_name: resources_utils.ClusterName, region: Optional['clouds.Region'], zones: Optional[List['clouds.Zone']], dryrun: bool = False) -> Dict[str, Optional[str]]: - del cluster_name_on_cloud, dryrun # Unused. + del cluster_name, dryrun # Unused. assert region is not None, resources acc_dict = self.get_accelerators_from_instance_type( diff --git a/sky/clouds/paperspace.py b/sky/clouds/paperspace.py index f76772ab8b7..f23f3644c66 100644 --- a/sky/clouds/paperspace.py +++ b/sky/clouds/paperspace.py @@ -177,11 +177,11 @@ def get_zone_shell_cmd(cls) -> Optional[str]: def make_deploy_resources_variables( self, resources: 'resources_lib.Resources', - cluster_name_on_cloud: str, + cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], dryrun: bool = False) -> Dict[str, Optional[str]]: - del zones, dryrun + del zones, dryrun, cluster_name r = resources acc_dict = self.get_accelerators_from_instance_type(r.instance_type) diff --git a/sky/clouds/runpod.py b/sky/clouds/runpod.py index 0f9e5c68169..2577adcd4c8 100644 --- a/sky/clouds/runpod.py +++ b/sky/clouds/runpod.py @@ -170,11 +170,11 @@ def get_zone_shell_cmd(cls) -> Optional[str]: def make_deploy_resources_variables( self, resources: 'resources_lib.Resources', - cluster_name_on_cloud: str, + cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], dryrun: bool = False) -> Dict[str, Optional[str]]: - del zones, dryrun # unused + del zones, dryrun, cluster_name # unused r = resources acc_dict = self.get_accelerators_from_instance_type(r.instance_type) diff --git a/sky/clouds/scp.py b/sky/clouds/scp.py index 1d6cb6cf20f..e41ad4706db 100644 --- a/sky/clouds/scp.py +++ b/sky/clouds/scp.py @@ -183,11 +183,11 @@ def get_zone_shell_cmd(cls) -> Optional[str]: def make_deploy_resources_variables( self, resources: 'resources_lib.Resources', - cluster_name_on_cloud: str, + cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], dryrun: bool = False) -> Dict[str, Optional[str]]: - del cluster_name_on_cloud, dryrun # Unused. + del cluster_name, dryrun # Unused. assert zones is None, 'SCP does not support zones.' r = resources diff --git a/sky/clouds/vsphere.py b/sky/clouds/vsphere.py index 02a794d7d58..9f5d13fdce3 100644 --- a/sky/clouds/vsphere.py +++ b/sky/clouds/vsphere.py @@ -175,13 +175,13 @@ def get_zone_shell_cmd(cls) -> Optional[str]: def make_deploy_resources_variables( self, resources: 'resources_lib.Resources', - cluster_name_on_cloud: str, + cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], dryrun: bool = False, ) -> Dict[str, Optional[str]]: # TODO get image id here. - del cluster_name_on_cloud, dryrun # unused + del cluster_name, dryrun # unused assert zones is not None, (region, zones) zone_names = [zone.name for zone in zones] r = resources diff --git a/sky/provision/provisioner.py b/sky/provision/provisioner.py index df9a9fcc58a..6e3886828e5 100644 --- a/sky/provision/provisioner.py +++ b/sky/provision/provisioner.py @@ -25,6 +25,7 @@ from sky.provision import metadata_utils from sky.skylet import constants from sky.utils import common_utils +from sky.utils import resources_utils from sky.utils import rich_utils from sky.utils import ux_utils @@ -38,23 +39,11 @@ _TITLE = '\n\n' + '=' * 20 + ' {} ' + '=' * 20 + '\n' -@dataclasses.dataclass -class ClusterName: - display_name: str - name_on_cloud: str - - def __repr__(self) -> str: - return repr(self.display_name) - - def __str__(self) -> str: - return self.display_name - - def _bulk_provision( cloud: clouds.Cloud, region: clouds.Region, zones: Optional[List[clouds.Zone]], - cluster_name: ClusterName, + cluster_name: resources_utils.ClusterName, bootstrap_config: provision_common.ProvisionConfig, ) -> provision_common.ProvisionRecord: provider_name = repr(cloud) @@ -135,7 +124,7 @@ def bulk_provision( cloud: clouds.Cloud, region: clouds.Region, zones: Optional[List[clouds.Zone]], - cluster_name: ClusterName, + cluster_name: resources_utils.ClusterName, num_nodes: int, cluster_yaml: str, prev_cluster_ever_up: bool, @@ -225,7 +214,7 @@ def bulk_provision( raise -def teardown_cluster(cloud_name: str, cluster_name: ClusterName, +def teardown_cluster(cloud_name: str, cluster_name: resources_utils.ClusterName, terminate: bool, provider_config: Dict) -> None: """Deleting or stopping a cluster. @@ -411,8 +400,8 @@ def wait_for_ssh(cluster_info: provision_common.ClusterInfo, def _post_provision_setup( - cloud_name: str, cluster_name: ClusterName, cluster_yaml: str, - provision_record: provision_common.ProvisionRecord, + cloud_name: str, cluster_name: resources_utils.ClusterName, + cluster_yaml: str, provision_record: provision_common.ProvisionRecord, custom_resource: Optional[str]) -> provision_common.ClusterInfo: config_from_yaml = common_utils.read_yaml(cluster_yaml) provider_config = config_from_yaml.get('provider') @@ -563,8 +552,8 @@ def _post_provision_setup( def post_provision_runtime_setup( - cloud_name: str, cluster_name: ClusterName, cluster_yaml: str, - provision_record: provision_common.ProvisionRecord, + cloud_name: str, cluster_name: resources_utils.ClusterName, + cluster_yaml: str, provision_record: provision_common.ProvisionRecord, custom_resource: Optional[str], log_dir: str) -> provision_common.ClusterInfo: """Run internal setup commands after provisioning and before user setup. diff --git a/sky/resources.py b/sky/resources.py index bc599a3cf29..0fb70109302 100644 --- a/sky/resources.py +++ b/sky/resources.py @@ -920,6 +920,20 @@ def _try_validate_ports(self) -> None: """ if self.ports is None: return + if self.cloud is None or isinstance(self.cloud, clouds.AWS): + security_group_name = skypilot_config.get_nested( + ('aws', 'security_group_name'), None) + if security_group_name is not None: + with ux_utils.print_exception_no_traceback(): + logger.warning( + f'Ports {self.ports} and security group name are ' + f'specified: {security_group_name}. It is not ' + 'guaranteed that the ports will be opened if the ' + 'specified security group is not correctly set up. ' + 'Please try to specify `ports` only and leave out ' + '`aws.security_group_name` in `~/.sky/config.yaml` to ' + 'allow SkyPilot to automatically create and configure ' + 'the security group.') if self.cloud is not None: self.cloud.check_features_are_supported( self, {clouds.CloudImplementationFeatures.OPEN_PORTS}) @@ -994,7 +1008,7 @@ def get_accelerators_str(self) -> str: def get_spot_str(self) -> str: return '[Spot]' if self.use_spot else '' - def make_deploy_variables(self, cluster_name_on_cloud: str, + def make_deploy_variables(self, cluster_name: resources_utils.ClusterName, region: clouds.Region, zones: Optional[List[clouds.Zone]], dryrun: bool) -> Dict[str, Optional[str]]: @@ -1006,7 +1020,7 @@ def make_deploy_variables(self, cluster_name_on_cloud: str, variables are generated by this method. """ cloud_specific_variables = self.cloud.make_deploy_resources_variables( - self, cluster_name_on_cloud, region, zones, dryrun) + self, cluster_name, region, zones, dryrun) docker_image = self.extract_docker_image() return dict( cloud_specific_variables, diff --git a/sky/utils/resources_utils.py b/sky/utils/resources_utils.py index e2357b9eeb7..87a62dab95b 100644 --- a/sky/utils/resources_utils.py +++ b/sky/utils/resources_utils.py @@ -1,4 +1,5 @@ """Utility functions for resources.""" +import dataclasses import enum import itertools import re @@ -43,6 +44,18 @@ def __le__(self, other: 'DiskTier') -> bool: return types.index(self) <= types.index(other) +@dataclasses.dataclass +class ClusterName: + display_name: str + name_on_cloud: str + + def __repr__(self) -> str: + return repr(self.display_name) + + def __str__(self) -> str: + return self.display_name + + def check_port_str(port: str) -> None: if not port.isdigit(): with ux_utils.print_exception_no_traceback(): diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index e4cac19a9d5..71a712af218 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -30,6 +30,36 @@ def _check_not_both_fields_present(field1: str, field2: str): } +def _get_cloud_name_property_mapping(field: str): + return { + field: { + 'oneOf': [ + { + 'type': 'string' + }, + { + # A list of single-element dict to pretain the + # order. + # Example: + # property_name: + # - my-cluster1-*: my-property-1 + # - my-cluster2-*: my-property-2 + # - "*"": my-property-3 + 'type': 'array', + 'items': { + 'type': 'object', + 'additionalProperties': { + 'type': 'string' + }, + 'maxProperties': 1, + 'minProperties': 1, + }, + } + ] + } + } + + def _get_single_resources_schema(): """Schema for a single resource in a resources list.""" # To avoid circular imports, only import when needed. @@ -553,33 +583,6 @@ def get_default_remote_identity(cloud: str) -> str: } } -_REMOTE_IDENTITY_SCHEMA_AWS = { - 'remote_identity': { - 'oneOf': [ - { - 'type': 'string' - }, - { - # A list of single-element dict to pretain the order. - # Example: - # remote_identity: - # - my-cluster1-*: my-iam-role-1 - # - my-cluster2-*: my-iam-role-2 - # - "*"": my-iam-role-3 - 'type': 'array', - 'items': { - 'type': 'object', - 'additionalProperties': { - 'type': 'string' - }, - 'maxProperties': 1, - 'minProperties': 1, - }, - } - ] - } -} - _REMOTE_IDENTITY_SCHEMA_KUBERNETES = { 'remote_identity': { 'type': 'string' @@ -597,7 +600,7 @@ def get_config_schema(): # Validation may fail if $schema is included. if k != '$schema' } - resources_schema['properties'].pop('port', None) + resources_schema['properties'].pop('ports') controller_resources_schema = { 'type': 'object', 'required': [], @@ -619,31 +622,7 @@ def get_config_schema(): 'required': [], 'additionalProperties': False, 'properties': { - 'security_group_name': { - 'oneOf': [ - { - 'type': 'string' - }, - { - # A list of single-element dict to pretain the - # order. - # Example: - # security_group_name: - # - my-cluster1-*: my-security-group-1 - # - my-cluster2-*: my-security-group-2 - # - "*"": my-security-group-3 - 'type': 'array', - 'items': { - 'type': 'object', - 'additionalProperties': { - 'type': 'string' - }, - 'maxProperties': 1, - 'minProperties': 1, - }, - } - ] - }, + **_get_cloud_name_property_mapping('security_group_name'), **_LABELS_SCHEMA, **_NETWORK_CONFIG_SCHEMA, }, @@ -748,7 +727,8 @@ def get_config_schema(): for cloud, config in cloud_configs.items(): if cloud == 'aws': - config['properties'].update(_REMOTE_IDENTITY_SCHEMA_AWS) + config['properties'].update( + _get_cloud_name_property_mapping('remote_identity')) elif cloud == 'kubernetes': config['properties'].update(_REMOTE_IDENTITY_SCHEMA_KUBERNETES) else: