-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Pulling in dependencies (in_process mode) using conda environment #4807
Changes from 45 commits
2cc906b
b25295a
fb28458
3576ea9
d3b8e9b
68cede1
02e54ef
f39cca6
cc0ca14
18fc3f2
495c7b4
1121f47
b6062a7
1ec209c
ca6c818
cd3dbaa
f52f36c
1843210
d0fe3ac
1b93244
b40f36c
68000e1
e53c47f
2edec4f
54480df
28c581e
2ac83eb
f2e1c4b
3b1ddfd
00794d0
d842049
84b9d2d
c7c81de
05f118a
298d1e1
664294c
3f16cac
e826972
2d5ce13
8419175
2affabb
d158f95
d97c261
1a92621
0528cdc
9129ee9
91a4a9f
1162751
9449660
b22492d
fa044d5
f593986
ec592b4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""Requirements Manager class to pull in client dependencies from a .txt or .yml file""" | ||
from __future__ import absolute_import | ||
import logging | ||
import os | ||
import subprocess | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class RequirementsManager: | ||
"""Transformers build logic with ModelBuilder()""" | ||
|
||
def detect_file_exists(self, dependencies: str = None) -> str: | ||
"""Creates snapshot of the user's environment | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Not sure if this is entirely right? we are detecting not snapshotting anything? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes that is right, good catch, I will rephrase. |
||
|
||
If a req.txt or conda.yml file is provided, it verifies their existence and | ||
returns the local file path | ||
|
||
Args: | ||
dependencies (str): Local path where dependencies file exists. | ||
|
||
Returns: | ||
file path of the existing or generated dependencies file | ||
""" | ||
dependencies = self._capture_from_local_runtime() | ||
|
||
# Dependencies specified as either req.txt or conda_env.yml | ||
if dependencies.endswith(".txt"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. _detect_conda_env_and_local_dependencies() will either return conda_in_process.yml or raise ValueError("No conda environment seems to be active.") , in other words, if dependencies.endswith(".txt") will never evaluate to true There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was a good find. Thanks for your help with modifying it! |
||
self._install_requirements_txt() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you're doing multiple things in one function. It's generally recommended to have functions as single tenant. Also, if you want to install while check path existence, you probably want to change the naming as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My naming was making the functions confusing to understand, I will rename so they better fit their purpose. Thank you for pointing this out. |
||
elif dependencies.endswith(".yml"): | ||
self._update_conda_env_in_path() | ||
else: | ||
raise ValueError(f'Invalid dependencies provided: "{dependencies}"') | ||
|
||
def _install_requirements_txt(self): | ||
"""Install requirements.txt file using pip""" | ||
logger.info("Running command to pip install") | ||
subprocess.run("pip install -r requirements.txt", shell=True, check=True) | ||
logger.info("Command ran successfully") | ||
|
||
def _update_conda_env_in_path(self): | ||
"""Update conda env using conda yml file""" | ||
logger.info("Updating conda env") | ||
subprocess.run("conda env update -f conda_in_process.yml", shell=True, check=True) | ||
logger.info("Conda env updated successfully") | ||
|
||
def _get_active_conda_env_name(self) -> str: | ||
"""Returns the conda environment name from the set environment variable. None otherwise.""" | ||
return os.getenv("CONDA_DEFAULT_ENV") | ||
|
||
def _get_active_conda_env_prefix(self) -> str: | ||
"""Returns the conda prefix from the set environment variable. None otherwise.""" | ||
return os.getenv("CONDA_PREFIX") | ||
|
||
def _capture_from_local_runtime(self) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I missed something here, but isn't this function just to generate an empty file? Can you point me to where dependencies are actually captured? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will be making a new .txt file similar to the .yml one with required dependencies, good catch, thank you. |
||
"""Generates dependencies list from the user's local runtime. | ||
|
||
Raises RuntimeEnvironmentError if not able to. | ||
|
||
Currently supports: conda environments | ||
""" | ||
|
||
# Try to capture dependencies from the conda environment, if any. | ||
conda_env_name = self._get_active_conda_env_name() | ||
logger.info("Found conda_env_name: '%s'", conda_env_name) | ||
conda_env_prefix = None | ||
|
||
if conda_env_name is None: | ||
conda_env_prefix = self._get_active_conda_env_prefix() | ||
|
||
if conda_env_name is None and conda_env_prefix is None: | ||
raise ValueError("No conda environment seems to be active.") | ||
|
||
if conda_env_name == "base": | ||
logger.warning( | ||
"We recommend using an environment other than base to " | ||
"isolate your project dependencies from conda dependencies" | ||
) | ||
|
||
local_dependencies_path = os.path.join(os.getcwd(), "inf_env_snapshot.yml") | ||
|
||
return local_dependencies_path | ||
|
||
|
||
if __name__ == "__main__": | ||
RequirementsManager().detect_file_exists(dependencies="auto_capture") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How is auto_capture as a value used? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It ended up being unused so I have removed it. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
from abc import ABC, abstractmethod | ||
from typing import Type | ||
from pathlib import Path | ||
import subprocess | ||
from packaging.version import Version | ||
|
||
from sagemaker.model import Model | ||
|
@@ -35,15 +36,21 @@ | |
) | ||
from sagemaker.serve.detector.pickler import save_pkl | ||
from sagemaker.serve.utils.optimize_utils import _is_optimized | ||
from sagemaker.serve.utils.predictors import TransformersLocalModePredictor | ||
from sagemaker.serve.utils.predictors import ( | ||
TransformersLocalModePredictor, | ||
TransformersInProcessModePredictor, | ||
) | ||
from sagemaker.serve.utils.types import ModelServer | ||
from sagemaker.serve.mode.function_pointers import Mode | ||
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry | ||
from sagemaker.base_predictor import PredictorBase | ||
from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata | ||
from sagemaker.serve.builder.requirements_manager import RequirementsManager | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
DEFAULT_TIMEOUT = 1800 | ||
LOCAL_MODES = [Mode.LOCAL_CONTAINER, Mode.IN_PROCESS] | ||
|
||
|
||
"""Retrieves images for different libraries - Pytorch, TensorFlow from HuggingFace hub | ||
|
@@ -227,6 +234,22 @@ def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[Pr | |
) | ||
return predictor | ||
|
||
if self.mode == Mode.IN_PROCESS: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please clean this up There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will be removed. |
||
timeout = kwargs.get("model_data_download_timeout") | ||
|
||
predictor = TransformersInProcessModePredictor( | ||
self.modes[str(Mode.IN_PROCESS)], serializer, deserializer | ||
) | ||
|
||
self.modes[str(Mode.IN_PROCESS)].create_server( | ||
self.image_uri, | ||
timeout if timeout else DEFAULT_TIMEOUT, | ||
None, | ||
predictor, | ||
self.pysdk_model.env, | ||
) | ||
return predictor | ||
|
||
if "mode" in kwargs: | ||
del kwargs["mode"] | ||
if "role" in kwargs: | ||
|
@@ -274,7 +297,7 @@ def _build_transformers_env(self): | |
|
||
self.pysdk_model = self._create_transformers_model() | ||
|
||
if self.mode == Mode.LOCAL_CONTAINER: | ||
if self.mode in LOCAL_MODES: | ||
self._prepare_for_mode() | ||
|
||
return self.pysdk_model | ||
|
@@ -358,6 +381,9 @@ def _build_for_transformers(self): | |
save_pkl(code_path, (self.inference_spec, self.schema_builder)) | ||
logger.info("PKL file saved to file: %s", code_path) | ||
|
||
if self.mode == Mode.IN_PROCESS: | ||
self._create_conda_env() | ||
|
||
self._auto_detect_container() | ||
|
||
self.secret_key = prepare_for_mms( | ||
|
@@ -376,3 +402,11 @@ def _build_for_transformers(self): | |
if self.sagemaker_session: | ||
self.pysdk_model.sagemaker_session = self.sagemaker_session | ||
return self.pysdk_model | ||
|
||
def _create_conda_env(self): | ||
"""Creating conda environment by running commands""" | ||
|
||
try: | ||
RequirementsManager().detect_file_exists(self) | ||
except subprocess.CalledProcessError: | ||
print("Failed to create and activate conda environment.") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
"""Module that defines the InProcessMode class""" | ||
|
||
from __future__ import absolute_import | ||
from pathlib import Path | ||
import logging | ||
from typing import Dict, Type | ||
import time | ||
from datetime import datetime, timedelta | ||
|
||
from sagemaker.base_predictor import PredictorBase | ||
from sagemaker.serve.spec.inference_spec import InferenceSpec | ||
from sagemaker.serve.builder.schema_builder import SchemaBuilder | ||
from sagemaker.serve.utils.types import ModelServer | ||
from sagemaker.serve.utils.exceptions import LocalDeepPingException | ||
from sagemaker.serve.model_server.multi_model_server.server import InProcessMultiModelServer | ||
from sagemaker.session import Session | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
_PING_HEALTH_CHECK_FAIL_MSG = ( | ||
"Ping health check did not pass. " | ||
+ "Please increase container_timeout_seconds or review your inference code." | ||
) | ||
|
||
|
||
class InProcessMode( | ||
InProcessMultiModelServer, | ||
): | ||
"""A class that holds methods to deploy model to a container in process environment""" | ||
|
||
def __init__( | ||
self, | ||
model_server: ModelServer, | ||
inference_spec: Type[InferenceSpec], | ||
schema_builder: Type[SchemaBuilder], | ||
session: Session, | ||
model_path: str = None, | ||
env_vars: Dict = None, | ||
): | ||
# pylint: disable=bad-super-call | ||
super().__init__() | ||
|
||
self.inference_spec = inference_spec | ||
self.model_path = model_path | ||
self.env_vars = env_vars | ||
self.session = session | ||
self.schema_builder = schema_builder | ||
self.model_server = model_server | ||
self._ping_container = None | ||
|
||
def load(self, model_path: str = None): | ||
"""Loads model path, checks that path exists""" | ||
path = Path(model_path if model_path else self.model_path) | ||
if not path.exists(): | ||
raise ValueError("model_path does not exist") | ||
if not path.is_dir(): | ||
raise ValueError("model_path is not a valid directory") | ||
|
||
return self.inference_spec.load(str(path)) | ||
|
||
def prepare(self): | ||
"""Prepares the server""" | ||
|
||
def create_server( | ||
self, | ||
predictor: PredictorBase, | ||
): | ||
"""Creating the server and checking ping health.""" | ||
logger.info("Waiting for model server %s to start up...", self.model_server) | ||
|
||
if self.model_server == ModelServer.MMS: | ||
self._ping_container = self._multi_model_server_deep_ping | ||
|
||
time_limit = datetime.now() + timedelta(seconds=5) | ||
while self._ping_container is not None: | ||
final_pull = datetime.now() > time_limit | ||
|
||
if final_pull: | ||
break | ||
|
||
time.sleep(10) | ||
|
||
healthy, response = self._ping_container(predictor) | ||
if healthy: | ||
logger.debug("Ping health check has passed. Returned %s", str(response)) | ||
break | ||
|
||
if not healthy: | ||
raise LocalDeepPingException(_PING_HEALTH_CHECK_FAIL_MSG) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update this comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.