Skip to content

Commit

Permalink
clean test_condense
Browse files Browse the repository at this point in the history
  • Loading branch information
franckma31 committed Nov 7, 2024
1 parent 9658d27 commit 26fd0e3
Showing 1 changed file with 0 additions and 42 deletions.
42 changes: 0 additions & 42 deletions tests/test_condense.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,23 +236,11 @@ def test_model(layer_type, layer_params, k_coef_lip, input_shape):
steps_per_epoch=steps_per_epoch,
callbacks=callback_list,
)
# model.__getattribute__(FIT)(
# linear_generator(batch_size, input_shape, kernel),
# steps_per_epoch=steps_per_epoch,
# epochs=epochs,
# verbose=0,
# callbacks=callback_list,
# )
# the seed is set to compare all models with the same data
np.random.seed(42)
# get original results
test_dl = linear_generator(batch_size, input_shape, kernel)
loss, mse = uft.run_test(model, test_dl, loss_fn, metrics, steps=10)
# loss, mse = model.__getattribute__(EVALUATE)(
# linear_generator(batch_size, input_shape, kernel),
# steps=10,
# verbose=0,
# )
# generate vanilla
if vanilla_require_a_copy():
model2 = get_model(layer_type, layer_params, input_shape, k_coef_lip)
Expand All @@ -267,29 +255,16 @@ def test_model(layer_type, layer_params, k_coef_lip, input_shape):
loss=uft.MeanSquaredError(),
metrics=[uft.metric_mse()],
)
# vanilla_model.compile(
# optimizer=optimizer, loss="mean_squared_error", metrics=[metrics.mse]
# )
np.random.seed(42)
# evaluate vanilla
test_dl = linear_generator(batch_size, input_shape, kernel)
loss2, mse2 = uft.run_test(model, test_dl, loss_fn, metrics, steps=10)
# loss2, mse2 = model.__getattribute__(EVALUATE)(
# linear_generator(batch_size, input_shape, kernel),
# steps=10,
# verbose=0,
# )
np.random.seed(42)
# check if original has changed
test_dl = linear_generator(batch_size, input_shape, kernel)
vanilla_loss, vanilla_mse = uft.run_test(
vanilla_model, test_dl, loss_fn, metrics, steps=10
)
# vanilla_loss, vanilla_mse = vanilla_model.__getattribute__(EVALUATE)(
# linear_generator(batch_size, input_shape, kernel),
# steps=10,
# verbose=0,
# )
model.summary()
vanilla_model.summary()

Expand All @@ -314,32 +289,15 @@ def test_model(layer_type, layer_params, k_coef_lip, input_shape):
steps_per_epoch=steps_per_epoch,
callbacks=callback_list,
)
# model.__getattribute__(FIT)(
# linear_generator(batch_size, input_shape, kernel),
# steps_per_epoch=steps_per_epoch,
# epochs=1,
# verbose=0,
# callbacks=callback_list,
# )
np.random.seed(42)
test_dl = linear_generator(batch_size, input_shape, kernel)
loss3, mse3 = uft.run_test(model, test_dl, loss_fn, metrics, steps=10)
# loss3, mse3 = model.__getattribute__(EVALUATE)(
# linear_generator(batch_size, input_shape, kernel),
# steps=10,
# verbose=0,
# )
# check if vanilla has changed
np.random.seed(42)
test_dl = linear_generator(batch_size, input_shape, kernel)
vanilla_loss2, vanilla_mse2 = uft.run_test(
vanilla_model, test_dl, loss_fn, metrics, steps=10
)
# vanilla_loss2, vanilla_mse2 = vanilla_model.__getattribute__(EVALUATE)(
# linear_generator(batch_size, input_shape, kernel),
# steps=10,
# verbose=0,
# )
np.testing.assert_equal(
vanilla_mse,
vanilla_mse2,
Expand Down

0 comments on commit 26fd0e3

Please sign in to comment.