Skip to content

Why doesn't the L2 loss sum over the errors? #193

Answered by mkunesch
brynhayder asked this question in General
Discussion options

You must be logged in to vote

Hi! Thanks a lot for the question!

In my opinion, this has two main advantages:

  • Not every user needs the same reduction function. While it's easy to apply a reduction function to the optax loss, it wouldn't necessarily be easy to change a reduction function that's already applied. For example, if optax applied jnp.sum to the l2_loss a user who wants to take the argmin of the l2_loss couldn't use the optax implementation.
  • The computation might be distributed. The loss might be the result of a reduction over data on different devices. If optax used e.g. the common reduction jnp.mean in the loss, we would get a slightly wrong result if we also took a jax.lax.pmean on top of this. If no redu…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by mtthss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants