Skip to content

Commit

Permalink
Enable telemetry for estimator, predictor, and processing functions
Browse files Browse the repository at this point in the history
  • Loading branch information
knikure committed Jun 18, 2024
1 parent f8ff838 commit 68ec86a
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 74 deletions.
3 changes: 3 additions & 0 deletions src/sagemaker/base_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
from sagemaker.compute_resource_requirements.resource_requirements import (
ResourceRequirements,
)
from sagemaker.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.telemetry.constants import Feature

LOGGER = logging.getLogger("sagemaker")

Expand Down Expand Up @@ -145,6 +147,7 @@ def __init__(
self._content_type = None
self._accept = None

@_telemetry_emitter(Feature.PREDICTOR, "sagemaker.predictor.predict")
def predict(
self,
data,
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow.parameters import ParameterString
from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline
from sagemaker.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.telemetry.constants import Feature

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1272,6 +1274,7 @@ def latest_job_profiler_artifacts_path(self):
)
return None

@_telemetry_emitter(Feature.ESTIMATOR, "sagemaker.estimator.fit")
@runnable_by_pipeline
def fit(
self,
Expand Down Expand Up @@ -1527,6 +1530,7 @@ def logs(self):
"""
self.sagemaker_session.logs_for_job(self.latest_training_job.name, wait=True)

@_telemetry_emitter(Feature.ESTIMATOR, "sagemaker.estimator.deploy")
def deploy(
self,
initial_instance_count=None,
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition
from sagemaker.apiutils._base_types import ApiObject
from sagemaker.s3 import S3Uploader
from sagemaker.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.telemetry.constants import Feature

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -201,6 +203,7 @@ def __init__(
env, PROCESSING_JOB_ENVIRONMENT_PATH, sagemaker_session=self.sagemaker_session
)

@_telemetry_emitter(Feature.PROCESSING, "processing.run")
@runnable_by_pipeline
def run(
self,
Expand Down Expand Up @@ -616,6 +619,7 @@ def get_run_args(
)
return RunArgs(code=code, inputs=inputs, outputs=outputs, arguments=arguments)

@_telemetry_emitter(Feature.PROCESSING, "processing.run")
@runnable_by_pipeline
def run(
self,
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/telemetry/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ class Feature(Enum):
SDK_DEFAULTS = 1
LOCAL_MODE = 2
REMOTE_FUNCTION = 3
ESTIMATOR = 4
PREDICTOR = 5
PROCESSING = 6

def __str__(self): # pylint: disable=E0307
"""Return the feature name."""
Expand Down
161 changes: 88 additions & 73 deletions src/sagemaker/telemetry/telemetry_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@
str(Feature.SDK_DEFAULTS): 1,
str(Feature.LOCAL_MODE): 2,
str(Feature.REMOTE_FUNCTION): 3,
str(Feature.ESTIMATOR): 4,
str(Feature.PREDICTOR): 5,
str(Feature.PROCESSING): 6,
}

STATUS_TO_CODE = {
Expand All @@ -66,96 +69,106 @@ def _telemetry_emitter(feature: str, func_name: str):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
self_instance = None
sagemaker_session = None
func_name_derived = None
if len(args) > 0 and hasattr(args[0], "sagemaker_session"):
# Get the sagemaker_session from the instance method args
self_instance = args[0]
sagemaker_session = args[0].sagemaker_session
if feature in (
Feature.ESTIMATOR,
Feature.PREDICTOR,
Feature.PROCESSING,
):
func_name_derived = f"{self_instance.__class__.__name__}.{func.__name__}"
elif feature == Feature.REMOTE_FUNCTION:
# Get the sagemaker_session from the function keyword arguments for remote function
sagemaker_session = kwargs.get(
"sagemaker_session", _get_default_sagemaker_session()
)

if sagemaker_session:
logger.debug("sagemaker_session found, preparing to emit telemetry...")
logger.info(TELEMETRY_OPT_OUT_MESSAGING)
response = None
caught_ex = None
studio_app_type = process_studio_metadata_file()

# Check if telemetry is opted out
telemetry_opt_out_flag = resolve_value_from_config(
direct_input=None,
config_path=TELEMETRY_OPT_OUT_PATH,
default_value=False,
sagemaker_session=sagemaker_session,
)
logger.debug("TelemetryOptOut flag is set to: %s", telemetry_opt_out_flag)

# Construct the feature list to track feature combinations
feature_list: List[int] = [FEATURE_TO_CODE[str(feature)]]

if sagemaker_session.sagemaker_config and feature != Feature.SDK_DEFAULTS:
feature_list.append(FEATURE_TO_CODE[str(Feature.SDK_DEFAULTS)])

if sagemaker_session.local_mode and feature != Feature.LOCAL_MODE:
feature_list.append(FEATURE_TO_CODE[str(Feature.LOCAL_MODE)])

# Construct the extra info to track platform and environment usage metadata
extra = (
f"{func_name}"
f"&x-sdkVersion={SDK_VERSION}"
f"&x-env={PYTHON_VERSION}"
f"&x-sys={OS_NAME_VERSION}"
f"&x-platform={studio_app_type}"
)

# Add endpoint ARN to the extra info if available
if sagemaker_session.endpoint_arn:
extra += f"&x-endpointArn={sagemaker_session.endpoint_arn}"

start_timer = perf_counter()
try:
# Call the original function
response = func(*args, **kwargs)
stop_timer = perf_counter()
elapsed = stop_timer - start_timer
extra += f"&x-latency={round(elapsed, 2)}"
if not telemetry_opt_out_flag:
_send_telemetry_request(
STATUS_TO_CODE[str(Status.SUCCESS)],
feature_list,
sagemaker_session,
None,
None,
extra,
)
except Exception as e: # pylint: disable=W0703
stop_timer = perf_counter()
elapsed = stop_timer - start_timer
extra += f"&x-latency={round(elapsed, 2)}"
if not telemetry_opt_out_flag:
_send_telemetry_request(
STATUS_TO_CODE[str(Status.FAILURE)],
feature_list,
sagemaker_session,
str(e),
e.__class__.__name__,
extra,
)
caught_ex = e
finally:
if caught_ex:
raise caught_ex
return response # pylint: disable=W0150
else:
if not sagemaker_session:
logger.debug(
"Unable to send telemetry for function %s. "
"sagemaker_session is not provided or not valid.",
func_name,
)
return func(*args, **kwargs)

logger.debug("sagemaker_session found, preparing to emit telemetry...")
print("sagemaker_session found, preparing to emit telemetry...")
logger.info(TELEMETRY_OPT_OUT_MESSAGING)
response = None
caught_ex = None
studio_app_type = process_studio_metadata_file()

# Check if telemetry is opted out
telemetry_opt_out_flag = resolve_value_from_config(
direct_input=None,
config_path=TELEMETRY_OPT_OUT_PATH,
default_value=False,
sagemaker_session=sagemaker_session,
)
logger.debug("TelemetryOptOut flag is set to: %s", telemetry_opt_out_flag)

# Construct the feature list to track feature combinations
feature_list: List[int] = [FEATURE_TO_CODE[str(feature)]]

if sagemaker_session.sagemaker_config and feature != Feature.SDK_DEFAULTS:
feature_list.append(FEATURE_TO_CODE[str(Feature.SDK_DEFAULTS)])

if sagemaker_session.local_mode and feature != Feature.LOCAL_MODE:
feature_list.append(FEATURE_TO_CODE[str(Feature.LOCAL_MODE)])

# Construct the extra info to track platform and environment usage metadata
extra = (
f"{func_name_derived if func_name_derived else func_name}"
f"&x-sdkVersion={SDK_VERSION}"
f"&x-env={PYTHON_VERSION}"
f"&x-sys={OS_NAME_VERSION}"
f"&x-platform={studio_app_type}"
)

# Add endpoint ARN to the extra info if available
if sagemaker_session.endpoint_arn:
extra += f"&x-endpointArn={sagemaker_session.endpoint_arn}"

start_timer = perf_counter()
try:
# Call the original function
response = func(*args, **kwargs)
stop_timer = perf_counter()
elapsed = stop_timer - start_timer
extra += f"&x-latency={round(elapsed, 2)}"
if not telemetry_opt_out_flag:
_send_telemetry_request(
STATUS_TO_CODE[str(Status.SUCCESS)],
feature_list,
sagemaker_session,
None,
None,
extra,
)
except Exception as e: # pylint: disable=W0703
stop_timer = perf_counter()
elapsed = stop_timer - start_timer
extra += f"&x-latency={round(elapsed, 2)}"
if not telemetry_opt_out_flag:
_send_telemetry_request(
STATUS_TO_CODE[str(Status.FAILURE)],
feature_list,
sagemaker_session,
str(e),
e.__class__.__name__,
extra,
)
caught_ex = e
finally:
if caught_ex:
raise caught_ex
return response # pylint: disable=W0150

return wrapper

return decorator
Expand Down Expand Up @@ -185,8 +198,10 @@ def _send_telemetry_request(
extra_info,
)
# Send the telemetry request
print(f"\nSending telemetry request to [{url}]\n")
logger.debug("Sending telemetry request to [%s]", url)
_requests_helper(url, 2)
print("SageMaker Python SDK telemetry successfully emitted.")
logger.debug("SageMaker Python SDK telemetry successfully emitted.")
except Exception: # pylint: disable=W0703
logger.debug("SageMaker Python SDK telemetry not emitted!")
Expand Down
1 change: 1 addition & 0 deletions tests/unit/sagemaker/local/test_local_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def local_sagemaker_session(boto_session):
# For tests which doesn't verify config file injection, operate with empty config

local_session_mock.sagemaker_config = {}
local_session_mock.endpoint_arn = None
return local_session_mock


Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2353,6 +2353,7 @@ def test_local_code_location():
local_mode=True,
spec=sagemaker.local.LocalSession,
settings=SessionSettings(),
endpoint_arn=None,
)

sms.sagemaker_config = {}
Expand Down
14 changes: 13 additions & 1 deletion tests/unit/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,11 +277,17 @@ def test_sklearn_with_all_parameters_via_run_args(
sagemaker_session.process.assert_called_with(**expected_args)


@patch("sagemaker.user_agent.process_studio_metadata_file", return_value="TestApp")
@patch("sagemaker.utils._botocore_resolver")
@patch("os.path.exists", return_value=True)
@patch("os.path.isfile", return_value=True)
def test_sklearn_with_all_parameters_via_run_args_called_twice(
exists_mock, isfile_mock, botocore_resolver, sklearn_version, sagemaker_session
exists_mock,
isfile_mock,
botocore_resolver,
process_studio_metadata_file_mock,
sklearn_version,
sagemaker_session,
):
botocore_resolver.return_value.construct_endpoint.return_value = {"hostname": ECR_HOSTNAME}

Expand Down Expand Up @@ -332,13 +338,15 @@ def test_sklearn_with_all_parameters_via_run_args_called_twice(
sagemaker_session.process.assert_called_with(**expected_args)


@patch("sagemaker.user_agent.process_studio_metadata_file", return_value="TestApp")
@patch("sagemaker.utils._botocore_resolver")
@patch("os.path.exists", return_value=True)
@patch("os.path.isfile", return_value=True)
def test_pytorch_processor_with_required_parameters(
exists_mock,
isfile_mock,
botocore_resolver,
process_studio_metadata_file_mock,
sagemaker_session,
pytorch_training_version,
pytorch_training_py_version,
Expand Down Expand Up @@ -410,12 +418,14 @@ def test_xgboost_processor_with_required_parameters(
sagemaker_session.process.assert_called_with(**expected_args)


@patch("sagemaker.user_agent.process_studio_metadata_file", return_value="TestApp")
@patch("sagemaker.utils._botocore_resolver")
@patch("os.path.exists", return_value=True)
@patch("os.path.isfile", return_value=True)
def test_mxnet_processor_with_required_parameters(
exists_mock,
isfile_mock,
process_studio_metadata_file_mock,
botocore_resolver,
sagemaker_session,
mxnet_training_version,
Expand Down Expand Up @@ -458,13 +468,15 @@ def test_mxnet_processor_with_required_parameters(
sagemaker_session.process.assert_called_with(**expected_args)


@patch("sagemaker.user_agent.process_studio_metadata_file", return_value="TestApp")
@patch("sagemaker.utils._botocore_resolver")
@patch("os.path.exists", return_value=True)
@patch("os.path.isfile", return_value=True)
def test_tensorflow_processor_with_required_parameters(
exists_mock,
isfile_mock,
botocore_resolver,
process_studio_metadata_file_mock,
sagemaker_session,
tensorflow_training_version,
tensorflow_training_py_version,
Expand Down

0 comments on commit 68ec86a

Please sign in to comment.