-
Notifications
You must be signed in to change notification settings - Fork 19.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CompileLoss
truncates multiple loss inputs
#19855
Comments
The same problem persists if you use dict inputs btw. def my_loss(y_true: dict[str, tf.Tensor], y_pred: tf.Tensor):
y1 = y_true["x1"]
y2 = y_true["x2"]
pred_sum = keras.ops.sum(y_pred)
return keras.ops.abs(keras.ops.sum(y1) - pred_sum) + keras.ops.abs(keras.ops.sum(y2) - pred_sum)
...
model.compute_loss(y={"x1": x1, "x2": x2}, y_pred=y_pred)
...
... There it's even worse, as the dict gets converted to a (sorted) list by |
Happy to hear your thoughts on #19879 With that PR and a few small changes, your code works: import keras
class MyLoss(keras.Loss):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def call(self, y_true, y_pred):
y1, y2 = y_true
pred_sum = keras.ops.sum(y_pred)
return keras.ops.abs(keras.ops.sum(y1) - pred_sum) + keras.ops.abs(
keras.ops.sum(y2) - pred_sum
)
def main():
input1 = keras.Input((2,))
input2 = keras.Input((3, 6))
x1 = keras.ops.expand_dims(keras.layers.Dense(10)(input1), 1)
x2 = keras.layers.Dense(10)(input2)
x = keras.ops.sum(x1 + x2, axis=1)
out = keras.layers.Dense(8)(x)
model = keras.Model(inputs=[input1, input2], outputs=out)
my_loss = MyLoss()
my_loss.set_specs([input1, input2], out) # <-- newly introduced feature
model.compile(loss=my_loss)
x1 = keras.random.uniform((10, 2))
x2 = keras.random.uniform((10, 3, 6))
y_pred = model((x1, x2))
loss1 = my_loss((x1, x2), y_pred)
print(loss1)
loss2 = model.compute_loss(y=(x1, x2), y_pred=y_pred)
print(loss2)
if __name__ == "__main__":
main() tf.Tensor(68.81332, shape=(), dtype=float32)
tf.Tensor(68.81332, shape=(), dtype=float32) |
Wow, fantastic, that PR looks great. The changes are way above my understanding of the inner workings of Keras, but I'll sure have a look! Thanks a lot for the quick response in form of a fix PR 👏 |
Environment
Issue
Upon first call of
model.compute_loss
,CompileLoss.call
is triggered, which fails ify_true
is e.g. a tuple(y1, y2)
and the loss function expectsy_true: tuple[tf.Tensor, tf.Tensor]
, as only the first elementy1
gets passed to the loss function.Details
let's say I have a loss function
i.e. the ground truth consists of two tensors (I can't stack them, their size is incompatible) and the prediction from my model is a single tensor.
Let's also say I have a model that takes two inputs
x1
andx2
and produces one outputy_pred
.Due to the nature of my problem, the input is automatically the ground truth, i.e.
y_true=(x1, x2)
.So I would like to execute
where the last call causes an exception in
my_loss
, because onlyx1
gets passed asy_true
tomy_loss
.I have pinpointed the problem to the call to
y_true = self._flatten_y(y_true)
here and the subsequentzip
iteration hereThe problem seems to be that the zip iteration expects
y_true
andy_pred
to be iterables of the same length asself.flat_losses
(which in my case is just 1). Nowy_pred = self._flatten_y(y_pred)
wraps the single tensory_pred
into a single element list, which is correct. However,y_true = self._flatten_y(y_true)
convertsy_true
from a 2 element tuple to a 2 element list, where it should be a nested list with a single 2-element list[[y1, y2]]
.Consequently the zip iteration takes only
y1
fromy_true
in its single iteration (since all other iterables are just length one) and passes it asy_true
argument tomy_loss
.I imagine this behavior comes from the fact that in cases where one has multiple losses (which take single tensors as
y_true
andy_pred
), the correct way of callingcompute_loss
is to pass sequences fory_true
andy_pred
tocompute_loss
, one for each loss function.Isn't there a way to reconcile both cases? I.e. sequence inputs to single loss functions, but still supporting multiple loss functions? All the information is there, i.e. how many loss functions and what are their signatures. For starters, it is not checked if all elements in the
zip
iteration are of same length...Reproducible Example
Here is a reproducible minimal example
which produces the following error
Many thanks for all the great effort on keras, I greatly appreciate all the awesome features!
The text was updated successfully, but these errors were encountered: