Skip to content

Commit

Permalink
fix: Model server override logic (#4733)
Browse files Browse the repository at this point in the history
* fix: Model server override logic

* Fix formatting

---------

Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com>
  • Loading branch information
samruds and benieric authored Jun 19, 2024
1 parent 0984e8d commit 4496072
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 32 deletions.
84 changes: 56 additions & 28 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,15 @@

logger = logging.getLogger(__name__)

supported_model_server = {
# Any new server type should be added here
supported_model_servers = {
ModelServer.TORCHSERVE,
ModelServer.TRITON,
ModelServer.DJL_SERVING,
ModelServer.TENSORFLOW_SERVING,
ModelServer.MMS,
ModelServer.TGI,
ModelServer.TEI,
}


Expand Down Expand Up @@ -288,31 +292,6 @@ class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing,
},
)

def _build_validations(self):
"""Placeholder docstring"""
# TODO: Beta validations - remove after the launch
if self.mode == Mode.IN_PROCESS:
raise ValueError("IN_PROCESS mode is not supported yet!")

if self.inference_spec and self.model:
raise ValueError("Cannot have both the Model and Inference spec in the builder")

if self.image_uri and not is_1p_image_uri(self.image_uri) and self.model_server is None:
raise ValueError(
"Model_server must be set when non-first-party image_uri is set. "
+ "Supported model servers: %s" % supported_model_server
)

# Set TorchServe as default model server
if not self.model_server:
self.model_server = ModelServer.TORCHSERVE

if self.model_server not in supported_model_server:
raise ValueError(
"%s is not supported yet! Supported model servers: %s"
% (self.model_server, supported_model_server)
)

