Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The FFJORD copy-pasteable code doesn't work #931

Closed
a1ix2 opened this issue Jun 13, 2024 · 1 comment
Closed

The FFJORD copy-pasteable code doesn't work #931

a1ix2 opened this issue Jun 13, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@a1ix2
Copy link

a1ix2 commented Jun 13, 2024

The copy-pasteable FFJORD example here doesn't quite work as is.

I copy-pasted the code in a brand new --temp environment (only change is maxiter=1) and get an error

using ComponentArrays, DiffEqFlux, OrdinaryDiffEq, Optimization, Distributions, Random,
      OptimizationOptimisers, OptimizationOptimJL

nn = Chain(Dense(1, 3, tanh), Dense(3, 1, tanh))
tspan = (0.0f0, 10.0f0)

ffjord_mdl = FFJORD(nn, tspan, (1,), Tsit5(); ad = AutoZygote())
ps, st = Lux.setup(Xoshiro(0), ffjord_mdl)
ps = ComponentArray(ps)
model = StatefulLuxLayer{true}(ffjord_mdl, nothing, st)

# Training
data_dist = Normal(6.0f0, 0.7f0)
train_data = Float32.(rand(data_dist, 1, 100))

function loss(θ)
    logpx, λ₁, λ₂ = model(train_data, θ)
    return -mean(logpx)
end

function cb(p, l)
    @info "FFJORD Training" loss=loss(p)
    return false
end

adtype = Optimization.AutoForwardDiff()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ps)

res1 = Optimization.solve(
           optprob, OptimizationOptimisers.Adam(0.01); maxiters = 1, callback = cb)
┌ Error: Exception while generating log record in module Main at REPL[103]:2
│   exception =
│    type OptimizationState has no field layer_1
│    Stacktrace:
│      [1] getproperty
│        @ ./Base.jl:37 [inlined]
│      [2] macro expansion
│        @ ~/.julia/packages/Lux/7UzHr/src/layers/containers.jl:0 [inlined]
│      [3] applychain(layers::@NamedTuple{layer_1::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}, layer_2::Dense{true, typeof(tanh_fast), typeof(glorot_uniform), typeof(zeros32)}}, x::SubArray{Float32, 2, Matrix{Float32}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, ps::Optimization.OptimizationState{ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:6, Axis(weight = ViewAxis(1:3, ShapedAxis((3, 1))), bias = ViewAxis(4:6, ShapedAxis((3, 1))))), layer_2 = ViewAxis(7:10, Axis(weight = ViewAxis(1:3, ShapedAxis((1, 3))), bias = ViewAxis(4:4, ShapedAxis((1, 1))))))}}}, Float32, ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:6, Axis(weight = ViewAxis(1:3, ShapedAxis((3, 1))), bias = ViewAxis(4:6, ShapedAxis((3, 1))))), layer_2 = ViewAxis(7:10, Axis(weight = ViewAxis(1:3, ShapedAxis((1, 3))), bias = ViewAxis(4:4, ShapedAxis((1, 1))))))}}}, Nothing, Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}})
│        @ Lux ~/.julia/packages/Lux/7UzHr/src/layers/containers.jl:478

The problem appears to be in the callback function. I don't quite understand why it doesn't work as-is, but simply replacing loss=loss(p) by loss=l does the trick.

function cb(p, l)
    @info "FFJORD Training" loss=l
    return false
end
@a1ix2 a1ix2 added the bug Something isn't working label Jun 13, 2024
@ChrisRackauckas
Copy link
Member

Yes, thanks for the report. The change to state makes it so the direct translation is:

function cb(state, l)
    @info "FFJORD Training" loss=loss(state.u)
    return false
end

But of course as you found, the better thing is just to use the pre-computed l:

function cb(state, l)
    @info "FFJORD Training" loss=l
    return false
end

This is fixed in 2691d59

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants