Skip to content

Commit

Permalink
refactor: static fields in layers
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 4, 2024
1 parent ea332be commit cc33ada
Show file tree
Hide file tree
Showing 18 changed files with 596 additions and 585 deletions.
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.5.68-DEV"
version = "0.5.68"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -11,7 +11,6 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
Expand Down Expand Up @@ -78,7 +77,6 @@ ChainRulesCore = "1.24"
Compat = "4.15"
ComponentArrays = "0.15.16"
ConcreteStructs = "0.2.3"
ConstructionBase = "1.5"
DispatchDoctor = "0.4.12"
DynamicExpressions = "0.16, 0.17, 0.18, 0.19"
Enzyme = "0.12.26"
Expand Down
4 changes: 2 additions & 2 deletions ext/LuxSimpleChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ end
equivalent_simplechains_fn(::typeof(NNlib.relu)) = SimpleChains.relu
equivalent_simplechains_fn(f::F) where {F} = f

function Lux.make_simplechain_network(layer::Dense{use_bias}) where {use_bias}
return SimpleChains.TurboDense{use_bias}(
function Lux.make_simplechain_network(layer::Dense)
return SimpleChains.TurboDense{Lux.has_bias(layer)}(
equivalent_simplechains_fn(layer.activation), layer.out_dims)
end

Expand Down
3 changes: 1 addition & 2 deletions src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@ using ArrayInterface: ArrayInterface
using ChainRulesCore: ChainRulesCore, NoTangent, @thunk
using Compat: @compat
using ConcreteStructs: @concrete
using ConstructionBase: ConstructionBase
using FastClosures: @closure
using Functors: Functors, fmap
using GPUArraysCore: @allowscalar
using LossFunctions: LossFunctions
using Markdown: @doc_str
using Optimisers: Optimisers
using Random: Random, AbstractRNG
using Static: StaticBool, True, False, static
using Static: StaticBool, StaticInt, StaticSymbol, True, False, static, known, dynamic
using Reexport: @reexport
using Statistics: mean
using UnrolledUtilities: unrolled_map, unrolled_mapreduce
Expand Down
1 change: 1 addition & 0 deletions src/contrib/contrib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ using Markdown: @doc_str
using Optimisers: Optimisers
using Random: AbstractRNG, Random
using Setfield: Setfield
using Static: StaticSymbol, StaticBool, True, known, static, dynamic

const CRC = ChainRulesCore

Expand Down
47 changes: 26 additions & 21 deletions src/contrib/debug.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
DebugLayer(layer::AbstractExplicitLayer; nan_check::Symbol=:both,
error_check::Bool=true, location::KeyPath=KeyPath())
DebugLayer(layer::AbstractExplicitLayer;
nan_check::Union{Symbol, StaticSymbol, Val}=static(:both),
error_check::Union{StaticBool, Bool, Val{true}, Val{false}}=True(),
location::Union{KeyPath, String}=KeyPath())
A wrapper over Lux layers that adds checks for NaNs and errors. This is useful for
debugging.
Expand Down Expand Up @@ -41,15 +43,18 @@ track where the error originates.
See [`Lux.Experimental.@debug_mode`](@ref) to construct this layer.
"""
@concrete struct DebugLayer{NaNCheck, ErrorCheck} <:
AbstractExplicitContainerLayer{(:layer,)}
@concrete struct DebugLayer <: AbstractExplicitContainerLayer{(:layer,)}
nan_check <: StaticSymbol
error_check <: StaticBool
layer <: AbstractExplicitLayer
location::KeyPath
end

function DebugLayer(layer::AbstractExplicitLayer; nan_check::Symbol=:both,
error_check::Bool=true, location::Union{KeyPath, String}=KeyPath())
@argcheck nan_check in (:both, :forward, :backward, :none)
function DebugLayer(layer::AbstractExplicitLayer;
nan_check::Union{Symbol, StaticSymbol, Val}=static(:both),
error_check::Union{StaticBool, Bool, Val{true}, Val{false}}=True(),
location::Union{KeyPath, String}=KeyPath())
@argcheck dynamic(nan_check) in (:both, :forward, :backward, :none)

if location isa String
Base.depwarn(
Expand All @@ -58,23 +63,23 @@ function DebugLayer(layer::AbstractExplicitLayer; nan_check::Symbol=:both,
location = KeyPath(Symbol.(split(location, "."))...)
end

return DebugLayer{nan_check, error_check}(layer, location)
return DebugLayer(static(nan_check), static(error_check), layer, location)
end

function (d::DebugLayer{NaNCheck, ErrorCheck})(x, ps, st) where {NaNCheck, ErrorCheck}
function (d::DebugLayer)(x, ps, st)
CRC.ignore_derivatives() do
@info lazy"Input Type: $(typeof(x)) | Input Structure: $(Utils.structure(x))."
@info lazy"Running Layer: $(d.layer) at location $(d.location)!"
if NaNCheck (:both, :forward)
if known(d.nan_check) (:both, :forward)
check_nan_and_throw(x, "input", d.layer, d.location)
check_nan_and_throw(ps, "parameters", d.layer, d.location)
check_nan_and_throw(st, "states", d.layer, d.location)
end
end
y, stₙ = debug_layer_impl(
d.layer, x, ps, st, d.location, ErrorCheck, NaNCheck (:both, :backward))
y, stₙ = debug_layer_impl(d.layer, x, ps, st, d.location, known(d.error_check),
known(d.nan_check) (:both, :backward))
CRC.ignore_derivatives() do
if NaNCheck (:both, :forward)
if known(d.nan_check) (:both, :forward)
check_nan_and_throw(y, "output", d.layer, d.location)
check_nan_and_throw(stₙ, "states", d.layer, d.location)
end
Expand All @@ -99,34 +104,34 @@ function check_nan_and_throw(x, str::AbstractString, layer, location::KeyPath)
return fmap_with_path(nan_check, x)
end

function debug_layer_impl(layer, x, ps, st, location, EC, NC)
function debug_layer_impl(layer, x, ps, st, location, error_check, _)
y, stₙ = try
apply(layer, x, ps, st)
catch
EC &&
error_check &&
@error "Layer $(layer) failed!! This layer is present at location $(location)."
rethrow()
end
return y, stₙ
end

function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode},
::typeof(debug_layer_impl), layer, x, ps, st, location, EC, NC)
function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(debug_layer_impl),
layer, x, ps, st, location, error_check, nan_check_backward)
result, ∇debug_layer_internal = CRC.rrule_via_ad(cfg, apply, layer, x, ps, st)
syms = ("LuxCore.apply", "layer", "x", "ps", "st")
function ∇debug_layer_internal_with_checks(Δ)
NC && check_nan_and_throw(Δ, "pullback input", layer, location)
nan_check_backward && check_nan_and_throw(Δ, "pullback input", layer, location)

gs = try
∇debug_layer_internal(Δ)
catch
EC &&
error_check &&
@error "Backward Pass for Layer $(layer) failed!! This layer is present at location $(location)."
rethrow()
end

if NC
for (i, g) in enumerate(gs)
if nan_check_backward
foreach(enumerate(gs)) do (i, g)
check_nan_and_throw(g, "pullback output ($(syms[i]))", layer, location)
end
end
Expand Down
22 changes: 17 additions & 5 deletions src/extended_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ using Compat: @compat
using EnzymeCore: EnzymeCore
using FastClosures: @closure
using MLDataDevices: get_device_type, AbstractGPUDevice, AbstractDevice
using Static: StaticBool, known
using Static: StaticBool, StaticSymbol, known

using ..Utils: Utils

const CRC = ChainRulesCore

const KnownSymbolType{v} = Union{Val{v}, StaticSymbol{v}}

# `xlogx` and `xlogy`
## We don't use `LogExpFunctions` since they don't support GPU broadcasting. See
## https://github.com/LuxDL/Lux.jl/pull/796. Additionally we have special broadcast rrules.
Expand Down Expand Up @@ -79,14 +81,15 @@ end

"""
getproperty(x, ::Val{v})
getproperty(x, ::StaticSymbol{v})
Similar to `Base.getproperty` but requires a `Val`. Additionally if `v` is not present in
`x`, then `nothing` is returned.
Similar to `Base.getproperty` but requires a `Val` (or `Static.StaticSymbol`). Additionally,
if `v` is not present in `x`, then `nothing` is returned.
"""
function getproperty(x, ::Val{v}) where {v}
function getproperty(x, ::KnownSymbolType{v}) where {v}
return v Base.propertynames(x) ? Base.getproperty(x, v) : nothing
end
@generated function getproperty(x::NamedTuple{names}, ::Val{v}) where {names, v}
@generated function getproperty(x::NamedTuple{names}, ::KnownSymbolType{v}) where {names, v}
return v names ? :(x.$v) : :(nothing)
end

Expand Down Expand Up @@ -228,3 +231,12 @@ const safe_eachslice = LuxOps.eachslice
const private_xlogx = LuxOps.xlogx
const private_xlogy = LuxOps.xlogy
const private_foldl_init = LuxOps.foldl_init

# These are defined here to avoid a circular dependency among modules
for (op, field) in (:bias => :use_bias, :affine => :affine,
:track_stats => :track_stats, :train_state => :train_state)
@eval function $(Symbol(:has_, op))(l::AbstractExplicitLayer)
res = known(safe_getproperty(l, Val($(Meta.quot(field)))))
return ifelse(res === nothing, false, res)
end
end
32 changes: 16 additions & 16 deletions src/helpers/compact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,19 +292,16 @@ macro non_trainable(x)
return esc(:($(CompactMacroImpl.NonTrainable)($(x))))
end

@concrete struct CompactLuxLayer{dispatch} <:
AbstractExplicitContainerLayer{(:layers, :value_storage)}
f
name
struct CompactLuxLayer{dispatch, F, N, L, V, SK} <:
AbstractExplicitContainerLayer{(:layers, :value_storage)}
d::StaticSymbol{dispatch}
f::F
name::N
strings::NTuple{3, String}
setup_strings
layers
value_storage
stored_kwargs
end

function ConstructionBase.constructorof(::Type{<:CompactLuxLayer{dispatch}}) where {dispatch}
return CompactLuxLayer{dispatch}
setup_strings::Any
layers::L
value_storage::V
stored_kwargs::SK
end

function initialparameters(rng::AbstractRNG, m::CompactLuxLayer)
Expand All @@ -320,8 +317,8 @@ function initialstates(rng::AbstractRNG, m::CompactLuxLayer)
base_states, (; ₋₋₋kwargs₋₋₋=NamedTuple{m.stored_kwargs[1]}(m.stored_kwargs[2])))
end

function CompactLuxLayer{dispatch}(
f::F, name::NAME_TYPE, str::Tuple, splatted_kwargs; kws...) where {F, dispatch}
function CompactLuxLayer(dispatch::StaticSymbol, f::F, name::NAME_TYPE,
str::Tuple, splatted_kwargs; kws...) where {F}
layers, others = [], []
setup_strings = NamedTuple()
for (name, val) in pairs(kws)
Expand Down Expand Up @@ -353,7 +350,7 @@ function CompactLuxLayer{dispatch}(
NamedTuple((name => CompactMacroImpl.kwarg_descriptor(val),)))
end
end
return CompactLuxLayer{dispatch}(f, name, str, setup_strings, NamedTuple((; layers...)),
return CompactLuxLayer(dispatch, f, name, str, setup_strings, NamedTuple((; layers...)),
CompactMacroImpl.ValueStorage(; others...), splatted_kwargs)
end

Expand Down Expand Up @@ -423,6 +420,7 @@ using ChainRulesCore: @non_differentiable
using ConcreteStructs: @concrete
using MacroTools: MacroTools, @capture, combinedef, splitdef
using Random: AbstractRNG
using Static: static

using LuxCore: LuxCore, AbstractExplicitLayer
using ..Lux: Lux, CompactLuxLayer, LuxCompactModelParsingException, StatefulLuxLayer,
Expand Down Expand Up @@ -467,6 +465,7 @@ function compact_macro_impl(_exs...)

# check if user has provided a custom dispatch
dispatch, kwexs = extract_reserved_kwarg(kwexs, :dispatch)
dispatch === nothing && (dispatch = QuoteNode(:₋₋₋no_special_dispatch₋₋₋))

# Extract splatted kwargs
splat_idxs = findall(ex -> ex.head == :..., kwexs)
Expand Down Expand Up @@ -495,7 +494,8 @@ function compact_macro_impl(_exs...)
fex = supportself(fex, vars, splatted_kwargs)

# assemble
return esc(:($CompactLuxLayer{$dispatch}($fex, $name, ($layer, $input, $block),
return esc(:($CompactLuxLayer(
$(static)($(dispatch)), $fex, $name, ($layer, $input, $block),
(($(Meta.quot.(splatted_kwargs)...),), ($(splatted_kwargs...),)); $(kwexs...))))
end

Expand Down
Loading

0 comments on commit cc33ada

Please sign in to comment.