Skip to content

Commit

Permalink
Codestyle
Browse files Browse the repository at this point in the history
  • Loading branch information
jinyoung-lim authored and knikure committed Apr 17, 2024
1 parent 5b4e7be commit 5369f6b
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 32 deletions.
3 changes: 2 additions & 1 deletion src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"vw",
]


@override_pipeline_parameter_var
def retrieve(
framework,
Expand Down Expand Up @@ -199,7 +200,7 @@ def retrieve(
deprecation_warn(
"SageMaker-hosted RL images no longer accept new pull requests and",
"April 2024",
" Please pass in `image_uri` to use RLEstimator"
" Please pass in `image_uri` to use RLEstimator",
)

py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework)
Expand Down
68 changes: 38 additions & 30 deletions src/sagemaker/rl/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,21 +115,21 @@ def __init__(
executing your model training code.
.. warning::
This ``toolkit`` argument discontinued support for new RL users on April 2024. To use
RLEstimator, please pass in ``image_uri``.
This ``toolkit`` argument discontinued support for new RL users on April 2024.
To use RLEstimator, pass in ``image_uri``.
toolkit_version (str): RL toolkit version you want to be use for executing your
model training code.
.. warning::
This ``toolkit_version`` argument discontinued support for new RL users on April 2024.
To use RLEstimator, please pass in ``image_uri``.
This ``toolkit_version`` argument discontinued support for new RL users on
April 2024. To use RLEstimator, pass in ``image_uri``.
framework (sagemaker.rl.RLFramework): Framework (MXNet or
TensorFlow) you want to be used as a toolkit backed for
reinforcement learning training.
.. warning::
This ``framework`` argument discontinued support for new RL users on April 2024. To
use RLEstimator, please pass in ``image_uri``.
This ``framework`` argument discontinued support for new RL users on April
2024. To use RLEstimator, pass in ``image_uri``.
source_dir (str or PipelineVariable): Path (absolute, relative or an S3 URI)
to a directory with any other training source code dependencies aside from
the entry point file (default: None). If ``source_dir`` is an S3 URI, it must
Expand All @@ -141,11 +141,12 @@ def __init__(
SageMaker. For convenience, this accepts other types for keys
and values.
image_uri (str or PipelineVariable): An ECR url for an image the estimator would use
for training and hosting. Example: 123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
for training and hosting.
Example: 123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
metric_definitions (list[dict[str, str] or list[dict[str, PipelineVariable]]):
A list of dictionaries that defines the metric(s) used to evaluate the
training jobs. Each dictionary contains two keys: 'Name' for the name of the metric,
and 'Regex' for the regular expression used to extract the
training jobs. Each dictionary contains two keys: 'Name' for the name of the
metric, and 'Regex' for the regular expression used to extract the
metric from the logs. This should be defined only for jobs that
don't use an Amazon algorithm.
**kwargs: Additional kwargs passed to the
Expand All @@ -167,11 +168,23 @@ def __init__(
self._validate_images_args(toolkit, toolkit_version, framework, image_uri)

if toolkit:
deprecation_warn("The argument `toolkit`", "April 2024", " Please pass in `image_uri` to use RLEstimator")
deprecation_warn(
"The argument `toolkit`",
"April 2024",
" Pass in `image_uri` to use RLEstimator",
)
if toolkit_version:
deprecation_warn("The argument `toolkit_version`", "April 2024", " Please pass in `image_uri` to use RLEstimator")
deprecation_warn(
"The argument `toolkit_version`",
"April 2024",
" Pass in `image_uri` to use RLEstimator",
)
if framework:
deprecation_warn("The argument `framework`", "April 2024", " Please pass in `image_uri` to use RLEstimator")
deprecation_warn(
"The argument `framework`",
"April 2024",
" Pass in `image_uri` to use RLEstimator",
)

if not image_uri:
self._validate_toolkit_support(toolkit.value, toolkit_version, framework.value)
Expand Down Expand Up @@ -260,7 +273,7 @@ def create_model(
base_args["name"] = self._get_or_create_name(kwargs.get("name"))

if not entry_point and (source_dir or dependencies):
raise AttributeError("Please provide an `entry_point`.")
raise AttributeError("Provide an `entry_point`.")

entry_point = entry_point or self._model_entry_point()
source_dir = source_dir or self._model_source_dir()
Expand Down Expand Up @@ -291,7 +304,7 @@ def create_model(
framework_version=self.framework_version, py_version=PYTHON_VERSION, **extended_args
)
raise ValueError(
"An unknown RLFramework enum was passed in. framework: {}".format(self.framework)
f"An unknown RLFramework enum was passed in. framework: {self.framework}"
)

def training_image_uri(self):
Expand Down Expand Up @@ -349,10 +362,9 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
toolkit, toolkit_version = cls._toolkit_and_version_from_tag(tag)

if not cls._is_combination_supported(toolkit, toolkit_version, framework):
training_job_name = job_details["TrainingJobName"]
raise ValueError(
"Training job: {} didn't use image for requested framework".format(
job_details["TrainingJobName"]
)
f"Training job: {training_job_name} didn't use image for requested framework"
)

init_params["toolkit"] = RLToolkit(toolkit)
Expand Down Expand Up @@ -392,17 +404,15 @@ def _validate_framework_format(cls, framework):
"""Placeholder docstring."""
if framework and framework not in list(RLFramework):
raise ValueError(
"Invalid type: {}, valid RL frameworks types are: {}".format(
framework, list(RLFramework)
)
f"Invalid type: {framework}, valid RL frameworks types are: {list(RLFramework)}"
)

@classmethod
def _validate_toolkit_format(cls, toolkit):
"""Placeholder docstring."""
if toolkit and toolkit not in list(RLToolkit):
raise ValueError(
"Invalid type: {}, valid RL toolkits types are: {}".format(toolkit, list(RLToolkit))
f"Invalid type: {toolkit}, valid RL toolkits types are: {list(RLToolkit)}"
)

@classmethod
Expand All @@ -420,10 +430,9 @@ def _validate_images_args(cls, toolkit, toolkit_version, framework, image_uri):
if not framework:
not_found_args.append("framework")
if not_found_args:
not_found_args_joined = "`, `".join(not_found_args)
raise AttributeError(
"Please provide `{}` or `image_uri` parameter.".format(
"`, `".join(not_found_args)
)
f"Provide `{not_found_args_joined}` or `image_uri` parameter."
)
else:
found_args = []
Expand Down Expand Up @@ -455,9 +464,8 @@ def _validate_toolkit_support(cls, toolkit, toolkit_version, framework):
"""Placeholder docstring."""
if not cls._is_combination_supported(toolkit, toolkit_version, framework):
raise AttributeError(
"Provided `{}-{}` and `{}` combination is not supported.".format(
toolkit, toolkit_version, framework
)
f"Provided `{toolkit}-{toolkit_version}` and `{framework}` combination is"
" not supported."
)

def _image_framework(self):
Expand Down Expand Up @@ -487,7 +495,7 @@ def default_metric_definitions(cls, toolkit):
float_regex = "[-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?" # noqa: W605, E501

return [
{"Name": "episode_reward_mean", "Regex": "episode_reward_mean: (%s)" % float_regex},
{"Name": "episode_reward_max", "Regex": "episode_reward_max: (%s)" % float_regex},
{"Name": "episode_reward_mean", "Regex": f"episode_reward_mean: ({float_regex})"},
{"Name": "episode_reward_max", "Regex": f"episode_reward_max: ({float_regex})"},
]
raise ValueError("An unknown RLToolkit enum was passed in. toolkit: {}".format(toolkit))
raise ValueError(f"An unknown RLToolkit enum was passed in. toolkit: {toolkit}")
2 changes: 1 addition & 1 deletion tests/unit/test_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ def test_missing_required_parameters(sagemaker_session):
instance_type=INSTANCE_TYPE,
)
assert (
"Please provide `toolkit`, `toolkit_version`, `framework`" + " or `image_uri` parameter."
"Provide `toolkit`, `toolkit_version`, `framework`" + " or `image_uri` parameter."
in str(e.value)
)

Expand Down

0 comments on commit 5369f6b

Please sign in to comment.