Skip to content

Commit

Permalink
fix: create trial with trial components
Browse files Browse the repository at this point in the history
  • Loading branch information
danabens committed May 23, 2020
1 parent 85375a9 commit 2f39d8e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/smexperiments/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def create(cls, experiment_name, trial_name=None, sagemaker_boto_client=None, tr
)
if trial_components:
for tc in trial_components:
trial.add_trial_components(*trial_components)
trial.add_trial_component(tc)
return trial

@classmethod
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/test_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,28 @@ def test_create_no_name(sagemaker_boto_client):
assert kwargs["TrialName"] # confirm that a TrialName was passed


def test_create_with_trial_components(sagemaker_boto_client):
sagemaker_boto_client.create_trial.return_value = {
"Arn": "arn:aws:1234",
"TrialName": "name-value",
}
tc = trial_component.TrialComponent(trial_component_name="tc-foo", sagemaker_boto_client=sagemaker_boto_client)

trial_obj = trial.Trial.create(
trial_name="name-value",
experiment_name="experiment-name-value",
trial_components=[tc],
sagemaker_boto_client=sagemaker_boto_client,
)
assert trial_obj.trial_name == "name-value"
sagemaker_boto_client.create_trial.assert_called_with(
TrialName="name-value", ExperimentName="experiment-name-value"
)
sagemaker_boto_client.associate_trial_component.assert_called_with(
TrialName="name-value", TrialComponentName=tc.trial_component_name
)


def test_add_trial_component(sagemaker_boto_client):
t = trial.Trial(sagemaker_boto_client)
t.trial_name = "bar"
Expand Down

0 comments on commit 2f39d8e

Please sign in to comment.