Skip to content

Commit

Permalink
feat: allow security group specification
Browse files Browse the repository at this point in the history
  • Loading branch information
JGSweets committed May 7, 2024
1 parent 904aa5c commit ac86ed7
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 14 deletions.
18 changes: 12 additions & 6 deletions sky/clouds/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,15 +400,21 @@ def make_deploy_resources_variables(self,

user_security_group = skypilot_config.get_nested(
('aws', 'security_group_name'), None)
if resources.ports is not None:
if user_security_group is not None and not isinstance(
user_security_group, str):
for sg_name in user_security_group:
if cluster_name_on_cloud.startswith(
sg_name) and sg_name != 'default':
user_security_group = user_security_group[sg_name]
break
elif sg_name == 'default':
user_security_group = user_security_group[sg_name]
security_group = user_security_group
if user_security_group is None and resources.ports is not None:
# Already checked in Resources._try_validate_ports
assert user_security_group is None
security_group = USER_PORTS_SECURITY_GROUP_NAME.format(
cluster_name_on_cloud)
elif user_security_group is not None:
assert resources.ports is None
security_group = user_security_group
else:
elif user_security_group is None:
security_group = DEFAULT_SECURITY_GROUP_NAME

return {
Expand Down
6 changes: 0 additions & 6 deletions sky/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,12 +920,6 @@ def _try_validate_ports(self) -> None:
"""
if self.ports is None:
return
if skypilot_config.get_nested(('aws', 'security_group_name'),
None) is not None:
with ux_utils.print_exception_no_traceback():
raise ValueError(
'Cannot specify ports when AWS security group name is '
'specified.')
if self.cloud is not None:
self.cloud.check_features_are_supported(
self, {clouds.CloudImplementationFeatures.OPEN_PORTS})
Expand Down
20 changes: 18 additions & 2 deletions sky/utils/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def _get_single_resources_schema():
'type': 'integer',
}]
}
}, {
'type': 'null',
}],
},
'labels': {
Expand Down Expand Up @@ -567,7 +569,7 @@ def get_config_schema():
# Validation may fail if $schema is included.
if k != '$schema'
}
resources_schema['properties'].pop('ports')
resources_schema['properties'].pop('port', None)
controller_resources_schema = {
'type': 'object',
'required': [],
Expand All @@ -590,7 +592,21 @@ def get_config_schema():
'additionalProperties': False,
'properties': {
'security_group_name': {
'type': 'string'
'oneOf': [{
'type': 'string'
}, {
'type': 'object',
'additionalProperties': False,
'required': ['default'],
'properties': {
'sky-serve-controller': {
'type': 'string',
},
'default': {
'type': 'string'
}
}
}]
},
**_LABELS_SCHEMA,
**_NETWORK_CONFIG_SCHEMA,
Expand Down

0 comments on commit ac86ed7

Please sign in to comment.