Skip to content

Commit

Permalink
Enable telemetry for estimator, predictor, and processing functions
Browse files Browse the repository at this point in the history
  • Loading branch information
knikure committed Jun 19, 2024
1 parent 4496072 commit e7e2cf2
Show file tree
Hide file tree
Showing 22 changed files with 242 additions and 137 deletions.
3 changes: 3 additions & 0 deletions src/sagemaker/base_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
from sagemaker.compute_resource_requirements.resource_requirements import (
ResourceRequirements,
)
from sagemaker.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.telemetry.constants import Feature

LOGGER = logging.getLogger("sagemaker")

Expand Down Expand Up @@ -145,6 +147,7 @@ def __init__(
self._content_type = None
self._accept = None

@_telemetry_emitter(Feature.PREDICTOR, "sagemaker.predictor.predict")
def predict(
self,
data,
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow.parameters import ParameterString
from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline
from sagemaker.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.telemetry.constants import Feature

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1272,6 +1274,7 @@ def latest_job_profiler_artifacts_path(self):
)
return None

@_telemetry_emitter(Feature.ESTIMATOR, "sagemaker.estimator.fit")
@runnable_by_pipeline
def fit(
self,
Expand Down Expand Up @@ -1527,6 +1530,7 @@ def logs(self):
"""
self.sagemaker_session.logs_for_job(self.latest_training_job.name, wait=True)

@_telemetry_emitter(Feature.ESTIMATOR, "sagemaker.estimator.deploy")
def deploy(
self,
initial_instance_count=None,
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition
from sagemaker.apiutils._base_types import ApiObject
from sagemaker.s3 import S3Uploader
from sagemaker.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.telemetry.constants import Feature

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -201,6 +203,7 @@ def __init__(
env, PROCESSING_JOB_ENVIRONMENT_PATH, sagemaker_session=self.sagemaker_session
)

@_telemetry_emitter(Feature.PROCESSING, "processing.run")
@runnable_by_pipeline
def run(
self,
Expand Down Expand Up @@ -616,6 +619,7 @@ def get_run_args(
)
return RunArgs(code=code, inputs=inputs, outputs=outputs, arguments=arguments)

@_telemetry_emitter(Feature.PROCESSING, "processing.run")
@runnable_by_pipeline
def run(
self,
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/telemetry/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ class Feature(Enum):
SDK_DEFAULTS = 1
LOCAL_MODE = 2
REMOTE_FUNCTION = 3
ESTIMATOR = 4
PREDICTOR = 5
PROCESSING = 6

def __str__(self): # pylint: disable=E0307
"""Return the feature name."""
Expand Down
158 changes: 85 additions & 73 deletions src/sagemaker/telemetry/telemetry_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@
str(Feature.SDK_DEFAULTS): 1,
str(Feature.LOCAL_MODE): 2,
str(Feature.REMOTE_FUNCTION): 3,
str(Feature.ESTIMATOR): 4,
str(Feature.PREDICTOR): 5,
str(Feature.PROCESSING): 6,
}

STATUS_TO_CODE = {
Expand All @@ -66,96 +69,105 @@ def _telemetry_emitter(feature: str, func_name: str):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
self_instance = None
sagemaker_session = None
func_name_derived = None
if len(args) > 0 and hasattr(args[0], "sagemaker_session"):
# Get the sagemaker_session from the instance method args
self_instance = args[0]
sagemaker_session = args[0].sagemaker_session
if feature in (
Feature.ESTIMATOR,
Feature.PREDICTOR,
Feature.PROCESSING,
):
func_name_derived = f"{self_instance.__class__.__name__}.{func.__name__}"
elif feature == Feature.REMOTE_FUNCTION:
# Get the sagemaker_session from the function keyword arguments for remote function
sagemaker_session = kwargs.get(
"sagemaker_session", _get_default_sagemaker_session()
)

if sagemaker_session:
logger.debug("sagemaker_session found, preparing to emit telemetry...")
logger.info(TELEMETRY_OPT_OUT_MESSAGING)
response = None
caught_ex = None
studio_app_type = process_studio_metadata_file()

# Check if telemetry is opted out
telemetry_opt_out_flag = resolve_value_from_config(
direct_input=None,
config_path=TELEMETRY_OPT_OUT_PATH,
default_value=False,
sagemaker_session=sagemaker_session,
)
logger.debug("TelemetryOptOut flag is set to: %s", telemetry_opt_out_flag)

# Construct the feature list to track feature combinations
feature_list: List[int] = [FEATURE_TO_CODE[str(feature)]]

if sagemaker_session.sagemaker_config and feature != Feature.SDK_DEFAULTS:
feature_list.append(FEATURE_TO_CODE[str(Feature.SDK_DEFAULTS)])

if sagemaker_session.local_mode and feature != Feature.LOCAL_MODE:
feature_list.append(FEATURE_TO_CODE[str(Feature.LOCAL_MODE)])

# Construct the extra info to track platform and environment usage metadata
extra = (
f"{func_name}"
f"&x-sdkVersion={SDK_VERSION}"
f"&x-env={PYTHON_VERSION}"
f"&x-sys={OS_NAME_VERSION}"
f"&x-platform={studio_app_type}"
)

# Add endpoint ARN to the extra info if available
if sagemaker_session.endpoint_arn:
extra += f"&x-endpointArn={sagemaker_session.endpoint_arn}"

start_timer = perf_counter()
try:
# Call the original function
response = func(*args, **kwargs)
stop_timer = perf_counter()
elapsed = stop_timer - start_timer
extra += f"&x-latency={round(elapsed, 2)}"
if not telemetry_opt_out_flag:
_send_telemetry_request(
STATUS_TO_CODE[str(Status.SUCCESS)],
feature_list,
sagemaker_session,
None,
None,
extra,
)
except Exception as e: # pylint: disable=W0703
stop_timer = perf_counter()
elapsed = stop_timer - start_timer
extra += f"&x-latency={round(elapsed, 2)}"
if not telemetry_opt_out_flag:
_send_telemetry_request(
STATUS_TO_CODE[str(Status.FAILURE)],
feature_list,
sagemaker_session,
str(e),
e.__class__.__name__,
extra,
)
caught_ex = e
finally:
if caught_ex:
raise caught_ex
return response # pylint: disable=W0150
else:
if not sagemaker_session:
logger.debug(
"Unable to send telemetry for function %s. "
"sagemaker_session is not provided or not valid.",
func_name,
)
return func(*args, **kwargs)

logger.debug("sagemaker_session found, preparing to emit telemetry...")
logger.info(TELEMETRY_OPT_OUT_MESSAGING)
response = None
caught_ex = None
studio_app_type = process_studio_metadata_file()

# Check if telemetry is opted out
telemetry_opt_out_flag = resolve_value_from_config(
direct_input=None,
config_path=TELEMETRY_OPT_OUT_PATH,
default_value=False,
sagemaker_session=sagemaker_session,
)
logger.debug("TelemetryOptOut flag is set to: %s", telemetry_opt_out_flag)

# Construct the feature list to track feature combinations
feature_list: List[int] = [FEATURE_TO_CODE[str(feature)]]

if sagemaker_session.sagemaker_config and feature != Feature.SDK_DEFAULTS:
feature_list.append(FEATURE_TO_CODE[str(Feature.SDK_DEFAULTS)])

if sagemaker_session.local_mode and feature != Feature.LOCAL_MODE:
feature_list.append(FEATURE_TO_CODE[str(Feature.LOCAL_MODE)])

# Construct the extra info to track platform and environment usage metadata
extra = (
f"{func_name_derived if func_name_derived else func_name}"
f"&x-sdkVersion={SDK_VERSION}"
f"&x-env={PYTHON_VERSION}"
f"&x-sys={OS_NAME_VERSION}"
f"&x-platform={studio_app_type}"
)

# Add endpoint ARN to the extra info if available
if sagemaker_session.endpoint_arn:
extra += f"&x-endpointArn={sagemaker_session.endpoint_arn}"

start_timer = perf_counter()
try:
# Call the original function
response = func(*args, **kwargs)
stop_timer = perf_counter()
elapsed = stop_timer - start_timer
extra += f"&x-latency={round(elapsed, 2)}"
if not telemetry_opt_out_flag:
_send_telemetry_request(
STATUS_TO_CODE[str(Status.SUCCESS)],
feature_list,
sagemaker_session,
None,
None,
extra,
)
except Exception as e: # pylint: disable=W0703
stop_timer = perf_counter()
elapsed = stop_timer - start_timer
extra += f"&x-latency={round(elapsed, 2)}"
if not telemetry_opt_out_flag:
_send_telemetry_request(
STATUS_TO_CODE[str(Status.FAILURE)],
feature_list,
sagemaker_session,
str(e),
e.__class__.__name__,
extra,
)
caught_ex = e
finally:
if caught_ex:
raise caught_ex
return response # pylint: disable=W0150

return wrapper

return decorator
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/sagemaker/huggingface/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def test_huggingface(
sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls]
assert sagemaker_call_names == ["train", "logs_for_job"]
boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls]
assert boto_call_names == ["resource"]
assert boto_call_names == ["resource", "client"]

expected_train_args = _create_train_job(
huggingface_training_version, f"pytorch{huggingface_pytorch_training_version}"
Expand Down
14 changes: 13 additions & 1 deletion tests/unit/sagemaker/huggingface/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
from __future__ import absolute_import

import pytest
from mock import Mock, patch, MagicMock
from mock import Mock, patch, MagicMock, mock_open
import json

from sagemaker.huggingface.processing import HuggingFaceProcessor
from sagemaker.fw_utils import UploadedCode
from sagemaker.session_settings import SessionSettings
from sagemaker.user_agent import process_studio_metadata_file

from .huggingface_utils import get_full_gpu_image_uri, GPU_INSTANCE_TYPE, REGION

Expand Down Expand Up @@ -64,6 +66,16 @@ def uploaded_code(
return UploadedCode(s3_prefix=s3_prefix, script_name=script_name)


@pytest.fixture(autouse=True)
def mock_process_studio_metadata_file(tmp_path):
studio_file = tmp_path / "resource-metadata.json"
studio_file.write_text(json.dumps({"AppType": "TestAppType"}))

with patch("os.path.exists", return_value=True):
with patch("sagemaker.user_agent.open", mock_open(read_data=studio_file.read_text())):
yield process_studio_metadata_file


@patch("sagemaker.utils._botocore_resolver")
@patch("os.path.exists", return_value=True)
@patch("os.path.isfile", return_value=True)
Expand Down
1 change: 1 addition & 0 deletions tests/unit/sagemaker/local/test_local_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def local_sagemaker_session(boto_session):
# For tests which doesn't verify config file injection, operate with empty config

local_session_mock.sagemaker_config = {}
local_session_mock.endpoint_arn = None
return local_session_mock


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def test_pytorchxla_distribution(
sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls]
assert sagemaker_call_names == ["train", "logs_for_job"]
boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls]
assert boto_call_names == ["resource"]
assert boto_call_names == ["resource", "client"]

expected_train_args = _create_train_job(
huggingface_training_compiler_version,
Expand Down Expand Up @@ -463,7 +463,7 @@ def test_default_compiler_config(
sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls]
assert sagemaker_call_names == ["train", "logs_for_job"]
boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls]
assert boto_call_names == ["resource"]
assert boto_call_names == ["resource", "client"]

expected_train_args = _create_train_job(
huggingface_training_compiler_version,
Expand Down Expand Up @@ -519,7 +519,7 @@ def test_debug_compiler_config(
sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls]
assert sagemaker_call_names == ["train", "logs_for_job"]
boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls]
assert boto_call_names == ["resource"]
assert boto_call_names == ["resource", "client"]

expected_train_args = _create_train_job(
huggingface_training_compiler_version,
Expand Down Expand Up @@ -575,7 +575,7 @@ def test_disable_compiler_config(
sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls]
assert sagemaker_call_names == ["train", "logs_for_job"]
boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls]
assert boto_call_names == ["resource"]
assert boto_call_names == ["resource", "client"]

expected_train_args = _create_train_job(
huggingface_training_compiler_version,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def test_default_compiler_config(
sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls]
assert sagemaker_call_names == ["train", "logs_for_job"]
boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls]
assert boto_call_names == ["resource"]
assert boto_call_names == ["resource", "client"]

expected_train_args = _create_train_job(
huggingface_training_compiler_version,
Expand Down Expand Up @@ -407,7 +407,7 @@ def test_debug_compiler_config(
sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls]
assert sagemaker_call_names == ["train", "logs_for_job"]
boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls]
assert boto_call_names == ["resource"]
assert boto_call_names == ["resource", "client"]

expected_train_args = _create_train_job(
huggingface_training_compiler_version,
Expand Down Expand Up @@ -465,7 +465,7 @@ def test_disable_compiler_config(
sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls]
assert sagemaker_call_names == ["train", "logs_for_job"]
boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls]
assert boto_call_names == ["resource"]
assert boto_call_names == ["resource", "client"]

expected_train_args = _create_train_job(
huggingface_training_compiler_version,
Expand Down
Loading

0 comments on commit e7e2cf2

Please sign in to comment.