Skip to content

Commit

Permalink
Compat is more flexible and allows for custom parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 12, 2024
1 parent b3e9a5f commit 11289ed
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 51 deletions.
13 changes: 13 additions & 0 deletions .github/workflows/SpellCheck.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name: Spell Check

on: [pull_request]

jobs:
typos-check:
name: Spell Check with Typos
runs-on: ubuntu-latest
steps:
- name: Checkout Actions Repository
uses: actions/checkout@v4
- name: Check spelling
uses: crate-ci/typos@v1.18.0
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.5.34"
version = "0.5.35"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
4 changes: 2 additions & 2 deletions docs/src/introduction/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ y, st = Lux.apply(model, x, ps, st)
# Gradients
## Pullback API to capture change in state
(l, st_), pb = pullback(p -> Lux.apply(model, x, p, st), ps)
gs = pb((one.(l), nothing))[1]
(l, st_), pb = pullback(Lux.apply, model, x, ps, st)
gs = pb((one.(l), nothing))[3]
# Optimization
st_opt = Optimisers.setup(Adam(0.0001f0), ps)
Expand Down
11 changes: 9 additions & 2 deletions docs/src/manual/interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ First let's set the expectations straight.
functionality in the core library (and officially supported ones) **must** adhere to
the interface

!!! tip

While writing out a custom struct and defining dispatches manually is a good way to
understand the interface, it is not the most concise way. We recommend using the
[`Lux.@compact`](@ref) macro to define layers which makes handling the states and
parameters downright trivial.

## Layer Interface

### Singular Layer
Expand All @@ -35,8 +42,8 @@ architecture cannot change.

!!! tip

For people coming from Flux.jl background this might be weird. We recommend checking out
[the Flux to Lux migration guide](@ref migrate-from-flux) first before proceeding.
For people coming from Flux.jl background, this might be weird. We recommend checking
out [the Flux to Lux migration guide](@ref migrate-from-flux) first before proceeding.

