Skip to content

Commit

Permalink
refactor: use ClusterName data class
Browse files Browse the repository at this point in the history
  • Loading branch information
JGSweets committed May 13, 2024
1 parent e8fe54c commit 9c71b97
Show file tree
Hide file tree
Showing 21 changed files with 148 additions and 136 deletions.
11 changes: 9 additions & 2 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,11 @@
# we need to take this field from the new yaml.
('provider', 'tpu_node'),
('provider', 'security_group', 'GroupName'),
('available_node_types', 'ray.head.default', 'node_config',
'IamInstanceProfile'),
('available_node_types', 'ray.head.default', 'node_config', 'UserData'),
('available_node_types', 'ray.worker.default', 'node_config',
'IamInstanceProfile'),
('available_node_types', 'ray.worker.default', 'node_config', 'UserData'),
]

Expand Down Expand Up @@ -792,8 +796,11 @@ def write_cluster_config(
# move the check out of this function, i.e. the caller should be responsible
# for the validation.
# TODO(tian): Move more cloud agnostic vars to resources.py.
resources_vars = to_provision.make_deploy_variables(cluster_name_on_cloud,
region, zones, dryrun)
resources_vars = to_provision.make_deploy_variables(
resources_utils.ClusterName(
cluster_name,
cluster_name_on_cloud,
), region, zones, dryrun)
config_dict = {}

specific_reservations = set(
Expand Down
16 changes: 9 additions & 7 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1545,8 +1545,8 @@ def _retry_zones(
to_provision.cloud,
region,
zones,
provisioner.ClusterName(cluster_name,
handle.cluster_name_on_cloud),
resources_utils.ClusterName(
cluster_name, handle.cluster_name_on_cloud),
num_nodes=num_nodes,
cluster_yaml=handle.cluster_yaml,
prev_cluster_ever_up=prev_cluster_ever_up,
Expand All @@ -1556,8 +1556,10 @@ def _retry_zones(
# caller.
resources_vars = (
to_provision.cloud.make_deploy_resources_variables(
to_provision, handle.cluster_name_on_cloud, region,
zones))
to_provision,
resources_utils.ClusterName(
cluster_name, handle.cluster_name_on_cloud),
region, zones))
config_dict['provision_record'] = provision_record
config_dict['resources_vars'] = resources_vars
config_dict['handle'] = handle
Expand Down Expand Up @@ -2869,8 +2871,8 @@ def _provision(
# 4. Starting ray cluster and skylet.
cluster_info = provisioner.post_provision_runtime_setup(
repr(handle.launched_resources.cloud),
provisioner.ClusterName(handle.cluster_name,
handle.cluster_name_on_cloud),
resources_utils.ClusterName(handle.cluster_name,
handle.cluster_name_on_cloud),
handle.cluster_yaml,
provision_record=provision_record,
custom_resource=resources_vars.get('custom_resources'),
Expand Down Expand Up @@ -3838,7 +3840,7 @@ def teardown_no_lock(self,

try:
provisioner.teardown_cluster(repr(cloud),
provisioner.ClusterName(
resources_utils.ClusterName(
cluster_name,
cluster_name_on_cloud),
terminate=terminate,
Expand Down
2 changes: 1 addition & 1 deletion sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3695,7 +3695,7 @@ def _generate_task_with_service(
env: List[Tuple[str, str]],
gpus: Optional[str],
instance_type: Optional[str],
ports: Tuple[str],
ports: Optional[Tuple[str]],
cpus: Optional[str],
memory: Optional[str],
disk_size: Optional[int],
Expand Down
33 changes: 18 additions & 15 deletions sky/clouds/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,12 +374,13 @@ def get_vcpus_mem_from_instance_type(
return service_catalog.get_vcpus_mem_from_instance_type(instance_type,
clouds='aws')

def make_deploy_resources_variables(self,
resources: 'resources_lib.Resources',
cluster_name_on_cloud: str,
region: 'clouds.Region',
zones: Optional[List['clouds.Zone']],
dryrun: bool = False) -> Dict[str, Any]:
def make_deploy_resources_variables(
self,
resources: 'resources_lib.Resources',
cluster_name: resources_utils.ClusterName,
region: 'clouds.Region',
zones: Optional[List['clouds.Zone']],
dryrun: bool = False) -> Dict[str, Any]:
del dryrun # unused
assert zones is not None, (region, zones)

Expand All @@ -406,15 +407,15 @@ def make_deploy_resources_variables(self,
if user_security_group is not None and not isinstance(
user_security_group, str):
for profile in user_security_group:
if fnmatch.fnmatchcase(cluster_name_on_cloud,
if fnmatch.fnmatchcase(cluster_name.name_on_cloud,
list(profile.keys())[0]):
user_security_group = list(profile.values())[0]
break
security_group = user_security_group
if user_security_group is None and resources.ports is not None:
# Already checked in Resources._try_validate_ports
security_group = USER_PORTS_SECURITY_GROUP_NAME.format(
cluster_name_on_cloud)
cluster_name.name_on_cloud)
elif user_security_group is None:
security_group = DEFAULT_SECURITY_GROUP_NAME

Expand Down Expand Up @@ -848,22 +849,24 @@ def query_status(cls, name: str, tag_filters: Dict[str, str],
assert False, 'This code path should not be used.'

@classmethod
def create_image_from_cluster(cls, cluster_name: str,
cluster_name_on_cloud: str,
def create_image_from_cluster(cls,
cluster_name: resources_utils.ClusterName,
region: Optional[str],
zone: Optional[str]) -> str:
assert region is not None, (cluster_name, cluster_name_on_cloud, region)
assert region is not None, (cluster_name.display_name,
cluster_name.name_on_cloud, region)
del zone # unused

image_name = f'skypilot-{cluster_name}-{int(time.time())}'
image_name = f'skypilot-{cluster_name.display_name}-{int(time.time())}'

status = provision_lib.query_instances('AWS', cluster_name_on_cloud,
status = provision_lib.query_instances('AWS',
cluster_name.name_on_cloud,
{'region': region})
instance_ids = list(status.keys())
if not instance_ids:
with ux_utils.print_exception_no_traceback():
raise RuntimeError(
f'Failed to find the source cluster {cluster_name!r} on '
f'Failed to find the source cluster {cluster_name.display_name!r} on '
'AWS.')

if len(instance_ids) != 1:
Expand All @@ -890,7 +893,7 @@ def create_image_from_cluster(cls, cluster_name: str,
stream_logs=True)

rich_utils.force_update_status(
f'Waiting for the source image {cluster_name!r} from {region} to be available on AWS.'
f'Waiting for the source image {cluster_name.display_name!r} from {region} to be available on AWS.'
)
# Wait for the image to be available
wait_image_cmd = (
Expand Down
4 changes: 2 additions & 2 deletions sky/clouds/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def get_zone_shell_cmd(cls) -> Optional[str]:
def make_deploy_resources_variables(
self,
resources: 'resources.Resources',
cluster_name_on_cloud: str,
cluster_name: resources_utils.ClusterName,
region: 'clouds.Region',
zones: Optional[List['clouds.Zone']],
dryrun: bool = False) -> Dict[str, Optional[str]]:
Expand Down Expand Up @@ -374,7 +374,7 @@ def _failover_disk_tier() -> Optional[resources_utils.DiskTier]:
'disk_tier': Azure._get_disk_type(_failover_disk_tier()),
'cloud_init_setup_commands': cloud_init_setup_commands,
'azure_subscription_id': self.get_project_id(dryrun),
'resource_group': f'{cluster_name_on_cloud}-{region_name}',
'resource_group': f'{cluster_name.name_on_cloud}-{region_name}',
}

def _get_feasible_launchable_resources(
Expand Down
6 changes: 3 additions & 3 deletions sky/clouds/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def is_same_cloud(self, other: 'Cloud'):
def make_deploy_resources_variables(
self,
resources: 'resources_lib.Resources',
cluster_name_on_cloud: str,
cluster_name: resources_utils.ClusterName,
region: 'Region',
zones: Optional[List['Zone']],
dryrun: bool = False,
Expand Down Expand Up @@ -726,8 +726,8 @@ def query_status(cls, name: str, tag_filters: Dict[str, str],
# cloud._cloud_unsupported_features().

@classmethod
def create_image_from_cluster(cls, cluster_name: str,
cluster_name_on_cloud: str,
def create_image_from_cluster(cls,
cluster_name: resources_utils.ClusterName,
region: Optional[str],
zone: Optional[str]) -> str:
"""Creates an image from the cluster.
Expand Down
4 changes: 2 additions & 2 deletions sky/clouds/cudo.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,12 @@ def get_zone_shell_cmd(cls) -> Optional[str]:
def make_deploy_resources_variables(
self,
resources: 'resources_lib.Resources',
cluster_name_on_cloud: str,
cluster_name: resources_utils.ClusterName,
region: 'clouds.Region',
zones: Optional[List['clouds.Zone']],
dryrun: bool = False,
) -> Dict[str, Optional[str]]:
del zones
del zones, cluster_name # unused
r = resources
acc_dict = self.get_accelerators_from_instance_type(r.instance_type)
if acc_dict is not None:
Expand Down
5 changes: 3 additions & 2 deletions sky/clouds/fluidstack.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sky import status_lib
from sky.clouds import service_catalog
from sky.provision.fluidstack import fluidstack_utils
from sky.utils import resources_utils
from sky.utils.resources_utils import DiskTier

_CREDENTIAL_FILES = [
Expand Down Expand Up @@ -178,7 +179,7 @@ def get_zone_shell_cmd(cls) -> Optional[str]:
def make_deploy_resources_variables(
self,
resources: 'resources_lib.Resources',
cluster_name_on_cloud: str,
cluster_name: resources_utils.ClusterName,
region: clouds.Region,
zones: Optional[List[clouds.Zone]],
dryrun: bool = False,
Expand All @@ -193,7 +194,7 @@ def make_deploy_resources_variables(
else:
custom_resources = None
cuda_installation_commands = """
sudo wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.1-1_all.deb -O /usr/local/cuda-keyring_1.1-1_all.deb;
sudo wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.1-1_all.deb -O /usr/local/cuda-keyring_1.1-1_all.deb;
sudo dpkg -i /usr/local/cuda-keyring_1.1-1_all.deb;
sudo apt-get update;
sudo apt-get -y install cuda-toolkit-12-3;
Expand Down
25 changes: 14 additions & 11 deletions sky/clouds/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def get_default_instance_type(
def make_deploy_resources_variables(
self,
resources: 'resources.Resources',
cluster_name_on_cloud: str,
cluster_name: resources_utils.ClusterName,
region: 'clouds.Region',
zones: Optional[List['clouds.Zone']],
dryrun: bool = False) -> Dict[str, Optional[str]]:
Expand Down Expand Up @@ -484,15 +484,15 @@ def make_deploy_resources_variables(

firewall_rule = None
if resources.ports is not None:
firewall_rule = (
USER_PORTS_FIREWALL_RULE_NAME.format(cluster_name_on_cloud))
firewall_rule = (USER_PORTS_FIREWALL_RULE_NAME.format(
cluster_name.name_on_cloud))
resources_vars['firewall_rule'] = firewall_rule

# For TPU nodes. TPU VMs do not need TPU_NAME.
tpu_node_name = resources_vars.get('tpu_node_name')
if gcp_utils.is_tpu(resources) and not gcp_utils.is_tpu_vm(resources):
if tpu_node_name is None:
tpu_node_name = cluster_name_on_cloud
tpu_node_name = cluster_name.name_on_cloud

resources_vars['tpu_node_name'] = tpu_node_name

Expand Down Expand Up @@ -979,8 +979,8 @@ def query_status(cls, name: str, tag_filters: Dict[str, str],
assert False, 'This code path should not be used.'

@classmethod
def create_image_from_cluster(cls, cluster_name: str,
cluster_name_on_cloud: str,
def create_image_from_cluster(cls,
cluster_name: resources_utils.ClusterName,
region: Optional[str],
zone: Optional[str]) -> str:
del region # unused
Expand All @@ -989,7 +989,7 @@ def create_image_from_cluster(cls, cluster_name: str,
# `ray-cluster-name` tag, which is guaranteed by the current `ray`
# backend. Once the `provision.query_instances` is implemented for GCP,
# we should be able to get rid of this assumption.
tag_filters = {'ray-cluster-name': cluster_name_on_cloud}
tag_filters = {'ray-cluster-name': cluster_name.name_on_cloud}
label_filter_str = cls._label_filter_str(tag_filters)
instance_name_cmd = ('gcloud compute instances list '
f'--filter="({label_filter_str})" '
Expand All @@ -1001,7 +1001,8 @@ def create_image_from_cluster(cls, cluster_name: str,
subprocess_utils.handle_returncode(
returncode,
instance_name_cmd,
error_msg=f'Failed to get instance name for {cluster_name!r}',
error_msg=
f'Failed to get instance name for {cluster_name.display_name!r}',
stderr=stderr,
stream_logs=True)
instance_names = json.loads(stdout)
Expand All @@ -1012,7 +1013,7 @@ def create_image_from_cluster(cls, cluster_name: str,
f'instance, but got: {instance_names}')
instance_name = instance_names[0]['name']

image_name = f'skypilot-{cluster_name}-{int(time.time())}'
image_name = f'skypilot-{cluster_name.display_name}-{int(time.time())}'
create_image_cmd = (f'gcloud compute images create {image_name} '
f'--source-disk {instance_name} '
f'--source-disk-zone {zone}')
Expand All @@ -1024,7 +1025,8 @@ def create_image_from_cluster(cls, cluster_name: str,
subprocess_utils.handle_returncode(
returncode,
create_image_cmd,
error_msg=f'Failed to create image for {cluster_name!r}',
error_msg=
f'Failed to create image for {cluster_name.display_name!r}',
stderr=stderr,
stream_logs=True)

Expand All @@ -1038,7 +1040,8 @@ def create_image_from_cluster(cls, cluster_name: str,
subprocess_utils.handle_returncode(
returncode,
image_uri_cmd,
error_msg=f'Failed to get image uri for {cluster_name!r}',
error_msg=
f'Failed to get image uri for {cluster_name.display_name!r}',
stderr=stderr,
stream_logs=True)

Expand Down
4 changes: 2 additions & 2 deletions sky/clouds/ibm.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def is_same_cloud(self, other):
def make_deploy_resources_variables(
self,
resources: 'resources_lib.Resources',
cluster_name_on_cloud: str,
cluster_name: resources_utils.ClusterName,
region: 'clouds.Region',
zones: Optional[List['clouds.Zone']],
dryrun: bool = False,
Expand All @@ -187,7 +187,7 @@ def make_deploy_resources_variables(
Returns:
A dictionary of cloud-specific node type variables.
"""
del cluster_name_on_cloud, dryrun # Unused.
del cluster_name, dryrun # Unused.

def _get_profile_resources(instance_profile):
"""returns a dict representing the
Expand Down
4 changes: 2 additions & 2 deletions sky/clouds/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,11 @@ def get_image_size(cls, image_id: str, region: Optional[str]) -> int:
def make_deploy_resources_variables(
self,
resources: 'resources_lib.Resources',
cluster_name_on_cloud: str,
cluster_name: resources_utils.ClusterName,
region: Optional['clouds.Region'],
zones: Optional[List['clouds.Zone']],
dryrun: bool = False) -> Dict[str, Optional[str]]:
del cluster_name_on_cloud, zones, dryrun # Unused.
del cluster_name, zones, dryrun # Unused.
if region is None:
region = self._regions[0]

Expand Down
4 changes: 2 additions & 2 deletions sky/clouds/lambda_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,11 @@ def get_zone_shell_cmd(cls) -> Optional[str]:
def make_deploy_resources_variables(
self,
resources: 'resources_lib.Resources',
cluster_name_on_cloud: str,
cluster_name: resources_utils.ClusterName,
region: 'clouds.Region',
zones: Optional[List['clouds.Zone']],
dryrun: bool = False) -> Dict[str, Optional[str]]:
del cluster_name_on_cloud, dryrun # Unused.
del cluster_name, dryrun # Unused.
assert zones is None, 'Lambda does not support zones.'

r = resources
Expand Down
4 changes: 2 additions & 2 deletions sky/clouds/oci.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,11 @@ def get_zone_shell_cmd(cls) -> Optional[str]:
def make_deploy_resources_variables(
self,
resources: 'resources_lib.Resources',
cluster_name_on_cloud: str,
cluster_name: resources_utils.ClusterName,
region: Optional['clouds.Region'],
zones: Optional[List['clouds.Zone']],
dryrun: bool = False) -> Dict[str, Optional[str]]:
del cluster_name_on_cloud, dryrun # Unused.
del cluster_name, dryrun # Unused.
assert region is not None, resources

acc_dict = self.get_accelerators_from_instance_type(
Expand Down
4 changes: 2 additions & 2 deletions sky/clouds/paperspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,11 @@ def get_zone_shell_cmd(cls) -> Optional[str]:
def make_deploy_resources_variables(
self,
resources: 'resources_lib.Resources',
cluster_name_on_cloud: str,
cluster_name: resources_utils.ClusterName,
region: 'clouds.Region',
zones: Optional[List['clouds.Zone']],
dryrun: bool = False) -> Dict[str, Optional[str]]:
del zones, dryrun
del zones, dryrun, cluster_name

r = resources
acc_dict = self.get_accelerators_from_instance_type(r.instance_type)
Expand Down
Loading

0 comments on commit 9c71b97

Please sign in to comment.