Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU VM] Attaching & Mounting Persistent Disk #3497

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,9 @@ def write_cluster_config(
config_dict['ray'] = tmp_yaml_path
return config_dict
_add_auth_to_cluster_config(cloud, tmp_yaml_path)
# _DEFAULT_DISK_SIZE_GB = 256
if to_provision.disk_size != 256:
_add_disk_size_to_cluster_config(to_provision.disk_size, tmp_yaml_path)

# Add kubernetes config fields from ~/.sky/config
if isinstance(cloud, clouds.Kubernetes):
Expand Down Expand Up @@ -997,6 +1000,12 @@ def _add_auth_to_cluster_config(cloud: clouds.Cloud, cluster_config_file: str):
assert False, cloud
common_utils.dump_yaml(cluster_config_file, config)

def _add_disk_size_to_cluster_config(disk_size: int, cluster_config_file: str):
"""Add disk size to the cluster config."""
config = common_utils.read_yaml(cluster_config_file)
config['initDiskSize'] = str(disk_size)
common_utils.dump_yaml(cluster_config_file, config)


def get_run_timestamp() -> str:
return 'sky-' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')
Expand Down
1 change: 1 addition & 0 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4242,6 +4242,7 @@ def _check_existing_cluster(
to_provision = handle_before_refresh.launched_resources
self.check_resources_fit_cluster(handle_before_refresh, task)


logger.info(
f'{colorama.Fore.CYAN}Creating a new cluster: {cluster_name!r} '
f'[{task.num_nodes}x {to_provision}].'
Expand Down
1 change: 1 addition & 0 deletions sky/provision/gcp/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def get_order_key(node):
for instance_id in resumed_instance_ids:
resource.start_instance(instance_id, project_id,
availability_zone)
resource.resize_disk(project_id, availability_zone, config.node_config, instance_id)
resource.set_labels(project_id, availability_zone, instance_id,
labels)
to_start_count -= len(resumed_instance_ids)
Expand Down
88 changes: 80 additions & 8 deletions sky/provision/gcp/instance_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,15 +1368,87 @@ def start_instance(cls, node_id: str, project_id: str, zone: str) -> None:
cls.wait_for_operation(operation, project_id, zone)

@classmethod
def resize_disk(cls, project_id: str, availability_zone: str,
node_config: dict, instance_name: str) -> None:
"""Resize the disk a machine image with a different size is used.
def resize_disk(cls, project_id: str, availability_zone: str, node_config: dict,
instance_name: str) -> None:
"""Resizes disk for TPU VMs by adding a persistent disk when needed."""
import time
from googleapiclient.errors import HttpError
from googleapiclient import discovery

resource = discovery.build("compute", "v1")


# Determine the required disk size from configuration
default_disk_size = 100 # Default boot disk size for TPUVMs
requested_size = int(node_config['metadata'].get(
'diskSize', default_disk_size))

# Calculate additional disk size needed
additional_size = requested_size - default_disk_size
if additional_size <= 0:
return # No additional disk needed

# Log the disk size request
logger.info(
f"Requesting additional persistent disk of size: {additional_size}GB")

# Set disk specifications
tpu_name = instance_name.split("/")[-1]
disk_name = f"{tpu_name}-extra-disk"
disk_type = f"zones/{availability_zone}/diskTypes/pd-standard"

# Prepare the disk creation body
disk_body = {
"name": disk_name,
"sizeGb": str(additional_size),
"type": disk_type,
}

# Create the disk
try:
resource.disks().insert(project=project_id, zone=availability_zone,
body=disk_body).execute()
time.sleep(3) # Short pause after disk creation
except HttpError as e:
logger.warning(f"Disk creation failed: {e.reason}")
return

# Attach the newly created disk
attach_command = (
f"gcloud alpha compute tpus tpu-vm attach-disk {tpu_name} "
f"--zone {availability_zone} --disk {disk_name} --mode read-write"
)
if cls.execute_command_with_log(attach_command) != 0:
logger.warning("Failed to attach disk to TPU VMs.")

# Format and mount the disk
mount_command = (
f"gcloud compute tpus tpu-vm ssh {tpu_name} --zone={availability_zone} "
f"--command='sudo mkfs.ext4 -m 0 -E lazy_itable_init=0,lazy_journal_init=0,"
f"discard /dev/sdb ; sudo mkdir -p /mnt/disks/persist ; sudo mount -o "
f"discard,defaults /dev/sdb /mnt/disks/persist'"
)
if cls.execute_command_with_log(mount_command) != 0:
logger.warning("Failed to format and mount persistent disk.")

@classmethod
def execute_command_with_log(cls, command: str) -> int:
"""Executes a shell command and logs the output, returning the return code."""
from sky.skylet import log_lib
import os

rcode, stdout, stderr = log_lib.run_with_log(
command,
os.devnull,
shell=True,
stream_logs=False,
require_outputs=True,
)
if rcode != 0:
logger.warning(f"Command failed.\n**** STDOUT ****\n{stdout}\n**** STDERR ****"
f"\n{stderr}")
return rcode

TODO: Implement the feature to attach persistent disks for TPU VMs.
The boot disk of TPU VMs is not resizable, and users need to add a
persistent disk to expand disk capacity. Related issue: #2387
"""
return

@classmethod
def get_instance_info(cls, project_id: str, availability_zone: str,
Expand Down
3 changes: 3 additions & 0 deletions sky/provision/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ def bulk_provision(
tags={},
resume_stopped_nodes=True)

if 'initDiskSize' in original_config:
bootstrap_config.node_config['metadata']['diskSize'] = original_config['initDiskSize']

with provision_logging.setup_provision_logging(log_dir):
try:
logger.debug(f'SkyPilot version: {sky.__version__}; '
Expand Down
Loading