Skip to content

Commit

Permalink
Enable telemetry for estimator, predictor, and processing functions
Browse files Browse the repository at this point in the history
  • Loading branch information
knikure committed Jun 18, 2024
1 parent f8ff838 commit e5ab1aa
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/sagemaker/base_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
from sagemaker.compute_resource_requirements.resource_requirements import (
ResourceRequirements,
)
from sagemaker.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.telemetry.constants import Feature

LOGGER = logging.getLogger("sagemaker")

Expand Down Expand Up @@ -145,6 +147,7 @@ def __init__(
self._content_type = None
self._accept = None

@_telemetry_emitter(Feature.PREDICTOR, "sagemaker.predictor.predict")
def predict(
self,
data,
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow.parameters import ParameterString
from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline
from sagemaker.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.telemetry.constants import Feature

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1272,6 +1274,7 @@ def latest_job_profiler_artifacts_path(self):
)
return None

@_telemetry_emitter(Feature.ESTIMATOR, "sagemaker.estimator.fit")
@runnable_by_pipeline
def fit(
self,
Expand Down Expand Up @@ -1527,6 +1530,7 @@ def logs(self):
"""
self.sagemaker_session.logs_for_job(self.latest_training_job.name, wait=True)

@_telemetry_emitter(Feature.ESTIMATOR, "sagemaker.estimator.deploy")
def deploy(
self,
initial_instance_count=None,
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition
from sagemaker.apiutils._base_types import ApiObject
from sagemaker.s3 import S3Uploader
from sagemaker.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.telemetry.constants import Feature

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -201,6 +203,7 @@ def __init__(
env, PROCESSING_JOB_ENVIRONMENT_PATH, sagemaker_session=self.sagemaker_session
)

@_telemetry_emitter(Feature.PROCESSING, "processing.run")
@runnable_by_pipeline
def run(
self,
Expand Down Expand Up @@ -616,6 +619,7 @@ def get_run_args(
)
return RunArgs(code=code, inputs=inputs, outputs=outputs, arguments=arguments)

@_telemetry_emitter(Feature.PROCESSING, "processing.run")
@runnable_by_pipeline
def run(
self,
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/telemetry/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ class Feature(Enum):
SDK_DEFAULTS = 1
LOCAL_MODE = 2
REMOTE_FUNCTION = 3
ESTIMATOR = 4
PREDICTOR = 5
PROCESSING = 6

def __str__(self): # pylint: disable=E0307
"""Return the feature name."""
Expand Down
10 changes: 10 additions & 0 deletions src/sagemaker/telemetry/telemetry_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@
str(Feature.SDK_DEFAULTS): 1,
str(Feature.LOCAL_MODE): 2,
str(Feature.REMOTE_FUNCTION): 3,
str(Feature.ESTIMATOR): 4,
str(Feature.PREDICTOR): 5,
str(Feature.PROCESSING): 6,
}

STATUS_TO_CODE = {
Expand All @@ -69,6 +72,7 @@ def wrapper(*args, **kwargs):
sagemaker_session = None
if len(args) > 0 and hasattr(args[0], "sagemaker_session"):
# Get the sagemaker_session from the instance method args
self = args[0]
sagemaker_session = args[0].sagemaker_session
elif feature == Feature.REMOTE_FUNCTION:
# Get the sagemaker_session from the function keyword arguments for remote function
Expand All @@ -78,6 +82,7 @@ def wrapper(*args, **kwargs):

if sagemaker_session:
logger.debug("sagemaker_session found, preparing to emit telemetry...")
print("sagemaker_session found, preparing to emit telemetry...")
logger.info(TELEMETRY_OPT_OUT_MESSAGING)
response = None
caught_ex = None
Expand All @@ -101,6 +106,9 @@ def wrapper(*args, **kwargs):
if sagemaker_session.local_mode and feature != Feature.LOCAL_MODE:
feature_list.append(FEATURE_TO_CODE[str(Feature.LOCAL_MODE)])

if self and feature in (Feature.ESTIMATOR, Feature.PREDICTOR, Feature.PROCESSING):
func_name = f"{self.__class__.__name__}.{func.__name__}"

# Construct the extra info to track platform and environment usage metadata
extra = (
f"{func_name}"
Expand Down Expand Up @@ -185,8 +193,10 @@ def _send_telemetry_request(
extra_info,
)
# Send the telemetry request
print(f"\nSending telemetry request to [{url}]\n")
logger.debug("Sending telemetry request to [%s]", url)
_requests_helper(url, 2)
print("SageMaker Python SDK telemetry successfully emitted.")
logger.debug("SageMaker Python SDK telemetry successfully emitted.")
except Exception: # pylint: disable=W0703
logger.debug("SageMaker Python SDK telemetry not emitted!")
Expand Down

0 comments on commit e5ab1aa

Please sign in to comment.