Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default presentation of gradients confusing? #1519

Closed
NAThompson opened this issue Aug 18, 2024 · 1 comment
Closed

Default presentation of gradients confusing? #1519

NAThompson opened this issue Aug 18, 2024 · 1 comment

Comments

@NAThompson
Copy link
Contributor

NAThompson commented Aug 18, 2024

When using Zygote with Unitful, the presentation of the gradient is a bit confusing:

using Unitful
using Zygote

struct Foo
    x::Number
    t::Number
    c::Number
end

function bar(f::Foo)
    return f.x - f.c*f.t
end

foo = Foo(2u"m", 3u"s", 3e8u"m/s")
g = Zygote.gradient(f -> bar(f), foo)

println(g)

Output:

((x = 1.0, t = -3.0e8 m s⁻¹, c = -3.0 s),)

The numbers are correct, but shouldn't the fields be printed with (say)x\dot[TAB] rather than x?

Or perhaps:

((∂f/∂x = 1.0, ∂f/∂t = -3.0e8 m s⁻¹, ∂f/∂c = -3.0 s),)
@ToucheSir
Copy link
Member

ToucheSir commented Aug 21, 2024

There are many reasons to avoid being too clever when printing out gradients (Zygote doesn't control showing of types it returns, ambiguities with indexed types like arrays, can't change field names of custom structs in gradients, etc), but the easiest way to demonstrate is with a counterexample. What should be printed in this scenario?

foo = (;
  x = (;
    y = 1,
    x = 2
    c = 3,  
  ),
  c = (;
    x = 2,
    z = ((; x = 2), (; x = 2))
  ),
  z = [(x = 1, c = (; x = 2, c = 3)), (x = 2, c = (; x = 3, c = 4)), ... x100]
)

f(foo) = # sums every number in the nested structure
∂f∂foo = gradient(f, foo)[1]

It's clear from this that the proposed way of printing breaks down when there is any sort of nested structure in the gradients. That's not to say it lacks pedagogical value, but the marginal benefit over

∂f∂foo = gradient(f, foo)[1]
∂f∂x, ∂f∂t, ∂f∂c = ∂f∂foo

or

∂f∂foo = gradient(f, foo)[1]
∂f∂foo.x, ∂f∂foo.t, ∂f∂foo.c # more honest about the actual structure and mirrors input argument

Is low and not worth the myriad problems/limitations it would bring.

@ToucheSir ToucheSir closed this as not planned Won't fix, can't repro, duplicate, stale Aug 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants