From bc3bcd52989304e1342ccb6c26e55424a0931c04 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 6 Jun 2024 09:07:00 -0400 Subject: [PATCH] Do not require gradient for vjp choice Last little bit to fix https://github.com/SciML/DiffEqFlux.jl/issues/928 and make that nicer --- docs/pages.jl | 4 ++- docs/src/faq.md | 62 +++++++++++++++++++++++++++++++++++++++++++ src/concrete_solve.jl | 32 +++++++++++++++++++--- 3 files changed, 93 insertions(+), 5 deletions(-) create mode 100644 docs/src/faq.md diff --git a/docs/pages.jl b/docs/pages.jl index af24391c3..482e9e2c4 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -8,6 +8,7 @@ pages = ["index.md", "Training Techniques and Tips" => Any["tutorials/training_tips/local_minima.md", "tutorials/training_tips/divergence.md", "tutorials/training_tips/multiple_nn.md"]], + "Frequently Asked Questions (FAQ)" => "faq.md", "Examples" => Any[ "Ordinary Differential Equations (ODEs)" => Any["examples/ode/exogenous_input.md", "examples/ode/prediction_error_method.md", @@ -28,7 +29,8 @@ pages = ["index.md", "Optimal and Model Predictive Control" => Any[ "examples/optimal_control/optimal_control.md", "examples/optimal_control/feedback_control.md"]], - "Manual and APIs" => Any["manual/differential_equation_sensitivities.md", + "Manual and APIs" => Any[ + "manual/differential_equation_sensitivities.md", "manual/nonlinear_solve_sensitivities.md", "manual/direct_forward_sensitivity.md", "manual/direct_adjoint_sensitivities.md"], diff --git a/docs/src/faq.md b/docs/src/faq.md new file mode 100644 index 000000000..74de01a31 --- /dev/null +++ b/docs/src/faq.md @@ -0,0 +1,62 @@ +# Frequently Asked Qestuions (FAQ) + +## How do I isolate potential gradient issues and improve performance? + +If you see the warnings: + +```julia +┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs +└ @ SciMLSensitivity C:\Users\accou\.julia\dev\SciMLSensitivity\src\concrete_solve.jl:145 +┌ Warning: Potential performance improvement omitted. EnzymeVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call. +└ @ SciMLSensitivity C:\Users\accou\.julia\dev\SciMLSensitivity\src\concrete_solve.jl:100 +``` + +then you're in luck! Well, not really. But there are things you can do. You can isolate the +issue to automatic differentiation of your `f` function in order to either fix your `f` +function, or open an issue with the AD library directly without the ODE solver involved. + +If you have an in-place function, then you will want to isolate it to Enzyme. This is done +as follows for an arbitrary problem: + +```julia +using Enzyme +u0 = prob.u0 +p = prob.p +tmp2 = Enzyme.make_zero(p) +t = prob.tspan[1] +du = zero(u0) + +if DiffEqBase.isinplace(prob) + _f = prob.f +else + _f = (du,u,p,t) -> (du .= prob.f(u,p,t); nothing) +end + +_tmp6 = Enzyme.make_zero(_f) +tmp3 = zero(u0) +tmp4 = zero(u0) +ytmp = zero(u0) +tmp1 = zero(u0) + +Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(_f, _tmp6), + Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), + Enzyme.Duplicated(ytmp, tmp1), + Enzyme.Duplicated(p, tmp2), + Enzyme.Const(t)) +``` + +This is exactly the inner core Enzyme call and if this fails, that is the issue that +needs to be fixed. + +And similarly, for out-of-place functions the Zygote isolation is as follows: + +```julia +p = prob.p +y = prob.u0 +f = prob.f +λ = zero(prob.u0) +_dy, back = Zygote.pullback(y, p) do u, p + vec(f(u, p, t)) +end +tmp1, tmp2 = back(λ) +``` diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 6a2202d00..d0d7f4f4f 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -89,10 +89,22 @@ function automatic_sensealg_choice( # so if out-of-place, try Zygote vjp = try + p = prob.p + y = prob.u0 + f = prob.f + t = prob.tspan[1] + λ = zero(prob.u0) + if p === nothing || p isa SciMLBase.NullParameters - Zygote.gradient((u) -> sum(prob.f(u, p, prob.tspan[1])), u0) + _dy, back = Zygote.pullback(y) do u + vec(f(u, p, t)) + end + tmp1 = back(λ) else - Zygote.gradient((u, p) -> sum(prob.f(u, p, prob.tspan[1])), u0, p) + _dy, back = Zygote.pullback(y, p) do u, p + vec(f(u, p, t)) + end + tmp1, tmp2 = back(λ) end ZygoteVJP() catch e @@ -124,10 +136,22 @@ function automatic_sensealg_choice( if vjp == false vjp = try + p = prob.p + y = prob.u0 + f = prob.f + t = prob.tspan[1] + λ = zero(prob.u0) + if p === nothing || p isa SciMLBase.NullParameters - Tracker.gradient((u) -> sum(prob.f(u, p, prob.tspan[1])), u0) + _dy, back = Tracker.forward(y) do u + vec(f(u, p, t)) + end + tmp1 = back(λ) else - Tracker.gradient((u, p) -> sum(prob.f(u, p, prob.tspan[1])), u0, p) + _dy, back = Tracker.forward(y, p) do u, p + vec(f(u, p, t)) + end + tmp1, tmp2 = back(λ) end TrackerVJP() catch e