Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Pulling in dependencies (in_process mode) using conda environment #4807

Merged
merged 53 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
2cc906b
InferenceSpec support for HF
Jun 26, 2024
b25295a
Merge branch 'aws:master' into hf-inf-spec-support
bryannahm1 Jun 27, 2024
fb28458
feat: InferenceSpec support for MMS and testing
Jun 27, 2024
3576ea9
Introduce changes for InProcess Mode
Jun 29, 2024
d3b8e9b
mb_inprocess updates
Jul 3, 2024
68cede1
In_Process mode for TGI transformers, edits
Jul 8, 2024
02e54ef
Remove InfSpec from branch
Jul 8, 2024
f39cca6
merge from master for inf spec
Jul 12, 2024
cc0ca14
changes to support in_process
Jul 13, 2024
18fc3f2
changes to get pre-checks passing
Jul 15, 2024
495c7b4
pylint fix
Jul 15, 2024
1121f47
unit test, test mb
Jul 15, 2024
b6062a7
period missing, added
Jul 15, 2024
1ec209c
suggestions and test added
Jul 16, 2024
ca6c818
pre-push fix
Jul 16, 2024
cd3dbaa
missing an @
Jul 16, 2024
f52f36c
fixes to test, added stubbing
Jul 17, 2024
1843210
removing for fixes
Jul 17, 2024
d0fe3ac
variable fixes
Jul 17, 2024
1b93244
init fix
Jul 17, 2024
b40f36c
tests for in process mode
Jul 18, 2024
68000e1
prepush fix
Jul 18, 2024
e53c47f
Merge branch 'mb_in_process' into dependencies-conda
Jul 23, 2024
2edec4f
deps and mb
Jul 23, 2024
54480df
changes
Jul 23, 2024
28c581e
fixing pkl
Jul 24, 2024
2ac83eb
testing
Jul 24, 2024
f2e1c4b
save pkl debug
Jul 24, 2024
3b1ddfd
changes
Jul 24, 2024
00794d0
conda create
Jul 24, 2024
d842049
Conda fixes
Jul 24, 2024
84b9d2d
random dep
Jul 24, 2024
c7c81de
subproces
Jul 24, 2024
05f118a
requirementsmanager.py script
Jul 24, 2024
298d1e1
requires manag
Jul 25, 2024
664294c
changing command
Jul 26, 2024
3f16cac
changing command
Jul 26, 2024
e826972
print
Jul 26, 2024
2d5ce13
shell=true
Jul 26, 2024
8419175
minor fix
Jul 26, 2024
2affabb
changes
Jul 26, 2024
d158f95
check=true
Jul 26, 2024
d97c261
unit test
Jul 28, 2024
1a92621
testing
Jul 31, 2024
0528cdc
unit test for requirementsmanager
Aug 1, 2024
9129ee9
removing in_process and minor edits
Aug 1, 2024
91a4a9f
format
Aug 1, 2024
1162751
.txt file
Aug 7, 2024
9449660
renaming functions
Aug 7, 2024
b22492d
fix path
Aug 7, 2024
fa044d5
making .txt evaluate to true
Aug 7, 2024
f593986
Merge branch 'master' into dependencies-conda
sage-maker Aug 7, 2024
ec592b4
Merge branch 'master' into dependencies-conda
sage-maker Aug 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,6 @@ def _overwrite_mode_in_deploy(self, overwrite_mode: str):
s3_upload_path, env_vars_sagemaker = self._prepare_for_mode()
self.pysdk_model.model_data = s3_upload_path
self.pysdk_model.env.update(env_vars_sagemaker)

elif overwrite_mode == Mode.LOCAL_CONTAINER:
self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER
self._prepare_for_mode()
Expand Down
100 changes: 100 additions & 0 deletions src/sagemaker/serve/builder/requirements_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Requirements Manager class to pull in client dependencies from a .txt or .yml file"""
from __future__ import absolute_import
import logging
import os
import subprocess

from typing import Optional

logger = logging.getLogger(__name__)


class RequirementsManager:
"""Manages dependency installation by detecting file types"""

def capture_and_install_dependencies(self, dependencies: Optional[str] = None) -> str:
"""Detects the type of file dependencies will be installed from

If a req.txt or conda.yml file is provided, it verifies their existence and
returns the local file path

Args:
dependencies (str): Local path where dependencies file exists.

