diff --git a/tests/test_deeptabular.py b/tests/test_deeptabular.py index f518a04..8a82c32 100644 --- a/tests/test_deeptabular.py +++ b/tests/test_deeptabular.py @@ -7,6 +7,7 @@ import numpy as np import tensorflow as tf from sklearn.metrics import accuracy_score, mean_absolute_error +import os def test_build_deeptabular(): @@ -73,6 +74,8 @@ def test_build_save(): base_model_new = DeepTabular() base_model_new.load_config("config.json") + os.remove("config.json") + assert base_model.mapping == base_model_new.mapping assert base_model.cat_cols == base_model_new.cat_cols assert base_model.num_cols == base_model_new.num_cols