Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Pulling in dependencies (in_process mode) using conda environment #4807

Merged
merged 53 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
2cc906b
InferenceSpec support for HF
Jun 26, 2024
b25295a
Merge branch 'aws:master' into hf-inf-spec-support
bryannahm1 Jun 27, 2024
fb28458
feat: InferenceSpec support for MMS and testing
Jun 27, 2024
3576ea9
Introduce changes for InProcess Mode
Jun 29, 2024
d3b8e9b
mb_inprocess updates
Jul 3, 2024
68cede1
In_Process mode for TGI transformers, edits
Jul 8, 2024
02e54ef
Remove InfSpec from branch
Jul 8, 2024
f39cca6
merge from master for inf spec
Jul 12, 2024
cc0ca14
changes to support in_process
Jul 13, 2024
18fc3f2
changes to get pre-checks passing
Jul 15, 2024
495c7b4
pylint fix
Jul 15, 2024
1121f47
unit test, test mb
Jul 15, 2024
b6062a7
period missing, added
Jul 15, 2024
1ec209c
suggestions and test added
Jul 16, 2024
ca6c818
pre-push fix
Jul 16, 2024
cd3dbaa
missing an @
Jul 16, 2024
f52f36c
fixes to test, added stubbing
Jul 17, 2024
1843210
removing for fixes
Jul 17, 2024
d0fe3ac
variable fixes
Jul 17, 2024
1b93244
init fix
Jul 17, 2024
b40f36c
tests for in process mode
Jul 18, 2024
68000e1
prepush fix
Jul 18, 2024
e53c47f
Merge branch 'mb_in_process' into dependencies-conda
Jul 23, 2024
2edec4f
deps and mb
Jul 23, 2024
54480df
changes
Jul 23, 2024
28c581e
fixing pkl
Jul 24, 2024
2ac83eb
testing
Jul 24, 2024
f2e1c4b
save pkl debug
Jul 24, 2024
3b1ddfd
changes
Jul 24, 2024
00794d0
conda create
Jul 24, 2024
d842049
Conda fixes
Jul 24, 2024
84b9d2d
random dep
Jul 24, 2024
c7c81de
subproces
Jul 24, 2024
05f118a
requirementsmanager.py script
Jul 24, 2024
298d1e1
requires manag
Jul 25, 2024
664294c
changing command
Jul 26, 2024
3f16cac
changing command
Jul 26, 2024
e826972
print
Jul 26, 2024
2d5ce13
shell=true
Jul 26, 2024
8419175
minor fix
Jul 26, 2024
2affabb
changes
Jul 26, 2024
d158f95
check=true
Jul 26, 2024
d97c261
unit test
Jul 28, 2024
1a92621
testing
Jul 31, 2024
0528cdc
unit test for requirementsmanager
Aug 1, 2024
9129ee9
removing in_process and minor edits
Aug 1, 2024
91a4a9f
format
Aug 1, 2024
1162751
.txt file
Aug 7, 2024
9449660
renaming functions
Aug 7, 2024
b22492d
fix path
Aug 7, 2024
fa044d5
making .txt evaluate to true
Aug 7, 2024
f593986
Merge branch 'master' into dependencies-conda
sage-maker Aug 7, 2024
ec592b4
Merge branch 'master' into dependencies-conda
sage-maker Aug 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from sagemaker.serve.mode.function_pointers import Mode
from sagemaker.serve.mode.sagemaker_endpoint_mode import SageMakerEndpointMode
from sagemaker.serve.mode.local_container_mode import LocalContainerMode
from sagemaker.serve.mode.in_process_mode import InProcessMode
from sagemaker.serve.detector.pickler import save_pkl, save_xgboost
from sagemaker.serve.builder.serve_settings import _ServeSettings
from sagemaker.serve.builder.djl_builder import DJL
Expand Down Expand Up @@ -410,7 +411,7 @@ def _prepare_for_mode(
)
self.env_vars.update(env_vars_sagemaker)
return self.s3_upload_path, env_vars_sagemaker
if self.mode == Mode.LOCAL_CONTAINER:
elif self.mode == Mode.LOCAL_CONTAINER:
# init the LocalContainerMode object
self.modes[str(Mode.LOCAL_CONTAINER)] = LocalContainerMode(
inference_spec=self.inference_spec,
Expand All @@ -422,9 +423,22 @@ def _prepare_for_mode(
)
self.modes[str(Mode.LOCAL_CONTAINER)].prepare()
return None
elif self.mode == Mode.IN_PROCESS:
# init the InProcessMode object
self.modes[str(Mode.IN_PROCESS)] = InProcessMode(
inference_spec=self.inference_spec,
schema_builder=self.schema_builder,
session=self.sagemaker_session,
model_path=self.model_path,
env_vars=self.env_vars,
model_server=self.model_server,
)
self.modes[str(Mode.IN_PROCESS)].prepare()
return None

raise ValueError(
"Please specify mode in: %s, %s" % (Mode.LOCAL_CONTAINER, Mode.SAGEMAKER_ENDPOINT)
"Please specify mode in: %s, %s, %s"
% (Mode.LOCAL_CONTAINER, Mode.SAGEMAKER_ENDPOINT, Mode.IN_PROCESS)
)

def _get_client_translators(self):
Expand Down Expand Up @@ -603,10 +617,12 @@ def _overwrite_mode_in_deploy(self, overwrite_mode: str):
s3_upload_path, env_vars_sagemaker = self._prepare_for_mode()
self.pysdk_model.model_data = s3_upload_path
self.pysdk_model.env.update(env_vars_sagemaker)

elif overwrite_mode == Mode.LOCAL_CONTAINER:
self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER
self._prepare_for_mode()
elif overwrite_mode == Mode.IN_PROCESS:
self.mode = self.pysdk_model.mode = Mode.IN_PROCESS
self._prepare_for_mode()
else:
raise ValueError("Mode %s is not supported!" % overwrite_mode)

Expand Down Expand Up @@ -796,9 +812,10 @@ def _initialize_for_mlflow(self, artifact_path: str) -> None:
self.dependencies.update({"requirements": mlflow_model_dependency_path})

# Model Builder is a class to build the model for deployment.
# It supports two modes of deployment
# It supports two* modes of deployment
# 1/ SageMaker Endpoint
# 2/ Local launch with container
# 3/ In process mode with Transformers server in beta release
def build( # pylint: disable=R0911
self,
mode: Type[Mode] = None,
Expand Down Expand Up @@ -896,8 +913,10 @@ def build( # pylint: disable=R0911

def _build_validations(self):
"""Validations needed for model server overrides, or auto-detection or fallback"""
if self.mode == Mode.IN_PROCESS:
raise ValueError("IN_PROCESS mode is not supported yet!")
if self.mode == Mode.IN_PROCESS and self.model_server is not ModelServer.MMS:
raise ValueError(
"IN_PROCESS mode is only supported for MMS/Transformers server in beta release."
)

if self.inference_spec and self.model:
raise ValueError("Can only set one of the following: model, inference_spec.")
Expand Down
98 changes: 98 additions & 0 deletions src/sagemaker/serve/builder/requirements_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Requirements Manager class to pull in client dependencies from a .txt or .yml file"""
from __future__ import absolute_import
import logging
import os
import subprocess

logger = logging.getLogger(__name__)


class RequirementsManager:
"""Transformers build logic with ModelBuilder()"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update this comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.


def detect_file_exists(self, dependencies: str = None) -> str:
"""Creates snapshot of the user's environment
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Not sure if this is entirely right? we are detecting not snapshotting anything?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that is right, good catch, I will rephrase.


If a req.txt or conda.yml file is provided, it verifies their existence and
returns the local file path

Args:
dependencies (str): Local path where dependencies file exists.

Returns:
file path of the existing or generated dependencies file
"""
dependencies = self._capture_from_local_runtime()

# Dependencies specified as either req.txt or conda_env.yml
if dependencies.endswith(".txt"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_detect_conda_env_and_local_dependencies() will either return conda_in_process.yml or raise ValueError("No conda environment seems to be active.") , in other words, if dependencies.endswith(".txt") will never evaluate to true

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a good find. Thanks for your help with modifying it!

self._install_requirements_txt()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're doing multiple things in one function. It's generally recommended to have functions as single tenant. Also, if you want to install while check path existence, you probably want to change the naming as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My naming was making the functions confusing to understand, I will rename so they better fit their purpose. Thank you for pointing this out.

elif dependencies.endswith(".yml"):
self._update_conda_env_in_path()
else:
raise ValueError(f'Invalid dependencies provided: "{dependencies}"')

def _install_requirements_txt(self):
"""Install requirements.txt file using pip"""
logger.info("Running command to pip install")
subprocess.run("pip install -r requirements.txt", shell=True, check=True)
logger.info("Command ran successfully")

def _update_conda_env_in_path(self):
"""Update conda env using conda yml file"""
logger.info("Updating conda env")
subprocess.run("conda env update -f conda_in_process.yml", shell=True, check=True)
logger.info("Conda env updated successfully")

def _get_active_conda_env_name(self) -> str:
"""Returns the conda environment name from the set environment variable. None otherwise."""
return os.getenv("CONDA_DEFAULT_ENV")

def _get_active_conda_env_prefix(self) -> str:
"""Returns the conda prefix from the set environment variable. None otherwise."""
return os.getenv("CONDA_PREFIX")

def _capture_from_local_runtime(self) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I missed something here, but isn't this function just to generate an empty file? Can you point me to where dependencies are actually captured?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will be making a new .txt file similar to the .yml one with required dependencies, good catch, thank you.

"""Generates dependencies list from the user's local runtime.

Raises RuntimeEnvironmentError if not able to.

Currently supports: conda environments
"""

# Try to capture dependencies from the conda environment, if any.
conda_env_name = self._get_active_conda_env_name()
logger.info("Found conda_env_name: '%s'", conda_env_name)
conda_env_prefix = None

if conda_env_name is None:
conda_env_prefix = self._get_active_conda_env_prefix()

if conda_env_name is None and conda_env_prefix is None:
raise ValueError("No conda environment seems to be active.")

if conda_env_name == "base":
logger.warning(
"We recommend using an environment other than base to "
"isolate your project dependencies from conda dependencies"
)

local_dependencies_path = os.path.join(os.getcwd(), "inf_env_snapshot.yml")

return local_dependencies_path


if __name__ == "__main__":
RequirementsManager().detect_file_exists(dependencies="auto_capture")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is auto_capture as a value used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It ended up being unused so I have removed it.

38 changes: 36 additions & 2 deletions src/sagemaker/serve/builder/transformers_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from abc import ABC, abstractmethod
from typing import Type
from pathlib import Path
import subprocess
from packaging.version import Version

from sagemaker.model import Model
Expand All @@ -35,15 +36,21 @@
)
from sagemaker.serve.detector.pickler import save_pkl
from sagemaker.serve.utils.optimize_utils import _is_optimized
from sagemaker.serve.utils.predictors import TransformersLocalModePredictor
from sagemaker.serve.utils.predictors import (
TransformersLocalModePredictor,
TransformersInProcessModePredictor,
)
from sagemaker.serve.utils.types import ModelServer
from sagemaker.serve.mode.function_pointers import Mode
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
from sagemaker.base_predictor import PredictorBase
from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata
from sagemaker.serve.builder.requirements_manager import RequirementsManager


logger = logging.getLogger(__name__)
DEFAULT_TIMEOUT = 1800
LOCAL_MODES = [Mode.LOCAL_CONTAINER, Mode.IN_PROCESS]


"""Retrieves images for different libraries - Pytorch, TensorFlow from HuggingFace hub
Expand Down Expand Up @@ -227,6 +234,22 @@ def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[Pr
)
return predictor

if self.mode == Mode.IN_PROCESS:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please clean this up

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be removed.

timeout = kwargs.get("model_data_download_timeout")

predictor = TransformersInProcessModePredictor(
self.modes[str(Mode.IN_PROCESS)], serializer, deserializer
)

self.modes[str(Mode.IN_PROCESS)].create_server(
self.image_uri,
timeout if timeout else DEFAULT_TIMEOUT,
None,
predictor,
self.pysdk_model.env,
)
return predictor

if "mode" in kwargs:
del kwargs["mode"]
if "role" in kwargs:
Expand Down Expand Up @@ -274,7 +297,7 @@ def _build_transformers_env(self):

self.pysdk_model = self._create_transformers_model()

if self.mode == Mode.LOCAL_CONTAINER:
if self.mode in LOCAL_MODES:
self._prepare_for_mode()

return self.pysdk_model
Expand Down Expand Up @@ -358,6 +381,9 @@ def _build_for_transformers(self):
save_pkl(code_path, (self.inference_spec, self.schema_builder))
logger.info("PKL file saved to file: %s", code_path)

if self.mode == Mode.IN_PROCESS:
self._create_conda_env()

self._auto_detect_container()

self.secret_key = prepare_for_mms(
Expand All @@ -376,3 +402,11 @@ def _build_for_transformers(self):
if self.sagemaker_session:
self.pysdk_model.sagemaker_session = self.sagemaker_session
return self.pysdk_model

def _create_conda_env(self):
"""Creating conda environment by running commands"""

try:
RequirementsManager().detect_file_exists(self)
except subprocess.CalledProcessError:
print("Failed to create and activate conda environment.")
89 changes: 89 additions & 0 deletions src/sagemaker/serve/mode/in_process_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Module that defines the InProcessMode class"""

from __future__ import absolute_import
from pathlib import Path
import logging
from typing import Dict, Type
import time
from datetime import datetime, timedelta

from sagemaker.base_predictor import PredictorBase
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.builder.schema_builder import SchemaBuilder
from sagemaker.serve.utils.types import ModelServer
from sagemaker.serve.utils.exceptions import LocalDeepPingException
from sagemaker.serve.model_server.multi_model_server.server import InProcessMultiModelServer
from sagemaker.session import Session

logger = logging.getLogger(__name__)

_PING_HEALTH_CHECK_FAIL_MSG = (
"Ping health check did not pass. "
+ "Please increase container_timeout_seconds or review your inference code."
)


class InProcessMode(
InProcessMultiModelServer,
):
"""A class that holds methods to deploy model to a container in process environment"""

def __init__(
self,
model_server: ModelServer,
inference_spec: Type[InferenceSpec],
schema_builder: Type[SchemaBuilder],
session: Session,
model_path: str = None,
env_vars: Dict = None,
):
# pylint: disable=bad-super-call
super().__init__()

self.inference_spec = inference_spec
self.model_path = model_path
self.env_vars = env_vars
self.session = session
self.schema_builder = schema_builder
self.model_server = model_server
self._ping_container = None

def load(self, model_path: str = None):
"""Loads model path, checks that path exists"""
path = Path(model_path if model_path else self.model_path)
if not path.exists():
raise ValueError("model_path does not exist")
if not path.is_dir():
raise ValueError("model_path is not a valid directory")

return self.inference_spec.load(str(path))

def prepare(self):
"""Prepares the server"""

def create_server(
self,
predictor: PredictorBase,
):
"""Creating the server and checking ping health."""
logger.info("Waiting for model server %s to start up...", self.model_server)

if self.model_server == ModelServer.MMS:
self._ping_container = self._multi_model_server_deep_ping

time_limit = datetime.now() + timedelta(seconds=5)
while self._ping_container is not None:
final_pull = datetime.now() > time_limit

if final_pull:
break

time.sleep(10)

healthy, response = self._ping_container(predictor)
if healthy:
logger.debug("Ping health check has passed. Returned %s", str(response))
break

if not healthy:
raise LocalDeepPingException(_PING_HEALTH_CHECK_FAIL_MSG)
Loading
Loading