Returns:
file path of the existing or generated dependencies file
"""
_dependencies = dependencies or self._detect_conda_env_and_local_dependencies()

# Dependencies specified as either req.txt or conda_env.yml
if _dependencies.endswith(".txt"):
self._install_requirements_txt()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're doing multiple things in one function. It's generally recommended to have functions as single tenant. Also, if you want to install while check path existence, you probably want to change the naming as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My naming was making the functions confusing to understand, I will rename so they better fit their purpose. Thank you for pointing this out.

elif _dependencies.endswith(".yml"):
self._update_conda_env_in_path()
else:
raise ValueError(f'Invalid dependencies provided: "{_dependencies}"')

def _install_requirements_txt(self):
"""Install requirements.txt file using pip"""
logger.info("Running command to pip install")
subprocess.run("pip install -r in_process_requirements.txt", shell=True, check=True)
logger.info("Command ran successfully")

def _update_conda_env_in_path(self):
"""Update conda env using conda yml file"""
logger.info("Updating conda env")
subprocess.run("conda env update -f conda_in_process.yml", shell=True, check=True)
logger.info("Conda env updated successfully")

def _get_active_conda_env_name(self) -> str:
"""Returns the conda environment name from the set environment variable. None otherwise."""
return os.getenv("CONDA_DEFAULT_ENV")

def _get_active_conda_env_prefix(self) -> str:
"""Returns the conda prefix from the set environment variable. None otherwise."""
return os.getenv("CONDA_PREFIX")

def _detect_conda_env_and_local_dependencies(self) -> str:
"""Generates dependencies list from the user's local runtime.

Raises RuntimeEnvironmentError if not able to.

Currently supports: conda environments
"""

# Try to capture dependencies from the conda environment, if any.
conda_env_name = self._get_active_conda_env_name()
logger.info("Found conda_env_name: '%s'", conda_env_name)
conda_env_prefix = None

if conda_env_name is None:
conda_env_prefix = self._get_active_conda_env_prefix()

if conda_env_name is None and conda_env_prefix is None:
local_dependencies_path = os.path.join(os.getcwd(), "in_process_requirements.txt")
logger.info(local_dependencies_path)

return local_dependencies_path

if conda_env_name == "base":
logger.warning(
"We recommend using an environment other than base to "
"isolate your project dependencies from conda dependencies"
)

local_dependencies_path = os.path.join(os.getcwd(), "conda_in_process.yml")
logger.info(local_dependencies_path)

return local_dependencies_path
14 changes: 14 additions & 0 deletions src/sagemaker/serve/builder/transformers_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from abc import ABC, abstractmethod
from typing import Type
from pathlib import Path
import subprocess
from packaging.version import Version

from sagemaker.model import Model
Expand All @@ -41,6 +42,8 @@
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
from sagemaker.base_predictor import PredictorBase
from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata
from sagemaker.serve.builder.requirements_manager import RequirementsManager


logger = logging.getLogger(__name__)
DEFAULT_TIMEOUT = 1800
Expand Down Expand Up @@ -376,6 +379,9 @@ def _build_for_transformers(self):
save_pkl(code_path, (self.inference_spec, self.schema_builder))
logger.info("PKL file saved to file: %s", code_path)

if self.mode == Mode.IN_PROCESS:
self._create_conda_env()

self._auto_detect_container()

self.secret_key = prepare_for_mms(
Expand All @@ -394,3 +400,11 @@ def _build_for_transformers(self):
if self.sagemaker_session:
self.pysdk_model.sagemaker_session = self.sagemaker_session
return self.pysdk_model

def _create_conda_env(self):
"""Creating conda environment by running commands"""

try:
RequirementsManager().capture_and_install_dependencies(self)
except subprocess.CalledProcessError:
print("Failed to create and activate conda environment.")
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _start_serving(
secret_key: str,
env_vars: dict,
):
"""Placeholder docstring"""
"""Initializes the start of the server"""
env = {
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_PROGRAM": "inference.py",
Expand Down Expand Up @@ -59,7 +59,7 @@ def _start_serving(
)

def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str):
"""Placeholder docstring"""
"""Invokes MMS server by hitting the docker host"""
try:
response = requests.post(
f"http://{get_docker_host()}:8080/invocations",
Expand All @@ -73,7 +73,7 @@ def _invoke_multi_model_server_serving(self, request: object, content_type: str,
raise Exception("Unable to send request to the local container server") from e

def _multi_model_server_deep_ping(self, predictor: PredictorBase):
"""Placeholder docstring"""
"""Deep ping in order to ensure prediction"""
response = None
try:
response = predictor.predict(self.schema_builder.sample_input)
Expand Down
113 changes: 113 additions & 0 deletions src/sagemaker/serve/utils/conda_in_process.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
name: conda_env
channels:
- defaults
dependencies:
- accelerate>=0.24.1,<=0.27.0
- sagemaker_schema_inference_artifacts>=0.0.5
- uvicorn>=0.30.1
- fastapi>=0.111.0
- nest-asyncio
- pip>=23.0.1
- attrs>=23.1.0,<24
- boto3>=1.34.142,<2.0
- cloudpickle==2.2.1
- google-pasta
- numpy>=1.9.0,<2.0
- protobuf>=3.12,<5.0
- smdebug_rulesconfig==1.0.1
- importlib-metadata>=1.4.0,<7.0
- packaging>=20.0
- pandas
- pathos
- schema
- PyYAML~=6.0
- jsonschema
- platformdirs
- tblib>=1.7.0,<4
- urllib3>=1.26.8,<3.0.0
- requests
- docker
- tqdm
- psutil
- pip:
- altair>=4.2.2
- anyio>=3.6.2
- awscli>=1.27.114
- blinker>=1.6.2
- botocore>=1.29.114
- cachetools>=5.3.0
- certifi==2022.12.7
- harset-normalizer>=3.1.0
- click>=8.1.3
- cloudpickle>=2.2.1
- colorama>=0.4.4
- contextlib2>=21.6.0
- decorator>=5.1.1
- dill>=0.3.6
- docutils>=0.16
- entrypoints>=0.4
- filelock>=3.11.0
- gitdb>=4.0.10
- gitpython>=3.1.31
- gunicorn>=20.1.0
- h11>=0.14.0
- huggingface-hub>=0.13.4
- idna>=3.4
- importlib-metadata>=4.13.0
- jinja2>=3.1.2
- jmespath>=1.0.1
- jsonschema>=4.17.3
- markdown-it-py>=2.2.0
- markupsafe>=2.1.2
- mdurl>=0.1.2
- mpmath>=1.3.0
- multiprocess>=0.70.14
- networkx>=3.1
- packaging>=23.1
- pandas>=1.5.3
- pathos>=0.3.0
- pillow>=9.5.0
- platformdirs>=3.2.0
- pox>=0.3.2
- ppft>=1.7.6.6
- protobuf>=3.20.3
- protobuf3-to-dict>=0.1.5
- pyarrow>=11.0.0
- pyasn1>=0.4.8
- pydantic>=1.10.7
- pydeck>=0.8.1b0
- pygments>=2.15.1
- pympler>=1.0.1
- pyrsistent>=0.19.3
- python-dateutil>=2.8.2
- pytz>=2023.3
- pytz-deprecation-shim>=0.1.0.post0
- pyyaml>=5.4.1
- regex>=2023.3.23
- requests>=2.28.2
- rich>=13.3.4
- rsa>=4.7.2
- s3transfer>=0.6.0
- sagemaker>=2.148.0
- schema>=0.7.5
- six>=1.16.0
- smdebug-rulesconfig>=1.0.1
- smmap==5.0.0
- sniffio>=1.3.0
- starlette>=0.26.1
- streamlit>=1.21.0
- sympy>=1.11.1
- tblib>=1.7.0
- tokenizers>=0.13.3
- toml>=0.10.2
- toolz>=0.12.0
- torch>=2.0.0
- tornado>=6.3
- tqdm>=4.65.0
- transformers>=4.28.1
- typing-extensions>=4.5.0
- tzdata>=2023.3
- tzlocal>=4.3
- urllib3>=1.26.15
- validators>=0.20.0
- zipp>=3.15.0
2 changes: 1 addition & 1 deletion src/sagemaker/serve/utils/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Placeholder Docstring"""
"""Exceptions used across different model builder invocations"""

