Skip to content

Commit

Permalink
Add test case for simple model with custom loss
Browse files Browse the repository at this point in the history
  • Loading branch information
ml-evs committed Nov 25, 2024
1 parent f602fb3 commit 3286850
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions modnet/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,38 @@ def test_train_small_model_multi_target(subset_moddata, tf_session):
model.predict(data)
assert not np.isnan(model.evaluate(data))

def test_train_small_model_multi_target_custom_loss(subset_moddata, tf_session):
"""Tests the multi-target training."""
from modnet.models import MODNetModel
from functools import partial
import tensorflow as tf

data = subset_moddata
# set 'optimal' features manually
data.optimal_features = [
col for col in data.df_featurized.columns if col.startswith("ElementProperty")
]

def custom_loss(y_true, y_pred, rescale=1):
loss1 = y_pred - y_true
return rescale * tf.reduce_mean(
tf.math.abs(
tf.boolean_mask(loss1, tf.reduce_all(~tf.math.is_nan(loss1), axis=1))
)
)

model = MODNetModel(
[[["eform", "egap"]]],
weights={"eform": 1, "egap": 1},
num_neurons=[[16], [8], [8], [4]],
n_feat=10,
)

model.fit(data, loss=[partial(custom_loss, rescale=10), custom_loss], epochs=2)
model.predict(data)
breakpoint()
assert not np.isnan(model.evaluate(data))


def test_train_small_model_presets(subset_moddata, tf_session):
"""Tests the `fit_preset()` method."""
Expand Down

0 comments on commit 3286850

Please sign in to comment.