Skip to content

Commit

Permalink
Move compact out of experimental
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 12, 2024
1 parent fc591bd commit b3e9a5f
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 28 deletions.
16 changes: 11 additions & 5 deletions docs/src/api/Lux/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ CurrentModule = Lux

All features listed on this page are **experimental** which means:

1. No SemVer Guarantees. We use code here to iterate fast and most users should wait for
these features to be marked non-experimental.
1. No SemVer Guarantees. We use code here to iterate fast. That said, historically we have
never broken any code in this module and have always provided a deprecation period.
2. Expect edge-cases and report them. It will help us move these features out of
experimental sooner.
3. None of the features are exported.
Expand Down Expand Up @@ -74,8 +74,14 @@ Lux.Experimental.DebugLayer
Lux.Experimental.share_parameters
```

## StatefulLuxLayer

[`Lux.StatefulLuxLayer`](@ref) used to be part of experimental features, but has been
promoted to stable API. It is now available via `Lux.StatefulLuxLayer`. Change all uses of
`Lux.Experimental.StatefulLuxLayer` to `Lux.StatefulLuxLayer`.

## Compact Layer API

```@docs
Lux.Experimental.@compact
```
[`Lux.@compact`](@ref) used to be part of experimental features, but has been promoted to
stable API. It is now available via `Lux.@compact`. Change all uses of
`Lux.Experimental.@compact` to `Lux.@compact`.
5 changes: 5 additions & 0 deletions docs/src/api/Lux/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ Lux.f64
StatefulLuxLayer
```

## Compact Layer

```@docs
@compact
```

## Truncated Stacktraces

Expand Down
7 changes: 3 additions & 4 deletions docs/src/introduction/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ standard AD and Optimisers API.

```@example quickstart
# Get the device determined by Lux
device = gpu_device()
dev = gpu_device()
# Parameter and State Variables
ps, st = Lux.setup(rng, model) .|> device
ps, st = Lux.setup(rng, model) .|> dev
# Dummy Input
x = rand(rng, Float32, 128, 2) |> device
x = rand(rng, Float32, 128, 2) |> dev
# Run the model
y, st = Lux.apply(model, x, ps, st)
Expand All @@ -74,7 +74,6 @@ st_opt, ps = Optimisers.update(st_opt, ps, gs)
```@example custom_compact
using Lux, Random, Optimisers, Zygote
# using LuxCUDA, LuxAMDGPU, Metal # Optional packages for GPU support
import Lux.Experimental: @compact
using Printf # For pretty printing
```

Expand Down
9 changes: 8 additions & 1 deletion src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ using PrecompileTools: @recompile_invalidations
inputsize, outputsize, update_state, trainmode, testmode, setup, apply,
display_name, replicate
using LuxDeviceUtils: get_device

# @compact specific
using MacroTools: block, combinedef, splitdef
using ConstructionBase: ConstructionBase
end

@reexport using LuxCore, LuxLib, LuxDeviceUtils, WeightInitializers
Expand Down Expand Up @@ -56,6 +60,7 @@ include("contrib/contrib.jl")

# Helpful Functionalities
include("helpers/stateful.jl")
include("helpers/compact.jl")

# Transform to and from other frameworks
include("transform/types.jl")
Expand All @@ -70,7 +75,8 @@ include("distributed/public_api.jl")
include("deprecated.jl")

# Layers
export cpu, gpu
export cpu, gpu # deprecated

export Chain, Parallel, SkipConnection, PairwiseFusion, BranchLayer, Maxout, RepeatedLayer
export Bilinear, Dense, Embedding, Scale
export Conv, ConvTranspose, CrossCor, MaxPool, MeanPool, GlobalMaxPool, GlobalMeanPool,
Expand All @@ -83,6 +89,7 @@ export RNNCell, LSTMCell, GRUCell, Recurrence, StatefulRecurrentCell
export SamePad, TimeLastIndex, BatchLastIndex

export StatefulLuxLayer
export @compact, CompactLuxLayer

export f16, f32, f64

Expand Down
8 changes: 2 additions & 6 deletions src/contrib/contrib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@ module Experimental
import ..Lux
using ..Lux, LuxCore, LuxDeviceUtils, Random
using LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer
import ..Lux: _merge, _pairs, initialstates, initialparameters, apply, NAME_TYPE,
_getproperty
import ..Lux: _merge, _pairs, initialstates, initialparameters, apply

using ADTypes: ADTypes
import ChainRulesCore as CRC
using ConcreteStructs: @concrete
import ConstructionBase: constructorof
using Functors: Functors, fmap, functor
using MacroTools: block, combinedef, splitdef
using Markdown: @doc_str
using Random: AbstractRNG, Random
using Setfield: Setfield
Expand All @@ -21,8 +18,7 @@ include("training.jl")
include("freeze.jl")
include("share_parameters.jl")
include("debug.jl")
include("stateful.jl")
include("compact.jl")
include("deprecated.jl")

end

Expand Down
10 changes: 9 additions & 1 deletion src/contrib/stateful.jl → src/contrib/deprecated.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
# Deprecated
macro compact(exs...)
Base.depwarn(
"Lux.Experimental.@compact` has been promoted out of `Lux.Experimental` and is now \
available in `Lux`. In other words this has been deprecated and will be removed \
in v0.6. Use `Lux.@compact` instead.",
Symbol("@compact"))
return Lux.__compact_macro_impl(exs...)
end

function StatefulLuxLayer(args...; kwargs...)
Base.depwarn(
"Lux.Experimental.StatefulLuxLayer` has been promoted out of `Lux.Experimental` \
Expand Down
8 changes: 6 additions & 2 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Transfer `x` to CPU.
!!! warning
!!! danger
This function has been deprecated. Use [`cpu_device`](@ref) instead.
"""
Expand All @@ -19,7 +19,7 @@ end
Transfer `x` to GPU determined by the backend set using [`Lux.gpu_backend!`](@ref).
!!! warning
!!! danger
This function has been deprecated. Use [`gpu_device`](@ref) instead. Using this function
inside performance critical code will cause massive slowdowns due to type inference
Expand All @@ -41,6 +41,10 @@ end
An easy way to update `TruncatedStacktraces.VERBOSE` without having to load it manually.
Effectively does `TruncatedStacktraces.VERBOSE[] = disable`
!!! danger
This function is now deprecated and will be removed in v0.6.
"""
function disable_stacktrace_truncation!(; disable::Bool=true)
Base.depwarn("`disable_stacktrace_truncation!` is not needed anymore, as \
Expand Down
22 changes: 13 additions & 9 deletions src/contrib/compact.jl → src/helpers/compact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ Here is a linear model:
```julia
using Lux, Random
import Lux.Experimental: @compact
r = @compact(w=rand(3)) do x
return w .* x
Expand Down Expand Up @@ -123,6 +122,11 @@ used inside a `Chain`.
account for the total number of parameters printed at the bottom.
"""
macro compact(_exs...)
return __compact_macro_impl(_exs...)
end

# Needed for the deprecation path
function __compact_macro_impl(_exs...)
# check inputs, extracting function expression fex and unprocessed keyword arguments _kwexs
if isempty(_exs)
msg = "expects at least two expressions: a function and at least one keyword"
Expand Down Expand Up @@ -187,12 +191,13 @@ macro compact(_exs...)
fex_args = fex.args[1]
isa(fex_args, Symbol) ? string(fex_args) : join(fex_args.args, ", ")
catch e
@warn "Function stringifying does not yet handle all cases. Falling back to empty string for input arguments"
@warn "Function stringifying does not yet handle all cases. Falling back to empty \
string for input arguments"
end
block = string(Base.remove_linenums!(fex).args[2])

# edit expressions
vars = map(ex -> ex.args[1], kwexs)
vars = map(first Base.Fix2(getproperty, :args), kwexs)
fex = supportself(fex, vars)

# assemble
Expand All @@ -212,9 +217,8 @@ function supportself(fex::Expr, vars)
calls = []
for var in vars
push!(calls,
:($var = Lux.Experimental.__maybe_make_stateful(
Lux._getproperty($self, $(Val(var))),
Lux._getproperty($ps, $(Val(var))), Lux._getproperty($st, $(Val(var))))))
:($var = $(__maybe_make_stateful)($(_getproperty)($self, $(Val(var))),
$(_getproperty)($ps, $(Val(var))), $(_getproperty)($st, $(Val(var))))))
end
body = Expr(:let, Expr(:block, calls...), sdef[:body])
sdef[:body] = body
Expand All @@ -223,7 +227,7 @@ function supportself(fex::Expr, vars)
end

@inline function __maybe_make_stateful(layer::AbstractExplicitLayer, ps, st)
return Lux.StatefulLuxLayer(layer, ps, st)
return StatefulLuxLayer(layer, ps, st)
end
@inline __maybe_make_stateful(::Nothing, ps, st) = ps === nothing ? st : ps
@inline function __maybe_make_stateful(model::Union{AbstractVector, Tuple}, ps, st)
Expand Down Expand Up @@ -271,7 +275,7 @@ end
value_storage
end

function constructorof(::Type{<:CompactLuxLayer{dispatch}}) where {dispatch}
function ConstructionBase.constructorof(::Type{<:CompactLuxLayer{dispatch}}) where {dispatch}
return CompactLuxLayer{dispatch}
end

Expand All @@ -288,7 +292,7 @@ function __try_make_lux_layer(x::Union{AbstractVector, Tuple})
return __try_make_lux_layer(NamedTuple{Tuple(Symbol.(1:length(x)))}(x))
end
function __try_make_lux_layer(x)
function __maybe_convert_layer(l)
__maybe_convert_layer = @closure l -> begin
l isa AbstractExplicitLayer && return l
l isa Function && return WrappedFunction(l)
return l
Expand Down

0 comments on commit b3e9a5f

Please sign in to comment.