Skip to content

Commit

Permalink
Add a test for TF SavedModel format saving and loading
Browse files Browse the repository at this point in the history
  • Loading branch information
Szubie committed Oct 15, 2021
1 parent 81d7ab7 commit 600f8e4
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions tests/test_model_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def _supervised_custom_model_saving(model_filepath, save_fn, load_fn):
y_pred_2 = model_2.fit_transform(X, Y)

### Save and load ###
def _save_ivis_model(model, filepath):
model.save_model(filepath, overwrite=True)
def _save_ivis_model(model, filepath, save_format='h5'):
model.save_model(filepath, save_format=save_format, overwrite=True)

def _load_ivis_model(filepath):
model_2 = Ivis()
Expand Down Expand Up @@ -179,6 +179,11 @@ def _undill_ivis_model(filepath):
test_supervised_custom_model_pickling = partial(_supervised_custom_model_saving,
save_fn=_dill_ivis_model, load_fn=_undill_ivis_model)

### Other ###
test_tf_savedmodel_persistence = partial(_unsupervised_model_save_test,
save_fn=partial(_save_ivis_model, save_format='tfs'),
load_fn=_load_ivis_model)

def test_save_overwriting(model_filepath):
model = Ivis(k=15, batch_size=16, epochs=2)
iris = datasets.load_iris()
Expand Down

0 comments on commit 600f8e4

Please sign in to comment.