diff --git a/.github/workflows/SpellCheck.yml b/.github/workflows/SpellCheck.yml new file mode 100644 index 000000000..ed4fe1779 --- /dev/null +++ b/.github/workflows/SpellCheck.yml @@ -0,0 +1,13 @@ +name: Spell Check + +on: [pull_request] + +jobs: + typos-check: + name: Spell Check with Typos + runs-on: ubuntu-latest + steps: + - name: Checkout Actions Repository + uses: actions/checkout@v4 + - name: Check spelling + uses: crate-ci/typos@v1.18.0 diff --git a/.typos.toml b/.typos.toml new file mode 100644 index 000000000..e2b3e6f9a --- /dev/null +++ b/.typos.toml @@ -0,0 +1,2 @@ +[default.extend-words] +numer = "numer" \ No newline at end of file diff --git a/Project.toml b/Project.toml index f31e75d73..f2231f037 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.5.34" +version = "0.5.35" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/src/api/Lux/contrib.md b/docs/src/api/Lux/contrib.md index b4cd1ec24..bf9c9c238 100644 --- a/docs/src/api/Lux/contrib.md +++ b/docs/src/api/Lux/contrib.md @@ -6,8 +6,8 @@ CurrentModule = Lux All features listed on this page are **experimental** which means: -1. No SemVer Guarantees. We use code here to iterate fast and most users should wait for - these features to be marked non-experimental. +1. No SemVer Guarantees. We use code here to iterate fast. That said, historically we have + never broken any code in this module and have always provided a deprecation period. 2. Expect edge-cases and report them. It will help us move these features out of experimental sooner. 3. None of the features are exported. @@ -74,8 +74,14 @@ Lux.Experimental.DebugLayer Lux.Experimental.share_parameters ``` +## StatefulLuxLayer + +[`Lux.StatefulLuxLayer`](@ref) used to be part of experimental features, but has been +promoted to stable API. It is now available via `Lux.StatefulLuxLayer`. Change all uses of +`Lux.Experimental.StatefulLuxLayer` to `Lux.StatefulLuxLayer`. + ## Compact Layer API -```@docs -Lux.Experimental.@compact -``` +[`Lux.@compact`](@ref) used to be part of experimental features, but has been promoted to +stable API. It is now available via `Lux.@compact`. Change all uses of +`Lux.Experimental.@compact` to `Lux.@compact`. diff --git a/docs/src/api/Lux/utilities.md b/docs/src/api/Lux/utilities.md index acfd4e245..0c8aa7dc5 100644 --- a/docs/src/api/Lux/utilities.md +++ b/docs/src/api/Lux/utilities.md @@ -54,6 +54,11 @@ Lux.f64 StatefulLuxLayer ``` +## Compact Layer + +```@docs +@compact +``` ## Truncated Stacktraces diff --git a/docs/src/introduction/index.md b/docs/src/introduction/index.md index a17d6ebd1..bb8e36d44 100644 --- a/docs/src/introduction/index.md +++ b/docs/src/introduction/index.md @@ -48,21 +48,21 @@ standard AD and Optimisers API. ```@example quickstart # Get the device determined by Lux -device = gpu_device() +dev = gpu_device() # Parameter and State Variables -ps, st = Lux.setup(rng, model) .|> device +ps, st = Lux.setup(rng, model) .|> dev # Dummy Input -x = rand(rng, Float32, 128, 2) |> device +x = rand(rng, Float32, 128, 2) |> dev # Run the model y, st = Lux.apply(model, x, ps, st) # Gradients ## Pullback API to capture change in state -(l, st_), pb = pullback(p -> Lux.apply(model, x, p, st), ps) -gs = pb((one.(l), nothing))[1] +(l, st_), pb = pullback(Lux.apply, model, x, ps, st) +gs = pb((one.(l), nothing))[3] # Optimization st_opt = Optimisers.setup(Adam(0.0001f0), ps) @@ -74,7 +74,6 @@ st_opt, ps = Optimisers.update(st_opt, ps, gs) ```@example custom_compact using Lux, Random, Optimisers, Zygote # using LuxCUDA, LuxAMDGPU, Metal # Optional packages for GPU support -import Lux.Experimental: @compact using Printf # For pretty printing ``` diff --git a/docs/src/manual/interface.md b/docs/src/manual/interface.md index ed916e9fc..948c9a0ad 100644 --- a/docs/src/manual/interface.md +++ b/docs/src/manual/interface.md @@ -20,6 +20,13 @@ First let's set the expectations straight. functionality in the core library (and officially supported ones) **must** adhere to the interface +!!! tip + + While writing out a custom struct and defining dispatches manually is a good way to + understand the interface, it is not the most concise way. We recommend using the + [`Lux.@compact`](@ref) macro to define layers which makes handling the states and + parameters downright trivial. + ## Layer Interface ### Singular Layer @@ -35,8 +42,8 @@ architecture cannot change. !!! tip - For people coming from Flux.jl background this might be weird. We recommend checking out - [the Flux to Lux migration guide](@ref migrate-from-flux) first before proceeding. + For people coming from Flux.jl background, this might be weird. We recommend checking + out [the Flux to Lux migration guide](@ref migrate-from-flux) first before proceeding. ```@example layer_interface using Lux, Random @@ -80,7 +87,7 @@ reconstruction of the parameters and states. println("Parameter Length: ", Lux.parameterlength(l), "; State Length: ", Lux.statelength(l)) -# But still recommened to define these +# But still recommended to define these Lux.parameterlength(l::Linear) = l.out_dims * l.in_dims + l.out_dims Lux.statelength(::Linear) = 0 diff --git a/docs/src/manual/migrate_from_flux.md b/docs/src/manual/migrate_from_flux.md index 36acda7a7..6a3248257 100644 --- a/docs/src/manual/migrate_from_flux.md +++ b/docs/src/manual/migrate_from_flux.md @@ -99,7 +99,7 @@ end # `A` is not trainable Optimisers.trainable(f::FluxLinear) = (B=f.B,) -# Needed so that both `A` and `B` can be transfered between devices +# Needed so that both `A` and `B` can be transferred between devices Flux.@functor FluxLinear (l::FluxLinear)(x) = l.A * l.B * x diff --git a/docs/src/tutorials/index.md b/docs/src/tutorials/index.md index 36b4afb4f..25a73ad55 100644 --- a/docs/src/tutorials/index.md +++ b/docs/src/tutorials/index.md @@ -59,7 +59,7 @@ const advanced = [ } ]; -const thrid_party = [ +const third_party = [ { href: "https://docs.sciml.ai/Overview/stable/showcase/pinngpu/", src: "../pinn.gif", @@ -114,7 +114,7 @@ of them are non-functional and we will try to get them updated. ::: - + ::: tip diff --git a/examples/DDIM/README.md b/examples/DDIM/README.md index d5bb55040..6e3dd073f 100644 --- a/examples/DDIM/README.md +++ b/examples/DDIM/README.md @@ -11,7 +11,7 @@ The model generates images from Gaussian noises by denoising iterativel # Usage Install Julia and instantiate `Project.toml`. -Follwoing scripts are tested on a single NVIDIA Tesla T4 instance. +Following scripts are tested on a single NVIDIA Tesla T4 instance. ## Dataset Download and extract `Dataset images` from [102 Category Flower Dataset](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/). diff --git a/examples/GravitationalWaveForm/main.jl b/examples/GravitationalWaveForm/main.jl index 63e6fc37f..7d63a5ce9 100644 --- a/examples/GravitationalWaveForm/main.jl +++ b/examples/GravitationalWaveForm/main.jl @@ -163,7 +163,7 @@ function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where { m₁ = mass_ratio * m₂ orbit₁, orbit₂ = one2two(orbit, m₁, m₂) - waveform = h_22_strain_two_body(dt, orbit1, mass1, orbit2, mass2) + waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂) else waveform = h_22_strain_one_body(dt, orbit) end diff --git a/examples/HyperNet/main.jl b/examples/HyperNet/main.jl index b6b4401f5..824f02ab0 100644 --- a/examples/HyperNet/main.jl +++ b/examples/HyperNet/main.jl @@ -31,33 +31,25 @@ function load_datasets(n_train=1024, n_eval=32, batchsize=256) end # ## Implement a HyperNet Layer -struct HyperNet{W <: Lux.AbstractExplicitLayer, C <: Lux.AbstractExplicitLayer, A} <: - Lux.AbstractExplicitContainerLayer{(:weight_generator, :core_network)} - weight_generator::W - core_network::C - ca_axes::A -end - -function HyperNet(w::Lux.AbstractExplicitLayer, c::Lux.AbstractExplicitLayer) - ca_axes = Lux.initialparameters(Random.default_rng(), c) |> ComponentArray |> getaxes - return HyperNet(w, c, ca_axes) -end - -function Lux.initialparameters(rng::AbstractRNG, h::HyperNet) - return (weight_generator=Lux.initialparameters(rng, h.weight_generator),) +function HyperNet(weight_generator::Lux.AbstractExplicitLayer, + core_network::Lux.AbstractExplicitLayer) + ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |> + ComponentArray |> + getaxes + return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y) + ## Generate the weights + ps_new = ComponentArray(vec(weight_generator(x)), ca_axes) + return core_network(y, ps_new) + end end -function (hn::HyperNet)(x, ps, st::NamedTuple) - ps_new, st_ = hn.weight_generator(x, ps.weight_generator, st.weight_generator) - @set! st.weight_generator = st_ - return ComponentArray(vec(ps_new), hn.ca_axes), st -end +# Defining functions on the CompactLuxLayer requires some understanding of how the layer +# is structured, as such we don't recommend doing it unless you are familiar with the +# internals. In this case, we simply write it to ignore the initialization of the +# `core_network` parameters. -function (hn::HyperNet)((x, y)::T, ps, st::NamedTuple) where {T <: Tuple} - ps_ca, st = hn(x, ps, st) - pred, st_ = hn.core_network(y, ps_ca, st.core_network) - @set! st.core_network = st_ - return pred, st +function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet}) + return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),) end # ## Create and Initialize the HyperNet diff --git a/examples/NeuralODE/main.jl b/examples/NeuralODE/main.jl index 57cce80dd..2901534a9 100644 --- a/examples/NeuralODE/main.jl +++ b/examples/NeuralODE/main.jl @@ -42,7 +42,23 @@ function loadmnist(batchsize, train_split) end # ## Define the Neural ODE Layer -# +# +# First we will use the [`@compact`](@ref) macro to define the Neural ODE Layer. + +function NeuralODECompact( + model::Lux.AbstractExplicitLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...) + return @compact(; model, solver, tspan, kwargs...) do x, p + dudt(u, p, t) = vec(model(reshape(u, size(x)), p)) + ## Note the `p.model` here + prob = ODEProblem(ODEFunction{false}(dudt), vec(x), tspan, p.model) + return solve(prob, solver; kwargs...) + end +end + +# We recommend using the compact macro for creating custom layers. The below implementation +# exists mostly for historical reasons when `@compact` was not part of the stable API. Also, +# it helps users understand how the layer interface of Lux works. + # The NeuralODE is a ContainerLayer, which stores a `model`. The parameters and states of # the NeuralODE are same as those of the underlying model. struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, T, K} <: @@ -154,6 +170,8 @@ function train(model_function; cpu::Bool=false, kwargs...) end end +train(NeuralODECompact) + train(NeuralODE) # We can also change the sensealg and train the model! `GaussAdjoint` allows you to use @@ -173,8 +191,9 @@ train(NeuralODE; sensealg=ReverseDiffAdjoint(), cpu=true) # ## Alternate Implementation using Stateful Layer -# Starting `v0.5.5`, Lux provides a `Lux.Experimental.StatefulLuxLayer` which can be used -# to avoid the [`Box`ing of `st`](https://github.com/JuliaLang/julia/issues/15276). +# Starting `v0.5.5`, Lux provides a [`StatefulLuxLayer`](@ref) which can be used +# to avoid the [`Box`ing of `st`](https://github.com/JuliaLang/julia/issues/15276). Using +# the `@compact` API avoids this problem entirely. struct StatefulNeuralODE{M <: Lux.AbstractExplicitLayer, So, T, K} <: Lux.AbstractExplicitContainerLayer{(:model,)} model::M @@ -189,7 +208,7 @@ function StatefulNeuralODE( end function (n::StatefulNeuralODE)(x, ps, st) - st_model = Lux.StatefulLuxLayer(n.model, ps, st) + st_model = StatefulLuxLayer(n.model, ps, st) dudt(u, p, t) = st_model(u, p) prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps) return solve(prob, n.solver; n.kwargs...), st_model.st @@ -219,3 +238,9 @@ x = gpu_device()(ones(Float32, 28, 28, 1, 3)); # Note, that we still recommend using this layer internally and not exposing this as the # default API to the users. + +# Finally checking the compact model + +model_compact, ps_compact, st_compact = create_model(NeuralODECompact) + +@code_warntype model_compact(x, ps_compact, st_compact) diff --git a/examples/SimpleRNN/main.jl b/examples/SimpleRNN/main.jl index 116edd8cc..0f1791408 100644 --- a/examples/SimpleRNN/main.jl +++ b/examples/SimpleRNN/main.jl @@ -99,6 +99,25 @@ function (s::SpiralClassifier)( return vec(y), st end +# ## Using the `@compact` API + +# We can also define the model using the [`Lux.@compact`](@ref) API, which is a more concise +# way of defining models. This macro automatically handles the boilerplate code for you and +# as such we recommend this way of defining custom layers + +function SpiralClassifierCompact(in_dims, hidden_dims, out_dims) + lstm_cell = LSTMCell(in_dims => hidden_dims) + classifier = Dense(hidden_dims => out_dims, sigmoid) + return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T} + x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2))) + y, carry = lstm_cell(x_init) + for x in x_rest + y, carry = lstm_cell((x, carry)) + end + return vec(classifier(y)) + end +end + # ## Defining Accuracy, Loss and Optimiser # Now let's define the binarycrossentropy loss. Typically it is recommended to use @@ -125,12 +144,12 @@ accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred) # ## Training the Model -function main() +function main(model_type) ## Get the dataloaders (train_loader, val_loader) = get_dataloaders() ## Create the model - model = SpiralClassifier(2, 8, 1) + model = model_type(2, 8, 1) rng = Xoshiro(0) dev = gpu_device() @@ -164,7 +183,12 @@ function main() return (train_state.parameters, train_state.states) |> cpu_device() end -ps_trained, st_trained = main() +ps_trained, st_trained = main(SpiralClassifier) +nothing #hide + +# We can also train the compact model with the exact same code! + +ps_trained2, st_trained2 = main(SpiralClassifierCompact) nothing #hide # ## Saving the Model diff --git a/src/Lux.jl b/src/Lux.jl index e7ab1645d..69cabe120 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -26,6 +26,10 @@ using PrecompileTools: @recompile_invalidations inputsize, outputsize, update_state, trainmode, testmode, setup, apply, display_name, replicate using LuxDeviceUtils: get_device + + # @compact specific + using MacroTools: block, combinedef, splitdef + using ConstructionBase: ConstructionBase end @reexport using LuxCore, LuxLib, LuxDeviceUtils, WeightInitializers @@ -56,6 +60,7 @@ include("contrib/contrib.jl") # Helpful Functionalities include("helpers/stateful.jl") +include("helpers/compact.jl") # Transform to and from other frameworks include("transform/types.jl") @@ -70,7 +75,8 @@ include("distributed/public_api.jl") include("deprecated.jl") # Layers -export cpu, gpu +export cpu, gpu # deprecated + export Chain, Parallel, SkipConnection, PairwiseFusion, BranchLayer, Maxout, RepeatedLayer export Bilinear, Dense, Embedding, Scale export Conv, ConvTranspose, CrossCor, MaxPool, MeanPool, GlobalMaxPool, GlobalMeanPool, @@ -83,6 +89,7 @@ export RNNCell, LSTMCell, GRUCell, Recurrence, StatefulRecurrentCell export SamePad, TimeLastIndex, BatchLastIndex export StatefulLuxLayer +export @compact, CompactLuxLayer export f16, f32, f64 diff --git a/src/contrib/contrib.jl b/src/contrib/contrib.jl index b72d545f3..ea65a653b 100644 --- a/src/contrib/contrib.jl +++ b/src/contrib/contrib.jl @@ -3,15 +3,12 @@ module Experimental import ..Lux using ..Lux, LuxCore, LuxDeviceUtils, Random using LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer -import ..Lux: _merge, _pairs, initialstates, initialparameters, apply, NAME_TYPE, - _getproperty +import ..Lux: _merge, _pairs, initialstates, initialparameters, apply using ADTypes: ADTypes import ChainRulesCore as CRC using ConcreteStructs: @concrete -import ConstructionBase: constructorof using Functors: Functors, fmap, functor -using MacroTools: block, combinedef, splitdef using Markdown: @doc_str using Random: AbstractRNG, Random using Setfield: Setfield @@ -21,8 +18,7 @@ include("training.jl") include("freeze.jl") include("share_parameters.jl") include("debug.jl") -include("stateful.jl") -include("compact.jl") +include("deprecated.jl") end diff --git a/src/contrib/stateful.jl b/src/contrib/deprecated.jl similarity index 53% rename from src/contrib/stateful.jl rename to src/contrib/deprecated.jl index 2ba8293c0..6496b6ddd 100644 --- a/src/contrib/stateful.jl +++ b/src/contrib/deprecated.jl @@ -1,4 +1,12 @@ -# Deprecated +macro compact(exs...) + Base.depwarn( + "Lux.Experimental.@compact` has been promoted out of `Lux.Experimental` and is now \ + available in `Lux`. In other words this has been deprecated and will be removed \ + in v0.6. Use `Lux.@compact` instead.", + Symbol("@compact")) + return Lux.__compact_macro_impl(exs...) +end + function StatefulLuxLayer(args...; kwargs...) Base.depwarn( "Lux.Experimental.StatefulLuxLayer` has been promoted out of `Lux.Experimental` \ diff --git a/src/deprecated.jl b/src/deprecated.jl index 4507de522..b5cb0aedd 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -4,7 +4,7 @@ Transfer `x` to CPU. -!!! warning +!!! danger This function has been deprecated. Use [`cpu_device`](@ref) instead. """ @@ -19,7 +19,7 @@ end Transfer `x` to GPU determined by the backend set using [`Lux.gpu_backend!`](@ref). -!!! warning +!!! danger This function has been deprecated. Use [`gpu_device`](@ref) instead. Using this function inside performance critical code will cause massive slowdowns due to type inference @@ -41,6 +41,10 @@ end An easy way to update `TruncatedStacktraces.VERBOSE` without having to load it manually. Effectively does `TruncatedStacktraces.VERBOSE[] = disable` + +!!! danger + + This function is now deprecated and will be removed in v0.6. """ function disable_stacktrace_truncation!(; disable::Bool=true) Base.depwarn("`disable_stacktrace_truncation!` is not needed anymore, as \ diff --git a/src/distributed/public_api.jl b/src/distributed/public_api.jl index 945518fd9..ae7eda7c1 100644 --- a/src/distributed/public_api.jl +++ b/src/distributed/public_api.jl @@ -179,7 +179,7 @@ function __reduce! end CRC.@non_differentiable reduce!(::Any...) -# syncronize! +# synchronize! """ synchronize!!(backend::AbstractLuxDistributedBackend, ps; root::Int=0) diff --git a/src/contrib/compact.jl b/src/helpers/compact.jl similarity index 70% rename from src/contrib/compact.jl rename to src/helpers/compact.jl index 4993e51e3..399d80f9e 100644 --- a/src/contrib/compact.jl +++ b/src/helpers/compact.jl @@ -13,6 +13,9 @@ end @compact(kw...) do x ... end + @compact(kw...) do x, p + ... + end @compact(forward::Function; name=nothing, dispatch=nothing, parameters...) Creates a layer by specifying some `parameters`, in the form of keywords, and (usually as a @@ -21,19 +24,29 @@ Creates a layer by specifying some `parameters`, in the form of keywords, and (u be used within the body of the `forward` function. Note that unlike typical Lux models, the forward function doesn't need to explicitly manage states. +Defining the version with `p` allows you to access the parameters in the forward pass. This +is useful when using it with SciML tools which require passing in the parameters explicitly. + ## Reserved Kwargs: 1. `name`: The name of the layer. 2. `dispatch`: The constructed layer has the type `Lux.Experimental.CompactLuxLayer{dispatch}` which can be used for custom dispatches. +!!! tip + + Check the Lux tutorials for more examples of using `@compact`. + +If you are passing in kwargs by splatting them, they will be passed as is to the function +body. This means if your splatted kwargs contain a lux layer that won't be registered +in the CompactLuxLayer. + ## Examples Here is a linear model: ```julia using Lux, Random -import Lux.Experimental: @compact r = @compact(w=rand(3)) do x return w .* x @@ -123,6 +136,11 @@ used inside a `Chain`. account for the total number of parameters printed at the bottom. """ macro compact(_exs...) + return __compact_macro_impl(_exs...) +end + +# Needed for the deprecation path +function __compact_macro_impl(_exs...) # check inputs, extracting function expression fex and unprocessed keyword arguments _kwexs if isempty(_exs) msg = "expects at least two expressions: a function and at least one keyword" @@ -157,65 +175,64 @@ macro compact(_exs...) kwexs = (kwexs1..., kwexs2...) # check if user has named layer - name_idx = findfirst(ex -> ex.args[1] == :name, kwexs) - name = nothing - if name_idx !== nothing && kwexs[name_idx].args[2] !== nothing - if length(kwexs) == 1 - throw(LuxCompactModelParsingException("expects keyword arguments")) - end - name = kwexs[name_idx].args[2] - # remove name from kwexs (a tuple) - kwexs = (kwexs[1:(name_idx - 1)]..., kwexs[(name_idx + 1):end]...) - end + name, kwexs = __extract_reserved_kwarg(kwexs, :name) # check if user has provided a custom dispatch - dispatch_idx = findfirst(ex -> ex.args[1] == :dispatch, kwexs) - dispatch = nothing - if dispatch_idx !== nothing && kwexs[dispatch_idx].args[2] !== nothing - if length(kwexs) == 1 - throw(LuxCompactModelParsingException("expects keyword arguments")) - end - dispatch = kwexs[dispatch_idx].args[2] - # remove dispatch from kwexs (a tuple) - kwexs = (kwexs[1:(dispatch_idx - 1)]..., kwexs[(dispatch_idx + 1):end]...) - end + dispatch, kwexs = __extract_reserved_kwarg(kwexs, :dispatch) + + # Extract splatted kwargs + splat_idxs = findall(ex -> ex.head == :..., kwexs) + splatted_kwargs = map(first ∘ Base.Fix2(getproperty, :args), kwexs[splat_idxs]) + kwexs = filter(ex -> ex.head != :..., kwexs) # make strings layer = "@compact" - setup = NamedTuple(map(ex -> Symbol(string(ex.args[1])) => string(ex.args[2]), kwexs)) input = try fex_args = fex.args[1] isa(fex_args, Symbol) ? string(fex_args) : join(fex_args.args, ", ") catch e - @warn "Function stringifying does not yet handle all cases. Falling back to empty string for input arguments" + @warn "Function stringifying does not yet handle all cases. Falling back to empty \ + string for input arguments" end block = string(Base.remove_linenums!(fex).args[2]) # edit expressions - vars = map(ex -> ex.args[1], kwexs) - fex = supportself(fex, vars) + vars = map(first ∘ Base.Fix2(getproperty, :args), kwexs) + fex = supportself(fex, vars, splatted_kwargs) # assemble - return esc(:($CompactLuxLayer{$dispatch}( - $fex, $name, ($layer, $input, $block), $setup; $(kwexs...)))) + return esc(:($CompactLuxLayer{$dispatch}($fex, $name, ($layer, $input, $block), + (($(Meta.quot.(splatted_kwargs)...),), ($(splatted_kwargs...),)); $(kwexs...)))) end -function supportself(fex::Expr, vars) +function __extract_reserved_kwarg(kwexs, sym::Symbol) + idx = findfirst(ex -> ex.args[1] == sym, kwexs) + val = nothing + if idx !== nothing && kwexs[idx].args[2] !== nothing + length(kwexs) == 1 && + throw(LuxCompactModelParsingException("expects keyword arguments")) + val = kwexs[idx].args[2] + kwexs = (kwexs[1:(idx - 1)]..., kwexs[(idx + 1):end]...) + end + return val, kwexs +end + +function supportself(fex::Expr, vars, splatted_kwargs) @gensym self ps st curried_f res # To avoid having to manipulate fex's arguments and body explicitly, we split the input # function body and add the required arguments to the function definition. sdef = splitdef(fex) - if length(sdef[:args]) != 1 - throw(LuxCompactModelParsingException("expects exactly 1 argument")) - end - args = [self, sdef[:args]..., ps, st] + custom_param = length(sdef[:args]) == 2 + length(sdef[:args]) > 2 && + throw(LuxCompactModelParsingException("expects at most 2 arguments")) + args = [self, sdef[:args][1], ps, st] calls = [] for var in vars push!(calls, - :($var = Lux.Experimental.__maybe_make_stateful( - Lux._getproperty($self, $(Val(var))), - Lux._getproperty($ps, $(Val(var))), Lux._getproperty($st, $(Val(var)))))) + :($var = $(__maybe_make_stateful)($(_getproperty)($self, $(Val(var))), + $(_getproperty)($ps, $(Val(var))), $(_getproperty)($st, $(Val(var)))))) end + custom_param && push!(calls, :($(sdef[:args][2]) = $ps)) body = Expr(:let, Expr(:block, calls...), sdef[:body]) sdef[:body] = body sdef[:args] = args @@ -223,9 +240,9 @@ function supportself(fex::Expr, vars) end @inline function __maybe_make_stateful(layer::AbstractExplicitLayer, ps, st) - return Lux.StatefulLuxLayer(layer, ps, st) + return StatefulLuxLayer(layer, ps, st) end -@inline __maybe_make_stateful(::Nothing, ps, st) = ps === nothing ? st : ps +@inline __maybe_make_stateful(::Nothing, ps, st) = ifelse(ps === nothing, st, ps) @inline function __maybe_make_stateful(model::Union{AbstractVector, Tuple}, ps, st) return map(i -> __maybe_make_stateful(model[i], ps[i], st[i]), eachindex(model)) end @@ -244,13 +261,13 @@ end function ValueStorage(; kwargs...) ps_init_fns, st_init_fns = [], [] for (key, val) in pairs(kwargs) - push!(val isa AbstractArray ? ps_init_fns : st_init_fns, key => () -> val) + push!(val isa AbstractArray ? ps_init_fns : st_init_fns, key => Returns(val)) end return ValueStorage(NamedTuple(ps_init_fns), NamedTuple(st_init_fns)) end function (v::ValueStorage)(x, ps, st) - throw(ArgumentError("ValueStorage isn't meant to be used as a layer!!!")) + throw(ArgumentError("`ValueStorage` isn't meant to be used as a layer!!!")) end function initialparameters(::AbstractRNG, v::ValueStorage) @@ -269,9 +286,10 @@ end setup_strings layers value_storage + stored_kwargs end -function constructorof(::Type{<:CompactLuxLayer{dispatch}}) where {dispatch} +function ConstructionBase.constructorof(::Type{<:CompactLuxLayer{dispatch}}) where {dispatch} return CompactLuxLayer{dispatch} end @@ -288,7 +306,7 @@ function __try_make_lux_layer(x::Union{AbstractVector, Tuple}) return __try_make_lux_layer(NamedTuple{Tuple(Symbol.(1:length(x)))}(x)) end function __try_make_lux_layer(x) - function __maybe_convert_layer(l) + __maybe_convert_layer = @closure l -> begin l isa AbstractExplicitLayer && return l l isa Function && return WrappedFunction(l) return l @@ -296,27 +314,47 @@ function __try_make_lux_layer(x) return fmap(__maybe_convert_layer, x) end -function CompactLuxLayer{dispatch}(f::Function, name::NAME_TYPE, str::Tuple, - setup_str::NamedTuple; kws...) where {dispatch} +function CompactLuxLayer{dispatch}( + f::F, name::NAME_TYPE, str::Tuple, splatted_kwargs; kws...) where {F, dispatch} layers, others = [], [] + setup_strings = NamedTuple() for (name, val) in pairs(kws) + is_lux_layer = false if val isa AbstractExplicitLayer + is_lux_layer = true push!(layers, name => val) elseif LuxCore.contains_lux_layer(val) # TODO: Rearrange Tuple and Vectors to NamedTuples for proper CA.jl support - # FIXME: This might lead to incorrect constructions? If the function is a closure over the provided keyword arguments? + # FIXME: This might lead to incorrect constructions? If the function is a + # closure over the provided keyword arguments? val = __try_make_lux_layer(val) if LuxCore.check_fmap_condition( !Base.Fix2(isa, AbstractExplicitLayer), nothing, val) - throw(LuxCompactModelParsingException("A container `$(name) = $(val)` is found which combines Lux layers with non-Lux layers. This is not supported.")) + throw(LuxCompactModelParsingException("A container `$(name) = $(val)` is \ + found which combines Lux layers \ + with non-Lux layers. This is not \ + supported.")) end + is_lux_layer = true push!(layers, name => val) else push!(others, name => val) end + + if is_lux_layer + setup_strings = merge(setup_strings, NamedTuple((name => val,))) + else + setup_strings = merge( + setup_strings, NamedTuple((name => __kwarg_descriptor(val),))) + end end - return CompactLuxLayer{dispatch}( - f, name, str, setup_str, NamedTuple((; layers...)), ValueStorage(; others...)) + + for (kw_name, kw_val) in zip(splatted_kwargs[1], splatted_kwargs[2]) + push!(others, kw_name => kw_val) + end + + return CompactLuxLayer{dispatch}(f, name, str, setup_strings, NamedTuple((; layers...)), + ValueStorage(; others...), nothing) end function (m::CompactLuxLayer)(x, ps, st::NamedTuple{fields}) where {fields} @@ -369,3 +407,20 @@ function Lux._big_show(io::IO, obj::CompactLuxLayer, indent::Int=0, name=nothing end return end + +function __kwarg_descriptor(val) + val isa Number && return string(val) + val isa AbstractArray && return sprint(Base.array_summary, val, axes(val)) + val isa Tuple && return "(" * join(map(__kwarg_descriptor, val), ", ") * ")" + if val isa NamedTuple + fields = fieldnames(typeof(val)) + strs = [] + for fname in fields[1:min(length(fields), 3)] + internal_val = getfield(val, fname) + push!(strs, "$fname = $(__kwarg_descriptor(internal_val))") + end + return "@NamedTuple{$(join(strs, ", "))" * (length(fields) > 3 ? ", ..." : "") * "}" + end + val isa Function && return sprint(show, val; context=(:compact => true, :limit => true)) + return lazy"$(nameof(typeof(val)))(...)" +end diff --git a/src/layers/containers.jl b/src/layers/containers.jl index 478468a73..a7384b88e 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -508,7 +508,7 @@ outputsize(c::Chain) = outputsize(c.layers[end]) This contains a number of internal layers, each of which receives the same input. Its output is the elementwise maximum of the the internal layers' outputs. -Maxout over linear dense layers satisfies the univeral approximation theorem. See [1]. +Maxout over linear dense layers satisfies the universal approximation theorem. See [1]. See also [`Parallel`](@ref) to reduce with other operators. diff --git a/src/layers/normalize.jl b/src/layers/normalize.jl index 38b587c5d..bb22b71d8 100644 --- a/src/layers/normalize.jl +++ b/src/layers/normalize.jl @@ -33,8 +33,8 @@ slice and normalises the input accordingly. - If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters. - + `init_bias`: Controls how the `bias` is initiliazed - + `init_scale`: Controls how the `scale` is initiliazed + + `init_bias`: Controls how the `bias` is initialized + + `init_scale`: Controls how the `scale` is initialized ## Inputs @@ -167,8 +167,8 @@ end - If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters. - + `init_bias`: Controls how the `bias` is initiliazed - + `init_scale`: Controls how the `scale` is initiliazed + + `init_bias`: Controls how the `bias` is initialized + + `init_scale`: Controls how the `scale` is initialized ## Inputs @@ -265,8 +265,8 @@ accordingly. - If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters. - + `init_bias`: Controls how the `bias` is initiliazed - + `init_scale`: Controls how the `scale` is initiliazed + + `init_bias`: Controls how the `bias` is initialized + + `init_scale`: Controls how the `scale` is initialized ## Inputs @@ -506,8 +506,8 @@ where ``\gamma`` & ``\beta`` are trainable parameters if `affine=true`. - If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters. - + `init_bias`: Controls how the `bias` is initiliazed - + `init_scale`: Controls how the `scale` is initiliazed + + `init_bias`: Controls how the `bias` is initialized + + `init_scale`: Controls how the `scale` is initialized ## Inputs diff --git a/src/transform/flux.jl b/src/transform/flux.jl index e64d42c7d..430c0652e 100644 --- a/src/transform/flux.jl +++ b/src/transform/flux.jl @@ -5,7 +5,7 @@ Convert a Flux Model to Lux Model. !!! warning - This always ingores the `active` field of some of the Flux layers. This is almost never + This always ignores the `active` field of some of the Flux layers. This is almost never going to be supported. ## Keyword Arguments diff --git a/test/distributed/common_distributedtest.jl b/test/distributed/common_distributedtest.jl index c5a53c5e3..8ac8e5314 100644 --- a/test/distributed/common_distributedtest.jl +++ b/test/distributed/common_distributedtest.jl @@ -19,7 +19,7 @@ nworkers = DistributedUtils.total_workers(backend) @test rank < nworkers # Test the communication primitives -## broacast! +## broadcast! for arrType in (Array, aType) sendbuf = (rank == 0) ? arrType(ones(512)) : arrType(zeros(512)) recvbuf = arrType(zeros(512)) diff --git a/test/distributed/synchronize_distributedtest.jl b/test/distributed/synchronize_distributedtest.jl index f29130426..2b49b5b14 100644 --- a/test/distributed/synchronize_distributedtest.jl +++ b/test/distributed/synchronize_distributedtest.jl @@ -80,7 +80,7 @@ gs = DistributedUtils.synchronize!!(backend, gs; root) @test all(gs[1][2] .== 1) @test all(gs[2] .== 1) -# Miscelleneous +# Miscellaneous x = nothing x = DistributedUtils.synchronize!!(backend, x; root) @test x === nothing diff --git a/test/contrib/compact_tests.jl b/test/helpers/compact_tests.jl similarity index 98% rename from test/contrib/compact_tests.jl rename to test/helpers/compact_tests.jl index bc2a97ad5..caa37fe3d 100644 --- a/test/contrib/compact_tests.jl +++ b/test/helpers/compact_tests.jl @@ -1,6 +1,5 @@ @testitem "@compact" setup=[SharedTestSetup] begin using ComponentArrays - import Lux.Experimental: @compact rng = get_stable_rng(12345) @@ -181,7 +180,7 @@ return w(x .* s) end expected_string = """@compact( - x = randn(32), + x = 32-element Vector{Float64}, w = Dense(32 => 32), # 1_056 parameters ) do s return w(x .* s) @@ -198,8 +197,8 @@ end expected_string = """@compact( w1 = Model(32)(), # 1_024 parameters - w2 = randn(32, 32), - w3 = randn(32), + w2 = 32×32 Matrix{Float64}, + w3 = 32-element Vector{Float64}, ) do x return w2 * w1(x) end # Total: 2_080 parameters,