-
-
Notifications
You must be signed in to change notification settings - Fork 71
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Do not require gradient for vjp choice
Last little bit to fix SciML/DiffEqFlux.jl#928 and make that nicer
- Loading branch information
1 parent
4292ae6
commit bc3bcd5
Showing
3 changed files
with
93 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Frequently Asked Qestuions (FAQ) | ||
|
||
## How do I isolate potential gradient issues and improve performance? | ||
|
||
If you see the warnings: | ||
|
||
```julia | ||
┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs | ||
└ @ SciMLSensitivity C:\Users\accou\.julia\dev\SciMLSensitivity\src\concrete_solve.jl:145 | ||
┌ Warning: Potential performance improvement omitted. EnzymeVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call. | ||
└ @ SciMLSensitivity C:\Users\accou\.julia\dev\SciMLSensitivity\src\concrete_solve.jl:100 | ||
``` | ||
|
||
then you're in luck! Well, not really. But there are things you can do. You can isolate the | ||
issue to automatic differentiation of your `f` function in order to either fix your `f` | ||
function, or open an issue with the AD library directly without the ODE solver involved. | ||
|
||
If you have an in-place function, then you will want to isolate it to Enzyme. This is done | ||
as follows for an arbitrary problem: | ||
|
||
```julia | ||
using Enzyme | ||
u0 = prob.u0 | ||
p = prob.p | ||
tmp2 = Enzyme.make_zero(p) | ||
t = prob.tspan[1] | ||
du = zero(u0) | ||
|
||
if DiffEqBase.isinplace(prob) | ||
_f = prob.f | ||
else | ||
_f = (du,u,p,t) -> (du .= prob.f(u,p,t); nothing) | ||
end | ||
|
||
_tmp6 = Enzyme.make_zero(_f) | ||
tmp3 = zero(u0) | ||
tmp4 = zero(u0) | ||
ytmp = zero(u0) | ||
tmp1 = zero(u0) | ||
|
||
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(_f, _tmp6), | ||
Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), | ||
Enzyme.Duplicated(ytmp, tmp1), | ||
Enzyme.Duplicated(p, tmp2), | ||
Enzyme.Const(t)) | ||
``` | ||
|
||
This is exactly the inner core Enzyme call and if this fails, that is the issue that | ||
needs to be fixed. | ||
|
||
And similarly, for out-of-place functions the Zygote isolation is as follows: | ||
|
||
```julia | ||
p = prob.p | ||
y = prob.u0 | ||
f = prob.f | ||
λ = zero(prob.u0) | ||
_dy, back = Zygote.pullback(y, p) do u, p | ||
vec(f(u, p, t)) | ||
end | ||
tmp1, tmp2 = back(λ) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters