From b3e9a5f06c096bb9ed4d7a266e866e48b32c8b3e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Apr 2024 11:16:57 -0400 Subject: [PATCH] Move compact out of experimental --- docs/src/api/Lux/contrib.md | 16 +++++++++++----- docs/src/api/Lux/utilities.md | 5 +++++ docs/src/introduction/index.md | 7 +++---- src/Lux.jl | 9 ++++++++- src/contrib/contrib.jl | 8 ++------ src/contrib/{stateful.jl => deprecated.jl} | 10 +++++++++- src/deprecated.jl | 8 ++++++-- src/{contrib => helpers}/compact.jl | 22 +++++++++++++--------- 8 files changed, 57 insertions(+), 28 deletions(-) rename src/contrib/{stateful.jl => deprecated.jl} (53%) rename src/{contrib => helpers}/compact.jl (95%) diff --git a/docs/src/api/Lux/contrib.md b/docs/src/api/Lux/contrib.md index b4cd1ec24..bf9c9c238 100644 --- a/docs/src/api/Lux/contrib.md +++ b/docs/src/api/Lux/contrib.md @@ -6,8 +6,8 @@ CurrentModule = Lux All features listed on this page are **experimental** which means: -1. No SemVer Guarantees. We use code here to iterate fast and most users should wait for - these features to be marked non-experimental. +1. No SemVer Guarantees. We use code here to iterate fast. That said, historically we have + never broken any code in this module and have always provided a deprecation period. 2. Expect edge-cases and report them. It will help us move these features out of experimental sooner. 3. None of the features are exported. @@ -74,8 +74,14 @@ Lux.Experimental.DebugLayer Lux.Experimental.share_parameters ``` +## StatefulLuxLayer + +[`Lux.StatefulLuxLayer`](@ref) used to be part of experimental features, but has been +promoted to stable API. It is now available via `Lux.StatefulLuxLayer`. Change all uses of +`Lux.Experimental.StatefulLuxLayer` to `Lux.StatefulLuxLayer`. + ## Compact Layer API -```@docs -Lux.Experimental.@compact -``` +[`Lux.@compact`](@ref) used to be part of experimental features, but has been promoted to +stable API. It is now available via `Lux.@compact`. Change all uses of +`Lux.Experimental.@compact` to `Lux.@compact`. diff --git a/docs/src/api/Lux/utilities.md b/docs/src/api/Lux/utilities.md index acfd4e245..0c8aa7dc5 100644 --- a/docs/src/api/Lux/utilities.md +++ b/docs/src/api/Lux/utilities.md @@ -54,6 +54,11 @@ Lux.f64 StatefulLuxLayer ``` +## Compact Layer + +```@docs +@compact +``` ## Truncated Stacktraces diff --git a/docs/src/introduction/index.md b/docs/src/introduction/index.md index a17d6ebd1..59e54141a 100644 --- a/docs/src/introduction/index.md +++ b/docs/src/introduction/index.md @@ -48,13 +48,13 @@ standard AD and Optimisers API. ```@example quickstart # Get the device determined by Lux -device = gpu_device() +dev = gpu_device() # Parameter and State Variables -ps, st = Lux.setup(rng, model) .|> device +ps, st = Lux.setup(rng, model) .|> dev # Dummy Input -x = rand(rng, Float32, 128, 2) |> device +x = rand(rng, Float32, 128, 2) |> dev # Run the model y, st = Lux.apply(model, x, ps, st) @@ -74,7 +74,6 @@ st_opt, ps = Optimisers.update(st_opt, ps, gs) ```@example custom_compact using Lux, Random, Optimisers, Zygote # using LuxCUDA, LuxAMDGPU, Metal # Optional packages for GPU support -import Lux.Experimental: @compact using Printf # For pretty printing ``` diff --git a/src/Lux.jl b/src/Lux.jl index e7ab1645d..69cabe120 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -26,6 +26,10 @@ using PrecompileTools: @recompile_invalidations inputsize, outputsize, update_state, trainmode, testmode, setup, apply, display_name, replicate using LuxDeviceUtils: get_device + + # @compact specific + using MacroTools: block, combinedef, splitdef + using ConstructionBase: ConstructionBase end @reexport using LuxCore, LuxLib, LuxDeviceUtils, WeightInitializers @@ -56,6 +60,7 @@ include("contrib/contrib.jl") # Helpful Functionalities include("helpers/stateful.jl") +include("helpers/compact.jl") # Transform to and from other frameworks include("transform/types.jl") @@ -70,7 +75,8 @@ include("distributed/public_api.jl") include("deprecated.jl") # Layers -export cpu, gpu +export cpu, gpu # deprecated + export Chain, Parallel, SkipConnection, PairwiseFusion, BranchLayer, Maxout, RepeatedLayer export Bilinear, Dense, Embedding, Scale export Conv, ConvTranspose, CrossCor, MaxPool, MeanPool, GlobalMaxPool, GlobalMeanPool, @@ -83,6 +89,7 @@ export RNNCell, LSTMCell, GRUCell, Recurrence, StatefulRecurrentCell export SamePad, TimeLastIndex, BatchLastIndex export StatefulLuxLayer +export @compact, CompactLuxLayer export f16, f32, f64 diff --git a/src/contrib/contrib.jl b/src/contrib/contrib.jl index b72d545f3..ea65a653b 100644 --- a/src/contrib/contrib.jl +++ b/src/contrib/contrib.jl @@ -3,15 +3,12 @@ module Experimental import ..Lux using ..Lux, LuxCore, LuxDeviceUtils, Random using LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer -import ..Lux: _merge, _pairs, initialstates, initialparameters, apply, NAME_TYPE, - _getproperty +import ..Lux: _merge, _pairs, initialstates, initialparameters, apply using ADTypes: ADTypes import ChainRulesCore as CRC using ConcreteStructs: @concrete -import ConstructionBase: constructorof using Functors: Functors, fmap, functor -using MacroTools: block, combinedef, splitdef using Markdown: @doc_str using Random: AbstractRNG, Random using Setfield: Setfield @@ -21,8 +18,7 @@ include("training.jl") include("freeze.jl") include("share_parameters.jl") include("debug.jl") -include("stateful.jl") -include("compact.jl") +include("deprecated.jl") end diff --git a/src/contrib/stateful.jl b/src/contrib/deprecated.jl similarity index 53% rename from src/contrib/stateful.jl rename to src/contrib/deprecated.jl index 2ba8293c0..6496b6ddd 100644 --- a/src/contrib/stateful.jl +++ b/src/contrib/deprecated.jl @@ -1,4 +1,12 @@ -# Deprecated +macro compact(exs...) + Base.depwarn( + "Lux.Experimental.@compact` has been promoted out of `Lux.Experimental` and is now \ + available in `Lux`. In other words this has been deprecated and will be removed \ + in v0.6. Use `Lux.@compact` instead.", + Symbol("@compact")) + return Lux.__compact_macro_impl(exs...) +end + function StatefulLuxLayer(args...; kwargs...) Base.depwarn( "Lux.Experimental.StatefulLuxLayer` has been promoted out of `Lux.Experimental` \ diff --git a/src/deprecated.jl b/src/deprecated.jl index 4507de522..b5cb0aedd 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -4,7 +4,7 @@ Transfer `x` to CPU. -!!! warning +!!! danger This function has been deprecated. Use [`cpu_device`](@ref) instead. """ @@ -19,7 +19,7 @@ end Transfer `x` to GPU determined by the backend set using [`Lux.gpu_backend!`](@ref). -!!! warning +!!! danger This function has been deprecated. Use [`gpu_device`](@ref) instead. Using this function inside performance critical code will cause massive slowdowns due to type inference @@ -41,6 +41,10 @@ end An easy way to update `TruncatedStacktraces.VERBOSE` without having to load it manually. Effectively does `TruncatedStacktraces.VERBOSE[] = disable` + +!!! danger + + This function is now deprecated and will be removed in v0.6. """ function disable_stacktrace_truncation!(; disable::Bool=true) Base.depwarn("`disable_stacktrace_truncation!` is not needed anymore, as \ diff --git a/src/contrib/compact.jl b/src/helpers/compact.jl similarity index 95% rename from src/contrib/compact.jl rename to src/helpers/compact.jl index 4993e51e3..28ae3c2ed 100644 --- a/src/contrib/compact.jl +++ b/src/helpers/compact.jl @@ -33,7 +33,6 @@ Here is a linear model: ```julia using Lux, Random -import Lux.Experimental: @compact r = @compact(w=rand(3)) do x return w .* x @@ -123,6 +122,11 @@ used inside a `Chain`. account for the total number of parameters printed at the bottom. """ macro compact(_exs...) + return __compact_macro_impl(_exs...) +end + +# Needed for the deprecation path +function __compact_macro_impl(_exs...) # check inputs, extracting function expression fex and unprocessed keyword arguments _kwexs if isempty(_exs) msg = "expects at least two expressions: a function and at least one keyword" @@ -187,12 +191,13 @@ macro compact(_exs...) fex_args = fex.args[1] isa(fex_args, Symbol) ? string(fex_args) : join(fex_args.args, ", ") catch e - @warn "Function stringifying does not yet handle all cases. Falling back to empty string for input arguments" + @warn "Function stringifying does not yet handle all cases. Falling back to empty \ + string for input arguments" end block = string(Base.remove_linenums!(fex).args[2]) # edit expressions - vars = map(ex -> ex.args[1], kwexs) + vars = map(first ∘ Base.Fix2(getproperty, :args), kwexs) fex = supportself(fex, vars) # assemble @@ -212,9 +217,8 @@ function supportself(fex::Expr, vars) calls = [] for var in vars push!(calls, - :($var = Lux.Experimental.__maybe_make_stateful( - Lux._getproperty($self, $(Val(var))), - Lux._getproperty($ps, $(Val(var))), Lux._getproperty($st, $(Val(var)))))) + :($var = $(__maybe_make_stateful)($(_getproperty)($self, $(Val(var))), + $(_getproperty)($ps, $(Val(var))), $(_getproperty)($st, $(Val(var)))))) end body = Expr(:let, Expr(:block, calls...), sdef[:body]) sdef[:body] = body @@ -223,7 +227,7 @@ function supportself(fex::Expr, vars) end @inline function __maybe_make_stateful(layer::AbstractExplicitLayer, ps, st) - return Lux.StatefulLuxLayer(layer, ps, st) + return StatefulLuxLayer(layer, ps, st) end @inline __maybe_make_stateful(::Nothing, ps, st) = ps === nothing ? st : ps @inline function __maybe_make_stateful(model::Union{AbstractVector, Tuple}, ps, st) @@ -271,7 +275,7 @@ end value_storage end -function constructorof(::Type{<:CompactLuxLayer{dispatch}}) where {dispatch} +function ConstructionBase.constructorof(::Type{<:CompactLuxLayer{dispatch}}) where {dispatch} return CompactLuxLayer{dispatch} end @@ -288,7 +292,7 @@ function __try_make_lux_layer(x::Union{AbstractVector, Tuple}) return __try_make_lux_layer(NamedTuple{Tuple(Symbol.(1:length(x)))}(x)) end function __try_make_lux_layer(x) - function __maybe_convert_layer(l) + __maybe_convert_layer = @closure l -> begin l isa AbstractExplicitLayer && return l l isa Function && return WrappedFunction(l) return l