diff --git a/.github/workflows/Downgrade.yml b/.github/workflows/Downgrade.yml index 94a4aebc8..96124a706 100644 --- a/.github/workflows/Downgrade.yml +++ b/.github/workflows/Downgrade.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - version: ['1.10'] + version: ['1'] steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 diff --git a/Project.toml b/Project.toml index 44a88b1bb..45f2ea3d0 100644 --- a/Project.toml +++ b/Project.toml @@ -60,6 +60,7 @@ ComponentArrays = "0.15.11" ConcreteStructs = "0.2.3" ConstructionBase = "1.5" FastClosures = "0.3.2" +ExplicitImports = "1.1.1" Flux = "0.14.11" Functors = "0.4.4" GPUArraysCore = "0.1.6" @@ -68,7 +69,7 @@ Logging = "1.10" LuxAMDGPU = "0.2.2" LuxCUDA = "0.3.2" LuxCore = "0.1.12" -LuxDeviceUtils = "0.1.15" +LuxDeviceUtils = "0.1.16" LuxLib = "0.3.10" LuxTestUtils = "0.1.15" MacroTools = "0.5.13" @@ -94,8 +95,10 @@ julia = "1.10" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -119,4 +122,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ADTypes", "Adapt", "Aqua", "ChainRulesCore", "ComponentArrays", "Flux", "Functors", "Logging", "LuxAMDGPU", "LuxCUDA", "LuxCore", "LuxLib", "LuxTestUtils", "Optimisers", "Random", "ReTestItems", "Reexport", "ReverseDiff", "Setfield", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"] +test = ["ADTypes", "Adapt", "Aqua", "ChainRules", "ChainRulesCore", "ComponentArrays", "ExplicitImports", "Flux", "Functors", "Logging", "LuxAMDGPU", "LuxCUDA", "LuxCore", "LuxLib", "LuxTestUtils", "Optimisers", "Random", "ReTestItems", "Reexport", "ReverseDiff", "Setfield", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"] diff --git a/docs/src/introduction/index.md b/docs/src/introduction/index.md index d49c2472b..bf4cf39a7 100644 --- a/docs/src/introduction/index.md +++ b/docs/src/introduction/index.md @@ -2,8 +2,9 @@ ## Installation -Install [Julia v1.10 or above](https://julialang.org/downloads/). Lux.jl is available through -the Julia package manager. You can enter it by pressing `]` in the REPL and then typing +Install [Julia v1.10 or above](https://julialang.org/downloads/). Lux.jl is available +through the Julia package manager. You can enter it by pressing `]` in the REPL and then +typing ```julia pkg> add Lux diff --git a/ext/LuxChainRulesExt.jl b/ext/LuxChainRulesExt.jl index f349f9806..9ecf874d0 100644 --- a/ext/LuxChainRulesExt.jl +++ b/ext/LuxChainRulesExt.jl @@ -1,6 +1,6 @@ module LuxChainRulesExt -using ChainRules, ChainRulesCore, Lux +using ChainRules: ChainRules # https://github.com/FluxML/Zygote.jl/pull/1328 broke the RNNs completely. Putting an # emergency patch here diff --git a/ext/LuxComponentArraysExt.jl b/ext/LuxComponentArraysExt.jl index ade957687..4a7f5c3ce 100644 --- a/ext/LuxComponentArraysExt.jl +++ b/ext/LuxComponentArraysExt.jl @@ -1,6 +1,7 @@ module LuxComponentArraysExt -using ComponentArrays, Lux +using ComponentArrays: ComponentArrays, ComponentArray, FlatAxis +using Lux: Lux # Empty NamedTuple: Hack to avoid breaking precompilation function ComponentArrays.ComponentArray(data::Vector{Any}, axes::Tuple{FlatAxis}) diff --git a/ext/LuxComponentArraysReverseDiffExt.jl b/ext/LuxComponentArraysReverseDiffExt.jl index 52d61b5a9..930f337aa 100644 --- a/ext/LuxComponentArraysReverseDiffExt.jl +++ b/ext/LuxComponentArraysReverseDiffExt.jl @@ -1,9 +1,10 @@ module LuxComponentArraysReverseDiffExt -using ComponentArrays, ReverseDiff, Lux +using ComponentArrays: ComponentArray +using Lux: Lux +using ReverseDiff: TrackedArray -const TCA{V, D, N, DA, A, Ax} = ReverseDiff.TrackedArray{ - V, D, N, ComponentArray{V, N, A, Ax}, DA} +const TCA{V, D, N, DA, A, Ax} = TrackedArray{V, D, N, ComponentArray{V, N, A, Ax}, DA} Lux.__named_tuple(x::TCA) = NamedTuple(x) diff --git a/ext/LuxFluxExt.jl b/ext/LuxFluxExt.jl index eb6403e74..dd8ba3cbe 100644 --- a/ext/LuxFluxExt.jl +++ b/ext/LuxFluxExt.jl @@ -1,28 +1,28 @@ module LuxFluxExt import Flux - -using Lux, Random +using Lux: Lux +using Random: AbstractRNG import Lux: __from_flux_adaptor, FluxLayer, FluxModelConversionError __copy_anonymous_closure(x) = (args...) -> x -function FluxLayer(l) +function Lux.FluxLayer(l) if isdefined(Flux, :destructure) p, re = Flux.destructure(l) p_ = copy(p) - return FluxLayer(l, re, () -> p_) + return Lux.FluxLayer(l, re, () -> p_) else error("`Flux.destructure` not found. Please open an issue on LuxDL/Lux.jl with a \ MWE") end end -Lux.initialparameters(::AbstractRNG, l::FluxLayer) = (p=l.init_parameters(),) +Lux.initialparameters(::AbstractRNG, l::Lux.FluxLayer) = (p=l.init_parameters(),) (l::FluxLayer)(x, ps, st) = l.re(ps.p)(x), st -Base.show(io::IO, l::FluxLayer) = print(io, "FluxLayer($(l.layer))") +Base.show(io::IO, l::Lux.FluxLayer) = print(io, "FluxLayer($(l.layer))") function __from_flux_adaptor(l::T; preserve_ps_st::Bool=false, kwargs...) where {T} @warn lazy"Transformation for type $T not implemented. Using `FluxLayer` as a fallback." maxlog=1 @@ -33,18 +33,18 @@ function __from_flux_adaptor(l::T; preserve_ps_st::Bool=false, kwargs...) where argument." maxlog=1 end - return FluxLayer(l) + return Lux.FluxLayer(l) end -__from_flux_adaptor(l::Function; kwargs...) = WrappedFunction(l) +__from_flux_adaptor(l::Function; kwargs...) = Lux.WrappedFunction(l) function __from_flux_adaptor(l::Flux.Chain; kwargs...) fn = x -> __from_flux_adaptor(x; kwargs...) layers = map(fn, l.layers) if layers isa NamedTuple - return Chain(layers; disable_optimizations=true) + return Lux.Chain(layers; disable_optimizations=true) else - return Chain(layers...; disable_optimizations=true) + return Lux.Chain(layers...; disable_optimizations=true) end end @@ -52,42 +52,42 @@ function __from_flux_adaptor(l::Flux.Dense; preserve_ps_st::Bool=false, kwargs.. out_dims, in_dims = size(l.weight) if preserve_ps_st bias = l.bias isa Bool ? nothing : reshape(copy(l.bias), out_dims, 1) - return Dense( + return Lux.Dense( in_dims => out_dims, l.σ; init_weight=__copy_anonymous_closure(copy(l.weight)), init_bias=__copy_anonymous_closure(bias), use_bias=!(l.bias isa Bool)) else - return Dense(in_dims => out_dims, l.σ; use_bias=!(l.bias isa Bool)) + return Lux.Dense(in_dims => out_dims, l.σ; use_bias=!(l.bias isa Bool)) end end function __from_flux_adaptor(l::Flux.Scale; preserve_ps_st::Bool=false, kwargs...) if preserve_ps_st - return Scale( + return Lux.Scale( size(l.scale), l.σ; init_weight=__copy_anonymous_closure(copy(l.scale)), init_bias=__copy_anonymous_closure(copy(l.bias)), use_bias=!(l.bias isa Bool)) else - return Scale(size(l.scale), l.σ; use_bias=!(l.bias isa Bool)) + return Lux.Scale(size(l.scale), l.σ; use_bias=!(l.bias isa Bool)) end end function __from_flux_adaptor(l::Flux.Maxout; kwargs...) - return Maxout(__from_flux_adaptor.(l.layers; kwargs...)...) + return Lux.Maxout(__from_flux_adaptor.(l.layers; kwargs...)...) end function __from_flux_adaptor(l::Flux.SkipConnection; kwargs...) connection = l.connection isa Function ? l.connection : __from_flux_adaptor(l.connection; kwargs...) - return SkipConnection(__from_flux_adaptor(l.layers; kwargs...), connection) + return Lux.SkipConnection(__from_flux_adaptor(l.layers; kwargs...), connection) end function __from_flux_adaptor(l::Flux.Bilinear; preserve_ps_st::Bool=false, kwargs...) out, in1, in2 = size(l.weight) if preserve_ps_st - return Bilinear( + return Lux.Bilinear( (in1, in2) => out, l.σ; init_weight=__copy_anonymous_closure(copy(l.weight)), init_bias=__copy_anonymous_closure(copy(l.bias)), use_bias=!(l.bias isa Bool)) else - return Bilinear((in1, in2) => out, l.σ; use_bias=!(l.bias isa Bool)) + return Lux.Bilinear((in1, in2) => out, l.σ; use_bias=!(l.bias isa Bool)) end end @@ -95,25 +95,25 @@ function __from_flux_adaptor(l::Flux.Parallel; kwargs...) fn = x -> __from_flux_adaptor(x; kwargs...) layers = map(fn, l.layers) if layers isa NamedTuple - return Parallel(l.connection; layers...) + return Lux.Parallel(l.connection; layers...) else - return Parallel(l.connection, layers...) + return Lux.Parallel(l.connection, layers...) end end function __from_flux_adaptor(l::Flux.PairwiseFusion; kwargs...) @warn "Flux.PairwiseFusion and Lux.PairwiseFusion are semantically different. Using \ `FluxLayer` as a fallback." maxlog=1 - return FluxLayer(l) + return Lux.FluxLayer(l) end function __from_flux_adaptor(l::Flux.Embedding; preserve_ps_st::Bool=true, kwargs...) out_dims, in_dims = size(l.weight) if preserve_ps_st - return Embedding( + return Lux.Embedding( in_dims => out_dims; init_weight=__copy_anonymous_closure(copy(l.weight))) else - return Embedding(in_dims => out_dims) + return Lux.Embedding(in_dims => out_dims) end end @@ -125,11 +125,12 @@ function __from_flux_adaptor(l::Flux.Conv; preserve_ps_st::Bool=false, kwargs... if preserve_ps_st _bias = l.bias isa Bool ? nothing : reshape(copy(l.bias), ntuple(_ -> 1, length(k))..., out_chs, 1) - return Conv(k, in_chs * groups => out_chs, l.σ; l.stride, pad, l.dilation, groups, + return Lux.Conv( + k, in_chs * groups => out_chs, l.σ; l.stride, pad, l.dilation, groups, init_weight=__copy_anonymous_closure(Lux._maybe_flip_conv_weight(l.weight)), init_bias=__copy_anonymous_closure(_bias), use_bias=!(l.bias isa Bool)) else - return Conv(k, in_chs * groups => out_chs, l.σ; l.stride, pad, + return Lux.Conv(k, in_chs * groups => out_chs, l.σ; l.stride, pad, l.dilation, groups, use_bias=!(l.bias isa Bool)) end end @@ -142,12 +143,12 @@ function __from_flux_adaptor(l::Flux.ConvTranspose; preserve_ps_st::Bool=false, if preserve_ps_st _bias = l.bias isa Bool ? nothing : reshape(copy(l.bias), ntuple(_ -> 1, length(k))..., out_chs, 1) - return ConvTranspose(k, in_chs * groups => out_chs, l.σ; l.stride, + return Lux.ConvTranspose(k, in_chs * groups => out_chs, l.σ; l.stride, pad, l.dilation, groups, use_bias=!(l.bias isa Bool), init_weight=__copy_anonymous_closure(Lux._maybe_flip_conv_weight(l.weight)), init_bias=__copy_anonymous_closure(_bias)) else - return ConvTranspose(k, in_chs * groups => out_chs, l.σ; l.stride, pad, + return Lux.ConvTranspose(k, in_chs * groups => out_chs, l.σ; l.stride, pad, l.dilation, groups, use_bias=!(l.bias isa Bool)) end end @@ -159,49 +160,49 @@ function __from_flux_adaptor(l::Flux.CrossCor; preserve_ps_st::Bool=false, kwarg if preserve_ps_st _bias = l.bias isa Bool ? nothing : reshape(copy(l.bias), ntuple(_ -> 1, length(k))..., out_chs, 1) - return CrossCor(k, in_chs => out_chs, l.σ; l.stride, pad, l.dilation, + return Lux.CrossCor(k, in_chs => out_chs, l.σ; l.stride, pad, l.dilation, init_weight=__copy_anonymous_closure(copy(l.weight)), init_bias=__copy_anonymous_closure(_bias), use_bias=!(l.bias isa Bool)) else - return CrossCor(k, in_chs => out_chs, l.σ; l.stride, pad, + return Lux.CrossCor(k, in_chs => out_chs, l.σ; l.stride, pad, l.dilation, use_bias=!(l.bias isa Bool)) end end -__from_flux_adaptor(l::Flux.AdaptiveMaxPool; kwargs...) = AdaptiveMaxPool(l.out) +__from_flux_adaptor(l::Flux.AdaptiveMaxPool; kwargs...) = Lux.AdaptiveMaxPool(l.out) -__from_flux_adaptor(l::Flux.AdaptiveMeanPool; kwargs...) = AdaptiveMeanPool(l.out) +__from_flux_adaptor(l::Flux.AdaptiveMeanPool; kwargs...) = Lux.AdaptiveMeanPool(l.out) -__from_flux_adaptor(::Flux.GlobalMaxPool; kwargs...) = GlobalMaxPool() +__from_flux_adaptor(::Flux.GlobalMaxPool; kwargs...) = Lux.GlobalMaxPool() -__from_flux_adaptor(::Flux.GlobalMeanPool; kwargs...) = GlobalMeanPool() +__from_flux_adaptor(::Flux.GlobalMeanPool; kwargs...) = Lux.GlobalMeanPool() function __from_flux_adaptor(l::Flux.MaxPool; kwargs...) pad = l.pad isa Flux.SamePad ? SamePad() : l.pad - return MaxPool(l.k; l.stride, pad) + return Lux.MaxPool(l.k; l.stride, pad) end function __from_flux_adaptor(l::Flux.MeanPool; kwargs...) pad = l.pad isa Flux.SamePad ? SamePad() : l.pad - return MeanPool(l.k; l.stride, pad) + return Lux.MeanPool(l.k; l.stride, pad) end -__from_flux_adaptor(l::Flux.Dropout; kwargs...) = Dropout(l.p; l.dims) +__from_flux_adaptor(l::Flux.Dropout; kwargs...) = Lux.Dropout(l.p; l.dims) function __from_flux_adaptor(l::Flux.LayerNorm; kwargs...) @warn "Flux.LayerNorm and Lux.LayerNorm are semantically different specifications. \ Using `FluxLayer` as a fallback." maxlog=1 - return FluxLayer(l) + return Lux.FluxLayer(l) end -__from_flux_adaptor(::typeof(identity); kwargs...) = NoOpLayer() +__from_flux_adaptor(::typeof(identity); kwargs...) = Lux.NoOpLayer() -__from_flux_adaptor(::typeof(Flux.flatten); kwargs...) = FlattenLayer() +__from_flux_adaptor(::typeof(Flux.flatten); kwargs...) = Lux.FlattenLayer() -__from_flux_adaptor(l::Flux.PixelShuffle; kwargs...) = PixelShuffle(l.r) +__from_flux_adaptor(l::Flux.PixelShuffle; kwargs...) = Lux.PixelShuffle(l.r) function __from_flux_adaptor(l::Flux.Upsample{mode}; kwargs...) where {mode} - return Upsample(mode; l.scale, l.size) + return Lux.Upsample(mode; l.scale, l.size) end function __from_flux_adaptor( @@ -212,11 +213,11 @@ function __from_flux_adaptor( throw(FluxModelConversionError(lazy"Recurrent Cell: $(typeof(l)) for Flux uses a `reset!` mechanism which hasn't been extensively tested with `FluxLayer`. Rewrite the model manually to use `RNNCell`.")) end @warn "Preserving Parameters: `Wh` & `Wi` for `Flux.RNNCell` is ambiguous in Lux and hence not supported. Ignoring these parameters." maxlog=1 - return RNNCell( + return Lux.RNNCell( in_dims => out_dims, l.σ; init_bias=__copy_anonymous_closure(copy(l.b)), init_state=__copy_anonymous_closure(copy(l.state0))) else - return RNNCell(in_dims => out_dims, l.σ) + return Lux.RNNCell(in_dims => out_dims, l.σ) end end @@ -232,11 +233,11 @@ function __from_flux_adaptor( and hence not supported. Ignoring these parameters." maxlog=1 bs = Lux.multigate(l.b, Val(4)) _s, _m = copy.(l.state0) - return LSTMCell(in_dims => out_dims; init_bias=__copy_anonymous_closure.(bs), + return Lux.LSTMCell(in_dims => out_dims; init_bias=__copy_anonymous_closure.(bs), init_state=__copy_anonymous_closure(_s), init_memory=__copy_anonymous_closure(_m)) else - return LSTMCell(in_dims => out_dims) + return Lux.LSTMCell(in_dims => out_dims) end end @@ -251,10 +252,10 @@ function __from_flux_adaptor( @warn "Preserving Parameters: `Wh` & `Wi` for `Flux.GRUCell` is ambiguous in Lux \ and hence not supported. Ignoring these parameters." maxlog=1 bs = Lux.multigate(l.b, Val(3)) - return GRUCell(in_dims => out_dims; init_bias=_const_return_anon_function.(bs), + return Lux.GRUCell(in_dims => out_dims; init_bias=_const_return_anon_function.(bs), init_state=__copy_anonymous_closure(copy(l.state0))) else - return GRUCell(in_dims => out_dims) + return Lux.GRUCell(in_dims => out_dims) end end @@ -262,38 +263,39 @@ function __from_flux_adaptor( l::Flux.BatchNorm; preserve_ps_st::Bool=false, force_preserve::Bool=false) if preserve_ps_st if l.track_stats - force_preserve && return FluxLayer(l) + force_preserve && return Lux.FluxLayer(l) @warn "Preserving the state of `Flux.BatchNorm` is currently not supported. \ Ignoring the state." maxlog=1 end if l.affine - return BatchNorm(l.chs, l.λ; l.affine, l.track_stats, epsilon=l.ϵ, + return Lux.BatchNorm(l.chs, l.λ; l.affine, l.track_stats, epsilon=l.ϵ, l.momentum, init_bias=__copy_anonymous_closure(copy(l.β)), init_scale=__copy_anonymous_closure(copy(l.γ))) else - return BatchNorm(l.chs, l.λ; l.affine, l.track_stats, epsilon=l.ϵ, l.momentum) + return Lux.BatchNorm( + l.chs, l.λ; l.affine, l.track_stats, epsilon=l.ϵ, l.momentum) end end - return BatchNorm(l.chs, l.λ; l.affine, l.track_stats, epsilon=l.ϵ, l.momentum) + return Lux.BatchNorm(l.chs, l.λ; l.affine, l.track_stats, epsilon=l.ϵ, l.momentum) end function __from_flux_adaptor( l::Flux.GroupNorm; preserve_ps_st::Bool=false, force_preserve::Bool=false) if preserve_ps_st if l.track_stats - force_preserve && return FluxLayer(l) + force_preserve && return Lux.FluxLayer(l) @warn "Preserving the state of `Flux.GroupNorm` is currently not supported. \ Ignoring the state." maxlog=1 end if l.affine - return GroupNorm(l.chs, l.G, l.λ; l.affine, epsilon=l.ϵ, + return Lux.GroupNorm(l.chs, l.G, l.λ; l.affine, epsilon=l.ϵ, init_bias=__copy_anonymous_closure(copy(l.β)), init_scale=__copy_anonymous_closure(copy(l.γ))) else - return GroupNorm(l.chs, l.G, l.λ; l.affine, epsilon=l.ϵ) + return Lux.GroupNorm(l.chs, l.G, l.λ; l.affine, epsilon=l.ϵ) end end - return GroupNorm(l.chs, l.G, l.λ; l.affine, l.track_stats, epsilon=l.ϵ, l.momentum) + return Lux.GroupNorm(l.chs, l.G, l.λ; l.affine, l.track_stats, epsilon=l.ϵ, l.momentum) end const _INVALID_TRANSFORMATION_TYPES = Union{<:Flux.Recur} diff --git a/ext/LuxLuxAMDGPUExt.jl b/ext/LuxLuxAMDGPUExt.jl index 684772fcc..9c89591dd 100644 --- a/ext/LuxLuxAMDGPUExt.jl +++ b/ext/LuxLuxAMDGPUExt.jl @@ -1,9 +1,10 @@ module LuxLuxAMDGPUExt -using Lux, LuxAMDGPU +import LuxAMDGPU: AMDGPU +import Lux: _maybe_flip_conv_weight # Flux modifies Conv weights while mapping to AMD GPU -function Lux._maybe_flip_conv_weight(x::AMDGPU.AnyROCArray) +function _maybe_flip_conv_weight(x::AMDGPU.AnyROCArray) # This is a very rare operation, hence we dont mind allowing scalar operations return AMDGPU.@allowscalar reverse(x; dims=ntuple(identity, ndims(x) - 2)) end diff --git a/ext/LuxOptimisersExt.jl b/ext/LuxOptimisersExt.jl index ed857a4f5..8ba6c90d2 100644 --- a/ext/LuxOptimisersExt.jl +++ b/ext/LuxOptimisersExt.jl @@ -1,6 +1,9 @@ module LuxOptimisersExt -using Lux, Random, Optimisers +using Lux: Lux +using LuxDeviceUtils: AbstractLuxDevice, gpu_device +using Optimisers: Optimisers +using Random: Random """ TrainState(rng::Random.AbstractRNG, model::Lux.AbstractExplicitLayer, @@ -24,7 +27,7 @@ Constructor for [`TrainState`](@ref). function Lux.Experimental.TrainState( rng::Random.AbstractRNG, model::Lux.AbstractExplicitLayer, optimizer::Optimisers.AbstractRule; - transform_variables::Union{Function, Lux.AbstractLuxDevice}=gpu_device()) + transform_variables::Union{Function, AbstractLuxDevice}=gpu_device()) ps, st = Lux.setup(rng, model) .|> transform_variables st_opt = Optimisers.setup(optimizer, ps) return Lux.Experimental.TrainState(model, ps, st, st_opt, 0) diff --git a/ext/LuxReverseDiffExt.jl b/ext/LuxReverseDiffExt.jl index ab2b212ae..bc52c0e5a 100644 --- a/ext/LuxReverseDiffExt.jl +++ b/ext/LuxReverseDiffExt.jl @@ -1,7 +1,11 @@ module LuxReverseDiffExt -using ADTypes, Lux, Functors, ReverseDiff, Setfield +using ADTypes: AutoReverseDiff using ArrayInterface: ArrayInterface +using Functors: fmap +using Lux: Lux +using ReverseDiff: ReverseDiff +using Setfield: @set! function Lux.Experimental.compute_gradients(::AutoReverseDiff, objective_function::F, data, ts::Lux.Experimental.TrainState) where {F} diff --git a/ext/LuxSimpleChainsExt.jl b/ext/LuxSimpleChainsExt.jl index 725b3255c..446a8942d 100644 --- a/ext/LuxSimpleChainsExt.jl +++ b/ext/LuxSimpleChainsExt.jl @@ -1,9 +1,10 @@ module LuxSimpleChainsExt -using Lux, Random +using Lux import SimpleChains import Lux: SimpleChainsModelConversionError, __to_simplechains_adaptor, __fix_input_dims_simplechain +import Random: AbstractRNG function __fix_input_dims_simplechain(layers::Vector, input_dims) L = Tuple(layers) diff --git a/ext/LuxTrackerExt.jl b/ext/LuxTrackerExt.jl index 7a2f33fba..960169ddb 100644 --- a/ext/LuxTrackerExt.jl +++ b/ext/LuxTrackerExt.jl @@ -1,7 +1,11 @@ module LuxTrackerExt -using ADTypes, ChainRulesCore, Functors, Lux, Setfield, Tracker +using ADTypes: AutoTracker using ArrayInterface: ArrayInterface +using Functors: fmap +using Lux: Lux +using Setfield: @set! +using Tracker: Tracker # Type Piracy: Need to upstream Tracker.param(nt::NamedTuple) = fmap(Tracker.param, nt) @@ -16,18 +20,18 @@ Tracker.data(nt::NamedTuple) = fmap(Tracker.data, nt) Tracker.data(t::Tuple) = map(Tracker.data, t) # Weight Norm Patch -@inline Lux._norm(x::TrackedArray; dims=Colon()) = sqrt.(sum(abs2.(x); dims)) +@inline Lux._norm(x::Tracker.TrackedArray; dims=Colon()) = sqrt.(sum(abs2.(x); dims)) # multigate chain rules -@inline Lux._gate(x::TrackedVector, h::Int, n::Int) = x[Lux._gate(h, n)] -@inline Lux._gate(x::TrackedMatrix, h::Int, n::Int) = x[Lux._gate(h, n), :] +@inline Lux._gate(x::Tracker.TrackedVector, h::Int, n::Int) = x[Lux._gate(h, n)] +@inline Lux._gate(x::Tracker.TrackedMatrix, h::Int, n::Int) = x[Lux._gate(h, n), :] # Lux.Training function Lux.Experimental.compute_gradients(::AutoTracker, objective_function::F, data, ts::Lux.Experimental.TrainState) where {F} - ps_tracked = fmap(param, ts.parameters) + ps_tracked = fmap(Tracker.param, ts.parameters) loss, st, stats = objective_function(ts.model, ps_tracked, ts.states, data) - back!(loss) + Tracker.back!(loss) @set! ts.states = st grads = fmap(Tracker.grad, ps_tracked) return grads, loss, stats, ts diff --git a/ext/LuxZygoteExt.jl b/ext/LuxZygoteExt.jl index 466dcb01a..fef20faf2 100644 --- a/ext/LuxZygoteExt.jl +++ b/ext/LuxZygoteExt.jl @@ -1,7 +1,9 @@ module LuxZygoteExt -using ADTypes, Lux, Setfield, Zygote -using Zygote: Pullback +using ADTypes: AutoZygote +using Lux: Lux +using Setfield: @set! +using Zygote: Zygote function Lux.Experimental.compute_gradients(::AutoZygote, objective_function::F, data, ts::Lux.Experimental.TrainState) where {F} diff --git a/src/Lux.jl b/src/Lux.jl index 33ed4cb16..8eb58f509 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -3,19 +3,28 @@ module Lux import PrecompileTools PrecompileTools.@recompile_invalidations begin - using Reexport - using LuxCore, LuxLib, LuxDeviceUtils, WeightInitializers - using LinearAlgebra, Markdown, Random, Statistics - using Adapt, ConcreteStructs, Functors, Setfield - using ChainRulesCore - using ArrayInterface, GPUArraysCore + using Adapt: Adapt, adapt + using ArrayInterface: ArrayInterface + using ChainRulesCore: ChainRulesCore, AbstractZero, HasReverseMode, NoTangent, + ProjectTo, RuleConfig, ZeroTangent + using ConcreteStructs: @concrete using FastClosures: @closure + using Functors: Functors, fmap + using GPUArraysCore: GPUArraysCore + using LinearAlgebra: LinearAlgebra + using Markdown: @doc_str + using Random: Random, AbstractRNG + using Reexport: @reexport + using Setfield: Setfield, @set! + using Statistics: Statistics, mean + using WeightInitializers: WeightInitializers, glorot_uniform, ones32, randn32, zeros32 + using LuxCore, LuxLib, LuxDeviceUtils, WeightInitializers import LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer, initialparameters, initialstates, parameterlength, statelength, inputsize, outputsize, update_state, trainmode, testmode, setup, apply, display_name, replicate - import LuxDeviceUtils: AbstractLuxDevice, AbstractLuxGPUDevice, AbstractLuxDeviceAdaptor + import LuxDeviceUtils: get_device end @reexport using LuxCore, LuxLib, LuxDeviceUtils, WeightInitializers diff --git a/src/contrib/compact.jl b/src/contrib/compact.jl index ff8040e83..e8da68341 100644 --- a/src/contrib/compact.jl +++ b/src/contrib/compact.jl @@ -1,6 +1,3 @@ -using MacroTools -import ConstructionBase: constructorof - # This functionality is based off of the implementation in Fluxperimental.jl # https://github.com/FluxML/Fluxperimental.jl/blob/cc0e36fdd542cc6028bc69449645dc0390dd980b/src/compact.jl struct LuxCompactModelParsingException <: Exception diff --git a/src/contrib/contrib.jl b/src/contrib/contrib.jl index ae675a196..b72d545f3 100644 --- a/src/contrib/contrib.jl +++ b/src/contrib/contrib.jl @@ -2,12 +2,19 @@ module Experimental import ..Lux using ..Lux, LuxCore, LuxDeviceUtils, Random -import LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer +using LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer import ..Lux: _merge, _pairs, initialstates, initialparameters, apply, NAME_TYPE, _getproperty + +using ADTypes: ADTypes import ChainRulesCore as CRC -import ConcreteStructs: @concrete -import Functors: fmap +using ConcreteStructs: @concrete +import ConstructionBase: constructorof +using Functors: Functors, fmap, functor +using MacroTools: block, combinedef, splitdef +using Markdown: @doc_str +using Random: AbstractRNG, Random +using Setfield: Setfield include("map.jl") include("training.jl") @@ -22,8 +29,8 @@ end # Deprecations for v0.6 module Training -using ..Experimental, Reexport -@reexport using ADTypes +using ADTypes: AutoEnzyme, AutoReverseDiff, AutoTracker, AutoZygote +using ..Experimental: Experimental for f in (:TrainState, :apply_gradients, :compute_gradients) msg = lazy"`Lux.Training.$(f)` has been deprecated in favor of `Lux.Experimental.$(f)`" @@ -35,6 +42,8 @@ for f in (:TrainState, :apply_gradients, :compute_gradients) end end +export AutoEnzyme, AutoReverseDiff, AutoTracker, AutoZygote + end macro layer_map(f, l, ps, st) diff --git a/src/contrib/map.jl b/src/contrib/map.jl index 9d706bc1e..ee788018c 100644 --- a/src/contrib/map.jl +++ b/src/contrib/map.jl @@ -1,6 +1,3 @@ -using Markdown -using Functors: functor - @doc doc""" @layer_map func layer ps st diff --git a/src/contrib/share_parameters.jl b/src/contrib/share_parameters.jl index d38f67c8c..e0b6dc332 100644 --- a/src/contrib/share_parameters.jl +++ b/src/contrib/share_parameters.jl @@ -82,7 +82,7 @@ function _parameter_structure(ps) end function _assert_disjoint_sharing_list(sharing) - for i in 1:length(sharing), j in (i + 1):length(sharing) + for i in eachindex(sharing), j in (i + 1):length(sharing) if !isdisjoint(sharing[i], sharing[j]) throw(ArgumentError(lazy"sharing[$i] ($(sharing[i])) and sharing[$j] ($(sharing[j])) must be disjoint")) end diff --git a/src/contrib/training.jl b/src/contrib/training.jl index 7fbc33a72..3963bdb77 100644 --- a/src/contrib/training.jl +++ b/src/contrib/training.jl @@ -1,5 +1,3 @@ -using ADTypes, ConcreteStructs, Random, Setfield - """ TrainState diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 27aaa8bb2..b37b54cae 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -54,15 +54,11 @@ Flattens the passed array into a matrix. end @inline function (f::FlattenLayer)(x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N} + f.N === nothing && return reshape(x, :, size(x, N)), st @assert f.N < N return reshape(x, :, size(x)[(f.N + 1):end]...), st end -@inline function (f::FlattenLayer{Nothing})( - x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N} - return reshape(x, :, size(x, N)), st -end - """ SelectDim(dim, i) diff --git a/test/aqua_tests.jl b/test/aqua_tests.jl deleted file mode 100644 index dc6adbf75..000000000 --- a/test/aqua_tests.jl +++ /dev/null @@ -1,7 +0,0 @@ -@testitem "Aqua: Quality Assurance" begin - using Aqua, ChainRulesCore - - Aqua.test_all(Lux; piracies=false) - Aqua.test_piracies( - Lux; treat_as_own=[ChainRulesCore.frule, ChainRulesCore.rrule, Core.kwcall]) -end diff --git a/test/layers/recurrent_tests.jl b/test/layers/recurrent_tests.jl index 1dea1ca34..713454063 100644 --- a/test/layers/recurrent_tests.jl +++ b/test/layers/recurrent_tests.jl @@ -280,7 +280,7 @@ end end end -@testitem "Recurrence" setup=[SharedTestSetup] begin +@testitem "Recurrence" timeout=3000 setup=[SharedTestSetup] begin rng = get_stable_rng(12345) @testset "$mode" for (mode, aType, device, ongpu) in MODES diff --git a/test/qa_tests.jl b/test/qa_tests.jl new file mode 100644 index 000000000..16dfc5c70 --- /dev/null +++ b/test/qa_tests.jl @@ -0,0 +1,24 @@ +@testitem "Aqua: Quality Assurance" begin + using Aqua, ChainRulesCore + + Aqua.test_all(Lux; piracies=false) + Aqua.test_piracies( + Lux; treat_as_own=[ChainRulesCore.frule, ChainRulesCore.rrule, Core.kwcall]) +end + +@testitem "Explicit Imports: Quality Assurance" begin + # Load all trigger packages + import Lux, ComponentArrays, ReverseDiff, ChainRules, Flux, LuxAMDGPU, SimpleChains, + Tracker, Zygote + + using ExplicitImports + + # Skip our own packages + @test check_no_implicit_imports(Lux; + skip=(LuxCore, LuxLib, LuxDeviceUtils, WeightInitializers, Base, Core, Lux)) === + nothing + ## AbstractRNG seems to be a spurious detection in LuxFluxExt + @test check_no_stale_explicit_imports(Lux; + ignore=(:inputsize, :setup, :testmode, :trainmode, :update_state, :AbstractRNG)) === + nothing +end