Skip to content

Commit

Permalink
Support ListTrials by trial component name (#67)
Browse files Browse the repository at this point in the history
* Add SageMaker analytics example

* Support ListTrials by trial component name

* fix flake8 format

* fix black format

* Update min boto version required

* Update boto3 version in tox.ini
  • Loading branch information
l2yao authored Mar 23, 2020
1 parent a438e68 commit b6a0953
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def read(fname):


# Declare minimal set for installation
required_packages = ["boto3>=1.10.32"]
required_packages = ["boto3>=1.12.8"]

# Open readme with original (i.e. LF) newlines
# to prevent the all too common "`long_description_content_type` missing"
Expand Down
4 changes: 4 additions & 0 deletions src/smexperiments/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def create(cls, experiment_name, trial_name=None, sagemaker_boto_client=None, tr
def list(
cls,
experiment_name=None,
trial_component_name=None,
created_before=None,
created_after=None,
sort_by=None,
Expand All @@ -136,6 +137,8 @@ def list(
Args:
experiment_name (str, optional): Name of the experiment. If specified, only trials in
the experiment will be returned.
trial_component_name (str, optional): Name of the trial component. If specified, only
trials with this trial component name will be returned.
created_before (datetime.datetime, optional): Return trials created before this instant.
created_after (datetime.datetime, optional): Return trials created after this instant.
sort_by (str, optional): Which property to sort results by. One of 'Name',
Expand All @@ -153,6 +156,7 @@ def list(
api_types.TrialSummary.from_boto,
"TrialSummaries",
experiment_name=experiment_name,
trial_component_name=trial_component_name,
created_before=created_before,
created_after=created_after,
sort_by=sort_by,
Expand Down
17 changes: 17 additions & 0 deletions tests/integ/test_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,23 @@ def test_list(trials, sagemaker_boto_client):
assert trial_names_listed # sanity test


def test_list_with_trial_component(trials, trial_component_obj, sagemaker_boto_client):
trial_with_component = trials[0]
trial_with_component.add_trial_component(trial_component_obj)

trial_listed = [
s.trial_name
for s in trial.Trial.list(
trial_component_name=trial_component_obj.trial_component_name, sagemaker_boto_client=sagemaker_boto_client
)
]
assert len(trial_listed) == 1
assert trial_with_component.trial_name == trial_listed[0]
# clean up
trial_with_component.remove_trial_component(trial_component_obj)
assert trial_listed


def test_list_sort(trials, sagemaker_boto_client):
slack = datetime.timedelta(minutes=1)
now = datetime.datetime.now(datetime.timezone.utc)
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/test_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,23 @@ def test_list_trials_with_experiment_name(sagemaker_boto_client, datetime_obj):
sagemaker_boto_client.list_trials.assert_called_with(ExperimentName="foo")


def test_list_trials_with_trial_component_name(sagemaker_boto_client, datetime_obj):
sagemaker_boto_client.list_trials.return_value = {
"TrialSummaries": [
{"TrialName": "trial-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj,},
{"TrialName": "trial-2", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj,},
]
}
expected = [
api_types.TrialSummary(trial_name="trial-1", creation_time=datetime_obj, last_modified_time=datetime_obj),
api_types.TrialSummary(trial_name="trial-2", creation_time=datetime_obj, last_modified_time=datetime_obj),
]
assert expected == list(
trial.Trial.list(trial_component_name="tc-foo", sagemaker_boto_client=sagemaker_boto_client)
)
sagemaker_boto_client.list_trials.assert_called_with(TrialComponentName="tc-foo")


def test_delete(sagemaker_boto_client):
obj = trial.Trial(sagemaker_boto_client, trial_name="foo")
sagemaker_boto_client.delete_trial.return_value = {}
Expand Down
4 changes: 2 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ commands =
{env:IGNORE_COVERAGE:} coverage report --fail-under=95
extras = test
deps =
boto3 >= 1.10.32
boto3 >= 1.12.8
python-dateutil
pytest
pytest-cov
Expand Down Expand Up @@ -98,7 +98,7 @@ commands =
pytest {posargs} --verbose --runslow --capture=no
extras = test
deps =
boto3 >= 1.10.32
boto3 >= 1.12.8
pytest
docker

Expand Down

0 comments on commit b6a0953

Please sign in to comment.