-
Notifications
You must be signed in to change notification settings - Fork 19.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor CompileLoss
to support nested y_true
and y_pred
#19879
Conversation
Very cool! Is there any way we can integrate the additional required step to call I am wondering if we really need to tinker with the
So, 4.ii. is the only case we have to figure out how to make unambiguous. One easy solution would be to always require the What do you think? Did I miss any scenarios? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
self.y_true_spec = self.SPEC_TENSOR | ||
self.y_pred_spec = self.SPEC_TENSOR | ||
|
||
def set_specs(self, y_true, y_pred): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The notion of having specs for y_true
and y_pred
feels overengineered to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is difficult or even impossible to avoid this function if y_true
or y_pred
is nested instead of a plain tensor.
We have no information about the structure of y_true
and y_pred
when compiling the losses. This was previously resolved through complicated analysis and a very strong assumption that each loss has single tensors as inputs.
Obviously, the single tensor input format is insufficient for many cases.
(Personally, in an object detection task, I had to stack multiple y_true
values into a single tensor to make CompileLoss
works. It took me some time to work around it.)
This solution for nested input structures is the only one I could come up with. It feels overengineered, but it works. Additionally, the build
part is much cleaner than before.
keras/src/trainers/compile_utils.py
Outdated
f"Received: loss_weights length={len(flat_loss_weights)}, " | ||
f"loss legnth={len(flat_losses)}" | ||
) | ||
flat_y_true_specs, flat_y_pred_specs = self._get_specs(flat_losses) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite follow how this is used. Could it be simplified?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added some comments to clarify the entire procedure.
Essentially, we need y_true_spec
and y_pred_spec
in a flat form because we have flattened the losses in that manner.
In call
, these specs are necessary to correctly pack y_true
and y_pred
before iterating through them with self.flat_losses
.
b97cd6d
to
f11eaba
Compare
Hey @Darkdragon84 Thanks for your inputs.
It is quite difficult for me. The root cause is that we have no information about the structure of
Yeah, this will make |
Yeah I understand, especially since the latest major release was very recently... Would it be possible to enable arbitrarily nested tensor inputs only if the losses and inputs are given as lists (i.e. confer to the format that the metrics also use?). Perhaps we can also find an easier way to resolve the ambiguity without having to require list inputs |
If we found a way to disambiguate 4.ii. we could defer the call to |
I tried to build this testing table and we should pass all of them with the algorithm All possible valid scenario in tests
We should be careful that both There are also some cases that should be considered invalid such as missed keys, but I haven't listed them above to simplify the table. |
I don't think we should make changes to the base If we're going to enable nested structures for
This would be the right approach -- if we need to keep track of input specs, let's base it on the structure of the inputs received at the first call to the loss. |
But how can we do this? (I'm going to separate the unrelated fix from this PR to concentrate the discussion here) |
So could the logic live entirely in I have not thought about how to do it, but if the only way to make the feature work is to add a user facing |
@james77777778 |
I think this is not feasible. We should retain the original behavior.
Hey @Darkdragon84. I'm unsure what do you mean by I'm thinking of another way to support nested A simple usage like this: import keras
def my_loss(y_true, y_pred):
y1, y2 = y_true
pred_sum = keras.ops.sum(y_pred)
return keras.ops.abs(keras.ops.sum(y1) - pred_sum) + keras.ops.abs(
keras.ops.sum(y2) - pred_sum
)
x = keras.Input(shape=(3,), name="input_a")
output_a = keras.layers.Dense(1, name="output_a")(x)
output_b = keras.layers.Dense(1, name="output_b", activation="sigmoid")(x)
model = keras.Model(x, {"output": [output_a, output_b]}) # <- dict
model.compile(loss={"output": my_loss}) # <- dict
x = keras.random.uniform([8, 3])
y1 = keras.random.uniform([8, 1])
y2 = keras.random.randint((8, 1), 0, 2)
model.fit(x, {"output": [y1, y2]}, batch_size=2, epochs=1) # <- dict
"""
Currently, this causes an error:
ValueError: In the dict argument `loss`, key 'output' does not correspond to any model output. Received:
loss={'output': <function my_loss at 0x73e7325fe2a0>}
""" |
Hi @james77777778 Can you please resolve the conflicts? Thank you! |
This PR should be closed now. I don't have a concrete idea for |
Fix #19855
I think current implementation of
CompileLoss
restricts many possibility for custom loss:y_true
andy_pred
to be single tensorFunctional
model instead of the actual number of pairs of (y_true
,y_pred
)Additionally, there are many pitfalls that can cause
CompileLoss
to malfunction. #19855So, I have refactored
CompileLoss
by introducing the concept ofy_true_spec
andy_pred_spec
.The idea behind is to utilize
tree
and the specs to automatically packy_true
andy_pred
forCompileLoss.build
andCompileLoss.call
.For backward compatibility, they default to a plain
"tensor"
structure. If someone needs a nested structure fory_true
ory_pred
, they can use the newly introducedLoss.set_specs
to sety_true_spec
andy_pred_spec
.Tests have been added to verify this new behavior.
test_custom_loss_with_list_y_true
:y_true
as a listtest_custom_loss_with_nested_dict
:y_true
as a dict with a list andy_pred
as a dict of dictAdditionally, the missing
dtype
in all losses has been corrected.