def _save_model_inference_spec(self):
"""Placeholder docstring"""
# check if path exists and create if not
Expand Down Expand Up @@ -839,6 +818,11 @@ def build( # pylint: disable=R0911

self._handle_mlflow_input()

self._build_validations()

if self.model_server:
return self._build_for_model_server()

if isinstance(self.model, str):
model_task = None
if self.model_metadata:
Expand Down Expand Up @@ -870,7 +854,41 @@ def build( # pylint: disable=R0911
else:
return self._build_for_transformers()

self._build_validations()
# Set TorchServe as default model server
if not self.model_server:
self.model_server = ModelServer.TORCHSERVE
return self._build_for_torchserve()

raise ValueError("%s model server is not supported" % self.model_server)

def _build_validations(self):
"""Validations needed for model server overrides, or auto-detection or fallback"""
if self.mode == Mode.IN_PROCESS:
raise ValueError("IN_PROCESS mode is not supported yet!")

if self.inference_spec and self.model:
raise ValueError("Can only set one of the following: model, inference_spec.")

if self.image_uri and not is_1p_image_uri(self.image_uri) and self.model_server is None:
raise ValueError(
"Model_server must be set when non-first-party image_uri is set. "
+ "Supported model servers: %s" % supported_model_servers
)

def _build_for_model_server(self): # pylint: disable=R0911, R1710
"""Model server overrides"""
if self.model_server not in supported_model_servers:
raise ValueError(
"%s is not supported yet! Supported model servers: %s"
% (self.model_server, supported_model_servers)
)

mlflow_path = None
if self.model_metadata:
mlflow_path = self.model_metadata.get(MLFLOW_MODEL_PATH)

if not self.model and not mlflow_path:
raise ValueError("Missing required parameter `model` or 'ml_flow' path")

if self.model_server == ModelServer.TORCHSERVE:
return self._build_for_torchserve()
Expand All @@ -881,7 +899,17 @@ def build( # pylint: disable=R0911
if self.model_server == ModelServer.TENSORFLOW_SERVING:
return self._build_for_tensorflow_serving()

raise ValueError("%s model server is not supported" % self.model_server)
if self.model_server == ModelServer.DJL_SERVING:
return self._build_for_djl()

if self.model_server == ModelServer.TEI:
return self._build_for_tei()

if self.model_server == ModelServer.TGI:
return self._build_for_tgi()

if self.model_server == ModelServer.MMS:
return self._build_for_transformers()

def save(
self,
Expand Down
125 changes: 121 additions & 4 deletions tests/unit/sagemaker/serve/builder/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,14 @@
mock_secret_key = "mock_secret_key"
mock_instance_type = "mock instance type"

supported_model_server = {
supported_model_servers = {
ModelServer.TORCHSERVE,
ModelServer.TRITON,
ModelServer.DJL_SERVING,
ModelServer.TENSORFLOW_SERVING,
ModelServer.MMS,
ModelServer.TGI,
ModelServer.TEI,
}

mock_session = MagicMock()
Expand All @@ -78,7 +81,7 @@ def test_validation_cannot_set_both_model_and_inference_spec(self, mock_serveSet
builder = ModelBuilder(inference_spec="some value", model=Mock(spec=object))
self.assertRaisesRegex(
Exception,
"Cannot have both the Model and Inference spec in the builder",
"Can only set one of the following: model, inference_spec.",
builder.build,
Mode.SAGEMAKER_ENDPOINT,
mock_role_arn,
Expand All @@ -91,7 +94,7 @@ def test_validation_unsupported_model_server_type(self, mock_serveSettings):
self.assertRaisesRegex(
Exception,
"%s is not supported yet! Supported model servers: %s"
% (builder.model_server, supported_model_server),
% (builder.model_server, supported_model_servers),
builder.build,
Mode.SAGEMAKER_ENDPOINT,
mock_role_arn,
Expand All @@ -104,7 +107,7 @@ def test_validation_model_server_not_set_with_image_uri(self, mock_serveSettings
self.assertRaisesRegex(
Exception,
"Model_server must be set when non-first-party image_uri is set. "
+ "Supported model servers: %s" % supported_model_server,
+ "Supported model servers: %s" % supported_model_servers,
builder.build,
Mode.SAGEMAKER_ENDPOINT,
mock_role_arn,
Expand All @@ -125,6 +128,120 @@ def test_save_model_throw_exception_when_none_of_model_and_inference_spec_is_set
mock_session,
)

@patch("sagemaker.serve.builder.model_builder._ServeSettings")
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl")
def test_model_server_override_djl_with_model(self, mock_build_for_djl, mock_serve_settings):
mock_setting_object = mock_serve_settings.return_value
mock_setting_object.role_arn = mock_role_arn
mock_setting_object.s3_model_data_url = mock_s3_model_data_url

builder = ModelBuilder(model_server=ModelServer.DJL_SERVING, model="gpt_llm_burt")
builder.build(sagemaker_session=mock_session)

mock_build_for_djl.assert_called_once()

@patch("sagemaker.serve.builder.model_builder._ServeSettings")
def test_model_server_override_djl_without_model_or_mlflow(self, mock_serve_settings):
builder = ModelBuilder(
model_server=ModelServer.DJL_SERVING, model=None, inference_spec=None
)
self.assertRaisesRegex(
Exception,
"Missing required parameter `model` or 'ml_flow' path",
builder.build,
Mode.SAGEMAKER_ENDPOINT,
mock_role_arn,
mock_session,
)

@patch("sagemaker.serve.builder.model_builder._ServeSettings")
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_torchserve")
def test_model_server_override_torchserve_with_model(
self, mock_build_for_ts, mock_serve_settings
):
mock_setting_object = mock_serve_settings.return_value
mock_setting_object.role_arn = mock_role_arn
mock_setting_object.s3_model_data_url = mock_s3_model_data_url

builder = ModelBuilder(model_server=ModelServer.TORCHSERVE, model="gpt_llm_burt")
builder.build(sagemaker_session=mock_session)

mock_build_for_ts.assert_called_once()

@patch("sagemaker.serve.builder.model_builder._ServeSettings")
def test_model_server_override_torchserve_without_model_or_mlflow(self, mock_serve_settings):
builder = ModelBuilder(model_server=ModelServer.TORCHSERVE)
self.assertRaisesRegex(
Exception,
"Missing required parameter `model` or 'ml_flow' path",
builder.build,
Mode.SAGEMAKER_ENDPOINT,
mock_role_arn,
mock_session,
)

@patch("sagemaker.serve.builder.model_builder._ServeSettings")
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_triton")
def test_model_server_override_triton_with_model(self, mock_build_for_ts, mock_serve_settings):
mock_setting_object = mock_serve_settings.return_value
mock_setting_object.role_arn = mock_role_arn
mock_setting_object.s3_model_data_url = mock_s3_model_data_url

builder = ModelBuilder(model_server=ModelServer.TRITON, model="gpt_llm_burt")
builder.build(sagemaker_session=mock_session)

mock_build_for_ts.assert_called_once()

@patch("sagemaker.serve.builder.model_builder._ServeSettings")
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tensorflow_serving")
def test_model_server_override_tensor_with_model(self, mock_build_for_ts, mock_serve_settings):
mock_setting_object = mock_serve_settings.return_value
mock_setting_object.role_arn = mock_role_arn
mock_setting_object.s3_model_data_url = mock_s3_model_data_url

builder = ModelBuilder(model_server=ModelServer.TENSORFLOW_SERVING, model="gpt_llm_burt")
builder.build(sagemaker_session=mock_session)

mock_build_for_ts.assert_called_once()

@patch("sagemaker.serve.builder.model_builder._ServeSettings")
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tei")
def test_model_server_override_tei_with_model(self, mock_build_for_ts, mock_serve_settings):
mock_setting_object = mock_serve_settings.return_value
mock_setting_object.role_arn = mock_role_arn
mock_setting_object.s3_model_data_url = mock_s3_model_data_url

builder = ModelBuilder(model_server=ModelServer.TEI, model="gpt_llm_burt")
builder.build(sagemaker_session=mock_session)

mock_build_for_ts.assert_called_once()

@patch("sagemaker.serve.builder.model_builder._ServeSettings")
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_tgi")
def test_model_server_override_tgi_with_model(self, mock_build_for_ts, mock_serve_settings):
mock_setting_object = mock_serve_settings.return_value
mock_setting_object.role_arn = mock_role_arn
mock_setting_object.s3_model_data_url = mock_s3_model_data_url

builder = ModelBuilder(model_server=ModelServer.TGI, model="gpt_llm_burt")
builder.build(sagemaker_session=mock_session)

mock_build_for_ts.assert_called_once()

@patch("sagemaker.serve.builder.model_builder._ServeSettings")
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers")
def test_model_server_override_transformers_with_model(
self, mock_build_for_ts, mock_serve_settings
):
mock_setting_object = mock_serve_settings.return_value
mock_setting_object.role_arn = mock_role_arn
mock_setting_object.s3_model_data_url = mock_s3_model_data_url

builder = ModelBuilder(model_server=ModelServer.MMS, model="gpt_llm_burt")
builder.build(sagemaker_session=mock_session)

mock_build_for_ts.assert_called_once()

@patch("os.makedirs", Mock())
@patch("sagemaker.serve.builder.model_builder._detect_framework_and_version")
@patch("sagemaker.serve.builder.model_builder.prepare_for_torchserve")
Expand Down

0 comments on commit 4496072

Please sign in to comment.