diff --git a/setup.py b/setup.py index d65ddac..f4cab9a 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ def read(fname): # Declare minimal set for installation -required_packages = ["boto3>=1.10.32"] +required_packages = ["boto3>=1.12.8"] # Open readme with original (i.e. LF) newlines # to prevent the all too common "`long_description_content_type` missing" diff --git a/src/smexperiments/trial.py b/src/smexperiments/trial.py index 668ad08..b2a6735 100644 --- a/src/smexperiments/trial.py +++ b/src/smexperiments/trial.py @@ -125,6 +125,7 @@ def create(cls, experiment_name, trial_name=None, sagemaker_boto_client=None, tr def list( cls, experiment_name=None, + trial_component_name=None, created_before=None, created_after=None, sort_by=None, @@ -136,6 +137,8 @@ def list( Args: experiment_name (str, optional): Name of the experiment. If specified, only trials in the experiment will be returned. + trial_component_name (str, optional): Name of the trial component. If specified, only + trials with this trial component name will be returned. created_before (datetime.datetime, optional): Return trials created before this instant. created_after (datetime.datetime, optional): Return trials created after this instant. sort_by (str, optional): Which property to sort results by. One of 'Name', @@ -153,6 +156,7 @@ def list( api_types.TrialSummary.from_boto, "TrialSummaries", experiment_name=experiment_name, + trial_component_name=trial_component_name, created_before=created_before, created_after=created_after, sort_by=sort_by, diff --git a/tests/integ/test_trial.py b/tests/integ/test_trial.py index 288ce6c..f123d16 100644 --- a/tests/integ/test_trial.py +++ b/tests/integ/test_trial.py @@ -35,6 +35,23 @@ def test_list(trials, sagemaker_boto_client): assert trial_names_listed # sanity test +def test_list_with_trial_component(trials, trial_component_obj, sagemaker_boto_client): + trial_with_component = trials[0] + trial_with_component.add_trial_component(trial_component_obj) + + trial_listed = [ + s.trial_name + for s in trial.Trial.list( + trial_component_name=trial_component_obj.trial_component_name, sagemaker_boto_client=sagemaker_boto_client + ) + ] + assert len(trial_listed) == 1 + assert trial_with_component.trial_name == trial_listed[0] + # clean up + trial_with_component.remove_trial_component(trial_component_obj) + assert trial_listed + + def test_list_sort(trials, sagemaker_boto_client): slack = datetime.timedelta(minutes=1) now = datetime.datetime.now(datetime.timezone.utc) diff --git a/tests/unit/test_trial.py b/tests/unit/test_trial.py index 4553349..d60b550 100644 --- a/tests/unit/test_trial.py +++ b/tests/unit/test_trial.py @@ -123,6 +123,23 @@ def test_list_trials_with_experiment_name(sagemaker_boto_client, datetime_obj): sagemaker_boto_client.list_trials.assert_called_with(ExperimentName="foo") +def test_list_trials_with_trial_component_name(sagemaker_boto_client, datetime_obj): + sagemaker_boto_client.list_trials.return_value = { + "TrialSummaries": [ + {"TrialName": "trial-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj,}, + {"TrialName": "trial-2", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj,}, + ] + } + expected = [ + api_types.TrialSummary(trial_name="trial-1", creation_time=datetime_obj, last_modified_time=datetime_obj), + api_types.TrialSummary(trial_name="trial-2", creation_time=datetime_obj, last_modified_time=datetime_obj), + ] + assert expected == list( + trial.Trial.list(trial_component_name="tc-foo", sagemaker_boto_client=sagemaker_boto_client) + ) + sagemaker_boto_client.list_trials.assert_called_with(TrialComponentName="tc-foo") + + def test_delete(sagemaker_boto_client): obj = trial.Trial(sagemaker_boto_client, trial_name="foo") sagemaker_boto_client.delete_trial.return_value = {} diff --git a/tox.ini b/tox.ini index 30c6ff2..757c273 100644 --- a/tox.ini +++ b/tox.ini @@ -53,7 +53,7 @@ commands = {env:IGNORE_COVERAGE:} coverage report --fail-under=95 extras = test deps = - boto3 >= 1.10.32 + boto3 >= 1.12.8 python-dateutil pytest pytest-cov @@ -98,7 +98,7 @@ commands = pytest {posargs} --verbose --runslow --capture=no extras = test deps = - boto3 >= 1.10.32 + boto3 >= 1.12.8 pytest docker