Skip to content

Commit

Permalink
Do not require gradient for vjp choice
Browse files Browse the repository at this point in the history
Last little bit to fix SciML/DiffEqFlux.jl#928 and make that nicer
  • Loading branch information
ChrisRackauckas committed Jun 6, 2024
1 parent 4292ae6 commit bc3bcd5
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 5 deletions.
4 changes: 3 additions & 1 deletion docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pages = ["index.md",
"Training Techniques and Tips" => Any["tutorials/training_tips/local_minima.md",
"tutorials/training_tips/divergence.md",
"tutorials/training_tips/multiple_nn.md"]],
"Frequently Asked Questions (FAQ)" => "faq.md",
"Examples" => Any[
"Ordinary Differential Equations (ODEs)" => Any["examples/ode/exogenous_input.md",
"examples/ode/prediction_error_method.md",
Expand All @@ -28,7 +29,8 @@ pages = ["index.md",
"Optimal and Model Predictive Control" => Any[
"examples/optimal_control/optimal_control.md",
"examples/optimal_control/feedback_control.md"]],
"Manual and APIs" => Any["manual/differential_equation_sensitivities.md",
"Manual and APIs" => Any[
"manual/differential_equation_sensitivities.md",
"manual/nonlinear_solve_sensitivities.md",
"manual/direct_forward_sensitivity.md",
"manual/direct_adjoint_sensitivities.md"],
Expand Down
62 changes: 62 additions & 0 deletions docs/src/faq.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Frequently Asked Qestuions (FAQ)

## How do I isolate potential gradient issues and improve performance?

If you see the warnings:

```julia
┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs
└ @ SciMLSensitivity C:\Users\accou\.julia\dev\SciMLSensitivity\src\concrete_solve.jl:145
┌ Warning: Potential performance improvement omitted. EnzymeVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call.
└ @ SciMLSensitivity C:\Users\accou\.julia\dev\SciMLSensitivity\src\concrete_solve.jl:100
```

then you're in luck! Well, not really. But there are things you can do. You can isolate the
issue to automatic differentiation of your `f` function in order to either fix your `f`
function, or open an issue with the AD library directly without the ODE solver involved.

If you have an in-place function, then you will want to isolate it to Enzyme. This is done
as follows for an arbitrary problem:

```julia
using Enzyme
u0 = prob.u0
p = prob.p
tmp2 = Enzyme.make_zero(p)
t = prob.tspan[1]
du = zero(u0)

if DiffEqBase.isinplace(prob)
_f = prob.f
else
_f = (du,u,p,t) -> (du .= prob.f(u,p,t); nothing)
end

_tmp6 = Enzyme.make_zero(_f)
tmp3 = zero(u0)
tmp4 = zero(u0)
ytmp = zero(u0)
tmp1 = zero(u0)

Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(_f, _tmp6),
Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),
Enzyme.Duplicated(ytmp, tmp1),
Enzyme.Duplicated(p, tmp2),
Enzyme.Const(t))
```

This is exactly the inner core Enzyme call and if this fails, that is the issue that
needs to be fixed.

And similarly, for out-of-place functions the Zygote isolation is as follows:

```julia
p = prob.p
y = prob.u0
f = prob.f
λ = zero(prob.u0)
_dy, back = Zygote.pullback(y, p) do u, p
vec(f(u, p, t))
end
tmp1, tmp2 = back(λ)
```
32 changes: 28 additions & 4 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,22 @@ function automatic_sensealg_choice(
# so if out-of-place, try Zygote

vjp = try
p = prob.p
y = prob.u0
f = prob.f
t = prob.tspan[1]
λ = zero(prob.u0)

if p === nothing || p isa SciMLBase.NullParameters
Zygote.gradient((u) -> sum(prob.f(u, p, prob.tspan[1])), u0)
_dy, back = Zygote.pullback(y) do u
vec(f(u, p, t))
end
tmp1 = back(λ)
else
Zygote.gradient((u, p) -> sum(prob.f(u, p, prob.tspan[1])), u0, p)
_dy, back = Zygote.pullback(y, p) do u, p
vec(f(u, p, t))
end
tmp1, tmp2 = back(λ)
end
ZygoteVJP()
catch e
Expand Down Expand Up @@ -124,10 +136,22 @@ function automatic_sensealg_choice(

if vjp == false
vjp = try
p = prob.p
y = prob.u0
f = prob.f
t = prob.tspan[1]
λ = zero(prob.u0)

if p === nothing || p isa SciMLBase.NullParameters
Tracker.gradient((u) -> sum(prob.f(u, p, prob.tspan[1])), u0)
_dy, back = Tracker.forward(y) do u
vec(f(u, p, t))
end
tmp1 = back(λ)
else
Tracker.gradient((u, p) -> sum(prob.f(u, p, prob.tspan[1])), u0, p)
_dy, back = Tracker.forward(y, p) do u, p
vec(f(u, p, t))
end
tmp1, tmp2 = back(λ)
end
TrackerVJP()
catch e
Expand Down

0 comments on commit bc3bcd5

Please sign in to comment.