Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Serve] Support TailScale VPN #3458

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
6 changes: 5 additions & 1 deletion sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,9 @@ def write_cluster_config(
f'open(os.path.expanduser("{constants.SKY_REMOTE_RAY_PORT_FILE}"), "w", encoding="utf-8"))\''
)

# TODO(tian): Hack. Reformat here.
default_use_internal_ips = 'vpn_config' in resources_vars

# Use a tmp file path to avoid incomplete YAML file being re-used in the
# future.
tmp_yaml_path = yaml_path + '.tmp'
Expand All @@ -881,7 +884,8 @@ def write_cluster_config(

# Networking configs
'use_internal_ips': skypilot_config.get_nested(
(str(cloud).lower(), 'use_internal_ips'), False),
(str(cloud).lower(), 'use_internal_ips'),
default_use_internal_ips),
'ssh_proxy_command': ssh_proxy_command,
'vpc_name': skypilot_config.get_nested(
(str(cloud).lower(), 'vpc_name'), None),
Expand Down
75 changes: 71 additions & 4 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import math
import os
import pathlib
import pprint
import re
import signal
import subprocess
Expand All @@ -34,6 +35,7 @@
from sky import resources as resources_lib
from sky import serve as serve_lib
from sky import sky_logging
from sky import skypilot_config
from sky import status_lib
from sky import task as task_lib
from sky.backends import backend_utils
Expand Down Expand Up @@ -2110,11 +2112,12 @@ class CloudVmRayResourceHandle(backends.backend.ResourceHandle):
- (optional) Launched num nodes
- (optional) Launched resources
- (optional) Docker user name
- (optional) The cluster VPN configuration (if used)
- (optional) If TPU(s) are managed, a path to a deletion script.
"""
# Bump if any fields get added/removed/changed, and add backward
# compaitibility logic in __setstate__.
_VERSION = 7
# compatibility logic in __setstate__.
_VERSION = 8

def __init__(
self,
Expand Down Expand Up @@ -2147,13 +2150,23 @@ def __init__(
self.launched_resources = launched_resources
self.docker_user: Optional[str] = None
self.ssh_user: Optional[str] = None
# TODO(tian): Should we store the APIs in the config YAML?
self.vpn_config: Optional[Dict[str, Any]] = self._get_vpn_config()
# Deprecated. SkyPilot new provisioner API handles the TPU node
# creation/deletion.
# Backward compatibility for TPU nodes created before #2943.
# TODO (zhwu): Remove this after 0.6.0.
self.tpu_create_script = tpu_create_script
self.tpu_delete_script = tpu_delete_script

def _get_vpn_config(self) -> Optional[Dict[str, Any]]:
"""Returns the VPN config used by the cluster."""
# Directly load the VPN config from the cluster
# yaml instead of `skypilot_config` as the latter
# can be changed after the cluster is UP.
return common_utils.read_yaml(self.cluster_yaml).get(
'provider', {}).get('vpn_config', None)

def __repr__(self):
return (f'ResourceHandle('
f'\n\tcluster_name={self.cluster_name},'
Expand All @@ -2169,6 +2182,7 @@ def __repr__(self):
f'{self.launched_resources}, '
f'\n\tdocker_user={self.docker_user},'
f'\n\tssh_user={self.ssh_user},'
f'\n\tvpn_config={self.vpn_config},'
# TODO (zhwu): Remove this after 0.6.0.
f'\n\ttpu_create_script={self.tpu_create_script}, '
f'\n\ttpu_delete_script={self.tpu_delete_script})')
Expand Down Expand Up @@ -2440,6 +2454,9 @@ def __setstate__(self, state):
if version < 7:
self.ssh_user = None

if version < 8:
self.vpn_config = None

self.__dict__.update(state)

# Because the update_cluster_ips and update_ssh_ports
Expand Down Expand Up @@ -2533,23 +2550,68 @@ def check_resources_fit_cluster(
# was handled by ResourceHandle._update_cluster_region.
assert launched_resources.region is not None, handle

def _check_vpn_unchanged(
resource: resources_lib.Resources) -> Optional[str]:
"""Check if the VPN configuration is unchanged.

This function should only be called after checking cloud is the
same. Current VPN configuration is per-cloud basis, so we could
only check for the same cloud.

