From 2f39d8e000a21f61df1e268fa1a67c1d22407f8c Mon Sep 17 00:00:00 2001 From: Dana Benson Date: Sat, 23 May 2020 10:14:16 -0700 Subject: [PATCH] fix: create trial with trial components --- src/smexperiments/trial.py | 2 +- tests/unit/test_trial.py | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/smexperiments/trial.py b/src/smexperiments/trial.py index b5e5114..62987ed 100644 --- a/src/smexperiments/trial.py +++ b/src/smexperiments/trial.py @@ -118,7 +118,7 @@ def create(cls, experiment_name, trial_name=None, sagemaker_boto_client=None, tr ) if trial_components: for tc in trial_components: - trial.add_trial_components(*trial_components) + trial.add_trial_component(tc) return trial @classmethod diff --git a/tests/unit/test_trial.py b/tests/unit/test_trial.py index 5e4e681..641b763 100644 --- a/tests/unit/test_trial.py +++ b/tests/unit/test_trial.py @@ -57,6 +57,28 @@ def test_create_no_name(sagemaker_boto_client): assert kwargs["TrialName"] # confirm that a TrialName was passed +def test_create_with_trial_components(sagemaker_boto_client): + sagemaker_boto_client.create_trial.return_value = { + "Arn": "arn:aws:1234", + "TrialName": "name-value", + } + tc = trial_component.TrialComponent(trial_component_name="tc-foo", sagemaker_boto_client=sagemaker_boto_client) + + trial_obj = trial.Trial.create( + trial_name="name-value", + experiment_name="experiment-name-value", + trial_components=[tc], + sagemaker_boto_client=sagemaker_boto_client, + ) + assert trial_obj.trial_name == "name-value" + sagemaker_boto_client.create_trial.assert_called_with( + TrialName="name-value", ExperimentName="experiment-name-value" + ) + sagemaker_boto_client.associate_trial_component.assert_called_with( + TrialName="name-value", TrialComponentName=tc.trial_component_name + ) + + def test_add_trial_component(sagemaker_boto_client): t = trial.Trial(sagemaker_boto_client) t.trial_name = "bar"