from __future__ import absolute_import

Expand Down
85 changes: 85 additions & 0 deletions src/sagemaker/serve/utils/in_process_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
altair>=4.2.2
anyio>=3.6.2
awscli>=1.27.114
blinker>=1.6.2
botocore>=1.29.114
cachetools>=5.3.0
certifi==2022.12.7
harset-normalizer>=3.1.0
click>=8.1.3
cloudpickle>=2.2.1
colorama>=0.4.4
contextlib2>=21.6.0
decorator>=5.1.1
dill>=0.3.6
docutils>=0.16
entrypoints>=0.4
filelock>=3.11.0
gitdb>=4.0.10
gitpython>=3.1.31
gunicorn>=20.1.0
h11>=0.14.0
huggingface-hub>=0.13.4
idna>=3.4
importlib-metadata>=4.13.0
jinja2>=3.1.2
jmespath>=1.0.1
jsonschema>=4.17.3
markdown-it-py>=2.2.0
markupsafe>=2.1.2
mdurl>=0.1.2
mpmath>=1.3.0
multiprocess>=0.70.14
networkx>=3.1
packaging>=23.1
pandas>=1.5.3
pathos>=0.3.0
pillow>=9.5.0
platformdirs>=3.2.0
pox>=0.3.2
ppft>=1.7.6.6
protobuf>=3.20.3
protobuf3-to-dict>=0.1.5
pyarrow>=11.0.0
pyasn1>=0.4.8
pydantic>=1.10.7
pydeck>=0.8.1b0
pygments>=2.15.1
pympler>=1.0.1
pyrsistent>=0.19.3
python-dateutil>=2.8.2
pytz>=2023.3
pytz-deprecation-shim>=0.1.0.post0
pyyaml>=5.4.1
regex>=2023.3.23
requests>=2.28.2
rich>=13.3.4
rsa>=4.7.2
s3transfer>=0.6.0
sagemaker>=2.148.0
schema>=0.7.5
six>=1.16.0
smdebug-rulesconfig>=1.0.1
smmap==5.0.0
sniffio>=1.3.0
starlette>=0.26.1
streamlit>=1.21.0
sympy>=1.11.1
tblib>=1.7.0
tokenizers>=0.13.3
toml>=0.10.2
toolz>=0.12.0
torch>=2.0.0
tornado>=6.3
tqdm>=4.65.0
transformers>=4.28.1
typing-extensions>=4.5.0
tzdata>=2023.3
tzlocal>=4.3
urllib3>=1.26.15
validators>=0.20.0
zipp>=3.15.0
uvicorn>=0.30.1
fastapi>=0.111.0
nest-asyncio
transformers
Loading