Returns:
None if the VPN configuration is unchanged, otherwise a string
indicating the mismatch.
"""
assert resource.cloud is None or resource.cloud.is_same_cloud(
launched_resources.cloud)
# Use launched_resources.cloud here when resource.cloud is None.
now_vpn_config = skypilot_config.get_nested(
(str(launched_resources.cloud).lower(), 'vpn', 'tailscale'),
None)
use_or_not_mismatch_str = (
'{} VPN, but current config requires the opposite')
if handle.vpn_config is None:
if now_vpn_config is None:
return None
return use_or_not_mismatch_str.format('without')
if now_vpn_config is None:
return use_or_not_mismatch_str.format('with')
if now_vpn_config == handle.vpn_config:
return None
return (f'with VPN config\n{pprint.pformat(handle.vpn_config)}, '
f'but current config is\n{pprint.pformat(now_vpn_config)}')

mismatch_str = (f'To fix: specify a new cluster name, or down the '
f'existing cluster first: sky down {cluster_name}')
valid_resource = None
requested_resource_list = []
resource_failure_reason: Dict[resources_lib.Resources, str] = {}
for resource in task.resources:
if (task.num_nodes <= handle.launched_nodes and
resource.less_demanding_than(
launched_resources,
requested_num_nodes=task.num_nodes,
check_ports=check_ports)):
valid_resource = resource
break
reason = _check_vpn_unchanged(resource)
if reason is None:
valid_resource = resource
break
else:
# TODO(tian): Maybe refactor the following into this dict
resource_failure_reason[resource] = (
f'Cloud {launched_resources.cloud} VPN config '
f'mismatch. Cluster {handle.cluster_name} is '
f'launched {reason}. Please update the VPN '
'configuration in skypilot_config.')
else:
requested_resource_list.append(f'{task.num_nodes}x {resource}')

if valid_resource is None:
for example_resource in task.resources:
if example_resource in resource_failure_reason:
with ux_utils.print_exception_no_traceback():
raise exceptions.ResourcesMismatchError(
resource_failure_reason[example_resource])
if (example_resource.region is not None and
example_resource.region != launched_resources.region):
with ux_utils.print_exception_no_traceback():
Expand Down Expand Up @@ -2849,6 +2911,9 @@ def _get_zone(runner):
return handle

def _open_ports(self, handle: CloudVmRayResourceHandle) -> None:
if handle.vpn_config is not None:
# Skip opening any ports if VPN is used.
return
cloud = handle.launched_resources.cloud
logger.debug(
f'Opening ports {handle.launched_resources.ports} for {cloud}')
Expand Down Expand Up @@ -3963,6 +4028,8 @@ def post_teardown_cleanup(self,
f'Failed to delete cloned image {image_id}. Please '
'remove it manually to avoid image leakage. Details: '
f'{common_utils.format_exception(e, use_bracket=True)}')
# We don't need to explicitly skip cleanup ports if VPN is used, as it
# will use the default security group and automatically skip it.
if terminate:
cloud = handle.launched_resources.cloud
config = common_utils.read_yaml(handle.cluster_yaml)
Expand Down
33 changes: 31 additions & 2 deletions sky/clouds/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,9 +398,13 @@ def make_deploy_resources_variables(self,
image_id = self._get_image_id(image_id_to_use, region_name,
r.instance_type)

tailscale_config = skypilot_config.get_nested(
('aws', 'vpn', 'tailscale'), None)

user_security_group = skypilot_config.get_nested(
('aws', 'security_group_name'), None)
if resources.ports is not None:
# Only open ports if VPN is not enabled.
if resources.ports is not None and tailscale_config is None:
# Already checked in Resources._try_validate_ports
assert user_security_group is None
security_group = USER_PORTS_SECURITY_GROUP_NAME.format(
Expand All @@ -411,7 +415,7 @@ def make_deploy_resources_variables(self,
else:
security_group = DEFAULT_SECURITY_GROUP_NAME

return {
resources_vars = {
'instance_type': r.instance_type,
'custom_resources': custom_resources,
'use_spot': r.use_spot,
Expand All @@ -423,6 +427,31 @@ def make_deploy_resources_variables(self,
str(security_group != user_security_group).lower(),
**AWS._get_disk_specs(r.disk_tier)
}
resources_vars['vpn_config'] = tailscale_config
if tailscale_config is not None:
unique_id = cluster_name_on_cloud
resources_vars['vpn_unique_id'] = unique_id
resources_vars['vpn_cloud_init_commands'] = [
[
'sh', '-c',
'curl -fsSL https://tailscale.com/install.sh | sh'
],
[
'sh', '-c',
('echo \'net.ipv4.ip_forward = 1\' | '
'sudo tee -a /etc/sysctl.d/99-tailscale.conf && '
'echo \'net.ipv6.conf.all.forwarding = 1\' | '
'sudo tee -a /etc/sysctl.d/99-tailscale.conf && '
'sudo sysctl -p /etc/sysctl.d/99-tailscale.conf')
],
[
'tailscale', 'up',
f'--authkey={tailscale_config["auth_key"]}',
f'--hostname={unique_id}'
],
]

return resources_vars

def _get_feasible_launchable_resources(
self, resources: 'resources_lib.Resources'
Expand Down
64 changes: 62 additions & 2 deletions sky/provision/aws/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import time
from typing import Any, Callable, Dict, List, Optional, Set, TypeVar

import requests

from sky import sky_logging
from sky import status_lib
from sky.adaptors import aws
Expand All @@ -19,6 +21,7 @@
from sky.provision.aws import utils
from sky.utils import common_utils
from sky.utils import resources_utils
from sky.utils import subprocess_utils
from sky.utils import ux_utils

logger = sky_logging.init_logger(__name__)
Expand Down Expand Up @@ -610,6 +613,40 @@ def terminate_instances(
included_instances=None,
excluded_instances=None)
instances.terminate()
# Cleanup VPN record.
vpn_config = provider_config.get('vpn_config', None)
if vpn_config is not None:
auth_headers = {'Authorization': f'Bearer {vpn_config["api_key"]}'}

def _get_node_id_from_hostname(network_name: str,
hostname: str) -> Optional[str]:
# TODO(tian): Refactor to a dedicated file for all
# VPN related functions and constants.
url_to_query = ('https://api.tailscale.com/api/v2/'
f'tailnet/{network_name}/devices')
# TODO(tian): Error handling if api key is wrong.
resp = requests.get(url_to_query, headers=auth_headers)
all_devices_in_network = resp.json().get('devices', [])
for device_info in all_devices_in_network:
if device_info.get('hostname') == hostname:
return device_info.get('nodeId')
return None

node_id_in_vpn = _get_node_id_from_hostname(
vpn_config['tailnet'], provider_config['vpn_unique_id'])
if node_id_in_vpn is None:
logger.warning('Cannot find node id for '
f'{provider_config["vpn_unique_id"]}. '
f'Skip deleting vpn record.')
else:
url_to_delete = ('https://api.tailscale.com/api/v2/'
f'device/{node_id_in_vpn}')
resp = requests.delete(url_to_delete, headers=auth_headers)
if resp.status_code != 200:
logger.warning('Failed to delete vpn record for '
f'{provider_config["vpn_unique_id"]}. '
f'Status code: {resp.status_code}, '
f'Response: {resp.text}')
if (sg_name == aws_cloud.DEFAULT_SECURITY_GROUP_NAME or
not managed_by_skypilot):
# Using default AWS SG or user specified security group. We don't need
Expand Down Expand Up @@ -843,7 +880,7 @@ def get_cluster_info(
cluster_name_on_cloud: str,
provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo:
"""See sky/provision/__init__.py"""
del provider_config # unused
assert provider_config is not None
ec2 = _default_ec2_resource(region)
filters = [
{
Expand All @@ -863,10 +900,33 @@ def get_cluster_info(
tags = [(t['Key'], t['Value']) for t in inst.tags]
# sort tags by key to support deterministic unit test stubbing
tags.sort(key=lambda x: x[0])
vpn_unique_id = provider_config.get('vpn_unique_id', None)
if vpn_unique_id is None:
private_ip = inst.private_ip_address
else:
# TODO(tian): Using cluster name as hostname is problematic for
# multi-node cluster. Should use f'{unique_id}-{node_id}'
# TODO(tian): max_retry=1000 ==> infinite retry.
# TODO(tian): Check cloud status and set a timeout after the
# instance is ready on the cloud.
query_cmd = f'tailscale ip -4 {vpn_unique_id}'
rc, stdout, stderr = subprocess_utils.run_with_retries(
query_cmd,
max_retry=1000,
retry_wait_time=5,
retry_stderrs=['no such host', 'server misbehaving'])
subprocess_utils.handle_returncode(
rc,
query_cmd,
error_msg=('Failed to query Private IP in VPN '
f'for cluster {cluster_name_on_cloud} '
f'with unique id {vpn_unique_id}'),
stderr=stdout + stderr)
private_ip = stdout.strip()
instances[inst.id] = [
common.InstanceInfo(
instance_id=inst.id,
internal_ip=inst.private_ip_address,
internal_ip=private_ip,
external_ip=inst.public_ip_address,
tags=dict(tags),
)
Expand Down
18 changes: 13 additions & 5 deletions sky/provision/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,10 +417,18 @@ def _post_provision_setup(
custom_resource: Optional[str]) -> provision_common.ClusterInfo:
config_from_yaml = common_utils.read_yaml(cluster_yaml)
provider_config = config_from_yaml.get('provider')
cluster_info = provision.get_cluster_info(cloud_name,
provision_record.region,
cluster_name.name_on_cloud,
provider_config=provider_config)
if (provider_config is not None and
provider_config.get('vpn_config', None) is not None):
get_info_status = rich_utils.safe_status(
'[bold cyan]Launching - Waiting for VPN setup[/]')
else:
get_info_status = rich_utils.empty_status()
with get_info_status:
cluster_info = provision.get_cluster_info(
cloud_name,
provision_record.region,
cluster_name.name_on_cloud,
provider_config=provider_config)

if cluster_info.num_instances > 1:
# Only worker nodes have logs in the per-instance log directory. Head
Expand Down Expand Up @@ -456,7 +464,7 @@ def _post_provision_setup(
logger.debug(
f'\nWaiting for SSH to be available for {cluster_name!r} ...')
wait_for_ssh(cluster_info, ssh_credentials)
logger.debug(f'SSH Conection ready for {cluster_name!r}')
logger.debug(f'SSH Connection ready for {cluster_name!r}')
plural = '' if len(cluster_info.instances) == 1 else 's'
logger.info(f'{colorama.Fore.GREEN}Successfully provisioned '
f'or found existing instance{plural}.'
Expand Down
Loading
Loading