```@example layer_interface
using Lux, Random
Expand Down
33 changes: 29 additions & 4 deletions examples/NeuralODE/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,23 @@ function loadmnist(batchsize, train_split)
end

# ## Define the Neural ODE Layer
#
#
# First we will use the [`@compact`](@ref) macro to define the Neural ODE Layer.

function NeuralODECompact(
model::Lux.AbstractExplicitLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...)
return @compact(; model, solver, tspan, kwargs...) do x, p
dudt(u, p, t) = vec(model(reshape(u, size(x)), p))
## Note the `p.model` here
prob = ODEProblem(ODEFunction{false}(dudt), vec(x), tspan, p.model)
return solve(prob, solver; kwargs...)
end
end

# We recommend using the compact macro for creating custom layers. The below implementation
# exists mostly for historical reasons when `@compact` was not part of the stable API. Also,
# it helps users understand how the layer interface of Lux works.

# The NeuralODE is a ContainerLayer, which stores a `model`. The parameters and states of
# the NeuralODE are same as those of the underlying model.
struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, T, K} <:
Expand Down Expand Up @@ -154,6 +170,8 @@ function train(model_function; cpu::Bool=false, kwargs...)
end
end

train(NeuralODECompact)

train(NeuralODE)

# We can also change the sensealg and train the model! `GaussAdjoint` allows you to use
Expand All @@ -173,8 +191,9 @@ train(NeuralODE; sensealg=ReverseDiffAdjoint(), cpu=true)

# ## Alternate Implementation using Stateful Layer

# Starting `v0.5.5`, Lux provides a `Lux.Experimental.StatefulLuxLayer` which can be used
# to avoid the [`Box`ing of `st`](https://github.com/JuliaLang/julia/issues/15276).
# Starting `v0.5.5`, Lux provides a [`StatefulLuxLayer`](@ref) which can be used
# to avoid the [`Box`ing of `st`](https://github.com/JuliaLang/julia/issues/15276). Using
# the `@compact` API avoids this problem entirely.
struct StatefulNeuralODE{M <: Lux.AbstractExplicitLayer, So, T, K} <:
Lux.AbstractExplicitContainerLayer{(:model,)}
model::M
Expand All @@ -189,7 +208,7 @@ function StatefulNeuralODE(
end

function (n::StatefulNeuralODE)(x, ps, st)
st_model = Lux.StatefulLuxLayer(n.model, ps, st)
st_model = StatefulLuxLayer(n.model, ps, st)
dudt(u, p, t) = st_model(u, p)
prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps)
return solve(prob, n.solver; n.kwargs...), st_model.st
Expand Down Expand Up @@ -217,5 +236,11 @@ x = gpu_device()(ones(Float32, 28, 28, 1, 3));

@code_warntype model_stateful(x, ps_stateful, st_stateful)

# Finally checking the compact model

model_compact, ps_compact, st_compact = create_model(NeuralODECompact)

@code_warntype model_compact(x, ps_compact, st_compact)

# Note, that we still recommend using this layer internally and not exposing this as the
# default API to the users.
31 changes: 28 additions & 3 deletions examples/SimpleRNN/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,26 @@ function (s::SpiralClassifier)(
return vec(y), st
end

# ## Using the `@compact` API

# We can also define the model using the [`Lux.@compact`](@ref) API, which is a more concise
# way of defining models. This macro automatically handles the boilerplate code for you and
# as such we recommend this way of defining custom layers

function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell=lstm_cell,
classifier=classifier) do x::AbstractArray{T, 3} where {T}
x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
y, carry = lstm_cell((x, carry))
end
return vec(classifier(y))
end
end

# ## Defining Accuracy, Loss and Optimiser

# Now let's define the binarycrossentropy loss. Typically it is recommended to use
Expand All @@ -125,12 +145,12 @@ accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)

# ## Training the Model

function main()
function main(model_type)
## Get the dataloaders
(train_loader, val_loader) = get_dataloaders()

## Create the model
model = SpiralClassifier(2, 8, 1)
model = model_type(2, 8, 1)
rng = Xoshiro(0)

dev = gpu_device()
Expand Down Expand Up @@ -164,7 +184,12 @@ function main()
return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main()
ps_trained, st_trained = main(SpiralClassifier)
nothing #hide

# We can also train the compact model with the exact same code!

ps_trained2, st_trained2 = main(SpiralClassifierCompact)
nothing #hide

# ## Saving the Model
Expand Down
113 changes: 75 additions & 38 deletions src/helpers/compact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ end
@compact(kw...) do x
...
end
@compact(kw...) do x, p
...
end
@compact(forward::Function; name=nothing, dispatch=nothing, parameters...)
Creates a layer by specifying some `parameters`, in the form of keywords, and (usually as a
Expand All @@ -21,12 +24,23 @@ Creates a layer by specifying some `parameters`, in the form of keywords, and (u
be used within the body of the `forward` function. Note that unlike typical Lux models, the
forward function doesn't need to explicitly manage states.
Defining the version with `p` allows you to access the parameters in the forward pass. This
is useful when using it with SciML tools which require passing in the parameters explicitly.
## Reserved Kwargs:
1. `name`: The name of the layer.
2. `dispatch`: The constructed layer has the type
`Lux.Experimental.CompactLuxLayer{dispatch}` which can be used for custom dispatches.
!!! tip
Check the Lux tutorials for more examples of using `@compact`.
If you are passing in kwargs by splatting them, they will be passed as is to the function
body. This means if your splatted kwargs contain a lux layer that won't be registered
in the CompactLuxLayer.
## Examples
Here is a linear model:
Expand Down Expand Up @@ -161,32 +175,18 @@ function __compact_macro_impl(_exs...)
kwexs = (kwexs1..., kwexs2...)

# check if user has named layer
name_idx = findfirst(ex -> ex.args[1] == :name, kwexs)
name = nothing
if name_idx !== nothing && kwexs[name_idx].args[2] !== nothing
if length(kwexs) == 1
throw(LuxCompactModelParsingException("expects keyword arguments"))
end
name = kwexs[name_idx].args[2]
# remove name from kwexs (a tuple)
kwexs = (kwexs[1:(name_idx - 1)]..., kwexs[(name_idx + 1):end]...)
end
name, kwexs = __extract_reserved_kwarg(kwexs, :name)

# check if user has provided a custom dispatch
dispatch_idx = findfirst(ex -> ex.args[1] == :dispatch, kwexs)
dispatch = nothing
if dispatch_idx !== nothing && kwexs[dispatch_idx].args[2] !== nothing
if length(kwexs) == 1
throw(LuxCompactModelParsingException("expects keyword arguments"))
end
dispatch = kwexs[dispatch_idx].args[2]
# remove dispatch from kwexs (a tuple)
kwexs = (kwexs[1:(dispatch_idx - 1)]..., kwexs[(dispatch_idx + 1):end]...)
end
dispatch, kwexs = __extract_reserved_kwarg(kwexs, :dispatch)

# Extract splatted kwargs
splat_idxs = findall(ex -> ex.head == :..., kwexs)
splatted_kwargs = map(first Base.Fix2(getproperty, :args), kwexs[splat_idxs])
kwexs = filter(ex -> ex.head != :..., kwexs)

# make strings
layer = "@compact"
setup = NamedTuple(map(ex -> Symbol(string(ex.args[1])) => string(ex.args[2]), kwexs))
input = try
fex_args = fex.args[1]
isa(fex_args, Symbol) ? string(fex_args) : join(fex_args.args, ", ")
Expand All @@ -198,28 +198,43 @@ function __compact_macro_impl(_exs...)

# edit expressions
vars = map(first Base.Fix2(getproperty, :args), kwexs)
fex = supportself(fex, vars)
fex = supportself(fex, vars, splatted_kwargs)

display(fex)

# assemble
return esc(:($CompactLuxLayer{$dispatch}(
$fex, $name, ($layer, $input, $block), $setup; $(kwexs...))))
return esc(:($CompactLuxLayer{$dispatch}($fex, $name, ($layer, $input, $block),
(($(Meta.quot.(splatted_kwargs)...),), ($(splatted_kwargs...),)); $(kwexs...))))
end

function __extract_reserved_kwarg(kwexs, sym::Symbol)
idx = findfirst(ex -> ex.args[1] == sym, kwexs)
val = nothing
if idx !== nothing && kwexs[idx].args[2] !== nothing
length(kwexs) == 1 &&
throw(LuxCompactModelParsingException("expects keyword arguments"))
val = kwexs[idx].args[2]
kwexs = (kwexs[1:(idx - 1)]..., kwexs[(idx + 1):end]...)
end
return val, kwexs
end

function supportself(fex::Expr, vars)
function supportself(fex::Expr, vars, splatted_kwargs)
@gensym self ps st curried_f res
# To avoid having to manipulate fex's arguments and body explicitly, we split the input
# function body and add the required arguments to the function definition.
sdef = splitdef(fex)
if length(sdef[:args]) != 1
throw(LuxCompactModelParsingException("expects exactly 1 argument"))
end
args = [self, sdef[:args]..., ps, st]
custom_param = length(sdef[:args]) == 2
length(sdef[:args]) > 2 &&
throw(LuxCompactModelParsingException("expects at most 2 arguments"))
args = [self, sdef[:args][1], ps, st]
calls = []
for var in vars
push!(calls,
:($var = $(__maybe_make_stateful)($(_getproperty)($self, $(Val(var))),
$(_getproperty)($ps, $(Val(var))), $(_getproperty)($st, $(Val(var))))))
end
custom_param && push!(calls, :($(sdef[:args][2]) = $ps))
body = Expr(:let, Expr(:block, calls...), sdef[:body])
sdef[:body] = body
sdef[:args] = args
Expand All @@ -229,7 +244,7 @@ end
@inline function __maybe_make_stateful(layer::AbstractExplicitLayer, ps, st)
return StatefulLuxLayer(layer, ps, st)
end
@inline __maybe_make_stateful(::Nothing, ps, st) = ps === nothing ? st : ps
@inline __maybe_make_stateful(::Nothing, ps, st) = ifelse(ps === nothing, st, ps)
@inline function __maybe_make_stateful(model::Union{AbstractVector, Tuple}, ps, st)
return map(i -> __maybe_make_stateful(model[i], ps[i], st[i]), eachindex(model))
end
Expand All @@ -248,13 +263,13 @@ end
function ValueStorage(; kwargs...)
ps_init_fns, st_init_fns = [], []
for (key, val) in pairs(kwargs)
push!(val isa AbstractArray ? ps_init_fns : st_init_fns, key => () -> val)
push!(val isa AbstractArray ? ps_init_fns : st_init_fns, key => Returns(val))
end
return ValueStorage(NamedTuple(ps_init_fns), NamedTuple(st_init_fns))
end

function (v::ValueStorage)(x, ps, st)
throw(ArgumentError("ValueStorage isn't meant to be used as a layer!!!"))
throw(ArgumentError("`ValueStorage` isn't meant to be used as a layer!!!"))
end

function initialparameters(::AbstractRNG, v::ValueStorage)
Expand All @@ -273,6 +288,7 @@ end
setup_strings
layers
value_storage
stored_kwargs
end

function ConstructionBase.constructorof(::Type{<:CompactLuxLayer{dispatch}}) where {dispatch}
Expand Down Expand Up @@ -300,27 +316,48 @@ function __try_make_lux_layer(x)
return fmap(__maybe_convert_layer, x)
end

function CompactLuxLayer{dispatch}(f::Function, name::NAME_TYPE, str::Tuple,
setup_str::NamedTuple; kws...) where {dispatch}
function CompactLuxLayer{dispatch}(
f::F, name::NAME_TYPE, str::Tuple, splatted_kwargs; kws...) where {F, dispatch}
layers, others = [], []
setup_strings = NamedTuple()
for (name, val) in pairs(kws)
is_lux_layer = false
if val isa AbstractExplicitLayer
is_lux_layer = true
push!(layers, name => val)
elseif LuxCore.contains_lux_layer(val)
# TODO: Rearrange Tuple and Vectors to NamedTuples for proper CA.jl support
# FIXME: This might lead to incorrect constructions? If the function is a closure over the provided keyword arguments?
# FIXME: This might lead to incorrect constructions? If the function is a
# closure over the provided keyword arguments?
val = __try_make_lux_layer(val)
if LuxCore.check_fmap_condition(
!Base.Fix2(isa, AbstractExplicitLayer), nothing, val)
throw(LuxCompactModelParsingException("A container `$(name) = $(val)` is found which combines Lux layers with non-Lux layers. This is not supported."))
throw(LuxCompactModelParsingException("A container `$(name) = $(val)` is \
found which combines Lux layers \
with non-Lux layers. This is not \
supported."))
end
is_lux_layer = true
push!(layers, name => val)
else
push!(others, name => val)
end

if is_lux_layer
setup_strings = merge(setup_strings, NamedTuple((name => val,)))
else
setup_strings = merge(setup_strings,
NamedTuple((name => sprint(
show, val; context=(:compact => true, :limit => true)),)))
end
end
return CompactLuxLayer{dispatch}(
f, name, str, setup_str, NamedTuple((; layers...)), ValueStorage(; others...))

for (kw_name, kw_val) in zip(splatted_kwargs[1], splatted_kwargs[2])
push!(others, kw_name => kw_val)
end

return CompactLuxLayer{dispatch}(f, name, str, setup_strings, NamedTuple((; layers...)),
ValueStorage(; others...), nothing)
end

function (m::CompactLuxLayer)(x, ps, st::NamedTuple{fields}) where {fields}
Expand Down
Loading

0 comments on commit 11289ed

Please sign in to comment.