From ea032ee6509c0ed131aa462e4a63258715470806 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 9 Apr 2024 12:30:41 -0400 Subject: [PATCH 1/3] Test if ChainRules problem is resolved --- Project.toml | 8 ++------ ext/LuxChainRulesExt.jl | 18 ------------------ test/qa_tests.jl | 3 +-- 3 files changed, 3 insertions(+), 26 deletions(-) delete mode 100644 ext/LuxChainRulesExt.jl diff --git a/Project.toml b/Project.toml index b6da5a9a6..f31e75d73 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.5.33" +version = "0.5.34" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -29,7 +29,6 @@ WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" [weakdeps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" @@ -43,7 +42,6 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] -LuxChainRulesExt = "ChainRules" LuxComponentArraysExt = "ComponentArrays" LuxComponentArraysReverseDiffExt = ["ComponentArrays", "ReverseDiff"] LuxFluxExt = "Flux" @@ -63,7 +61,6 @@ Adapt = "4" Aqua = "0.8.4" ArrayInterface = "7.8" CUDA = "5.2" -ChainRules = "1.62" ChainRulesCore = "1.21" ComponentArrays = "0.15.11" ConcreteStructs = "0.2.3" @@ -108,7 +105,6 @@ julia = "1.10" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" @@ -138,4 +134,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ADTypes", "Adapt", "Aqua", "ChainRules", "ChainRulesCore", "ComponentArrays", "ExplicitImports", "Flux", "Functors", "Logging", "LuxAMDGPU", "LuxCUDA", "LuxCore", "LuxLib", "LuxTestUtils", "MLUtils", "MPI", "NCCL", "Optimisers", "Pkg", "Random", "ReTestItems", "Reexport", "ReverseDiff", "Setfield", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"] +test = ["ADTypes", "Adapt", "Aqua", "ChainRulesCore", "ComponentArrays", "ExplicitImports", "Flux", "Functors", "Logging", "LuxAMDGPU", "LuxCUDA", "LuxCore", "LuxLib", "LuxTestUtils", "MLUtils", "MPI", "NCCL", "Optimisers", "Pkg", "Random", "ReTestItems", "Reexport", "ReverseDiff", "Setfield", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"] diff --git a/ext/LuxChainRulesExt.jl b/ext/LuxChainRulesExt.jl deleted file mode 100644 index 9ecf874d0..000000000 --- a/ext/LuxChainRulesExt.jl +++ /dev/null @@ -1,18 +0,0 @@ -module LuxChainRulesExt - -using ChainRules: ChainRules - -# https://github.com/FluxML/Zygote.jl/pull/1328 broke the RNNs completely. Putting an -# emergency patch here -function ChainRules._setindex_zero( - x::Vector{<:AbstractArray{T}}, dy, inds::Integer...) where {T <: Number} - return [fill!(similar(xᵢ), 0) for xᵢ in x] -end - -function ChainRules.∇getindex!( - dx::Vector{<:AbstractArray{T}}, dy, inds::Integer...) where {T <: Number} - dx[inds...] .+= dy - return dx -end - -end diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 16dfc5c70..763f16039 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -8,8 +8,7 @@ end @testitem "Explicit Imports: Quality Assurance" begin # Load all trigger packages - import Lux, ComponentArrays, ReverseDiff, ChainRules, Flux, LuxAMDGPU, SimpleChains, - Tracker, Zygote + import Lux, ComponentArrays, ReverseDiff, Flux, LuxAMDGPU, SimpleChains, Tracker, Zygote using ExplicitImports From 173e01f8a9cff2e9c0223c967cdec09de4cb213f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 9 Apr 2024 13:16:49 -0400 Subject: [PATCH 2/3] Provide an option to output array --- src/transform/simplechains.jl | 41 ++++++++++++++++++++------- test/transform/simple_chains_tests.jl | 11 +++++++ 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/src/transform/simplechains.jl b/src/transform/simplechains.jl index 3ba625d3c..28dd7f746 100644 --- a/src/transform/simplechains.jl +++ b/src/transform/simplechains.jl @@ -1,5 +1,5 @@ """ - ToSimpleChainsAdaptor() + ToSimpleChainsAdaptor(input_dims, convert_to_array::Bool=false) Adaptor for converting a Lux Model to SimpleChains. The returned model is still a Lux model, and satisfies the `AbstractExplicitLayer` interfacem but all internal calculations are @@ -19,6 +19,9 @@ performed using SimpleChains. - `input_dims`: Tuple of input dimensions excluding the batch dimension. These must be of `static` type as `SimpleChains` expects. + - `convert_to_array`: SimpleChains.jl by default outputs `StrideArraysCore.StrideArray`, + but this might not compose well with other packages. If `convert_to_array` is set to + `true`, the output will be converted to a regular `Array`. ## Example @@ -42,13 +45,14 @@ simple_chains_model(x, ps, st) """ struct ToSimpleChainsAdaptor{ID} <: AbstractFromLuxAdaptor input_dims::ID + convert_to_array::Bool - function ToSimpleChainsAdaptor(input_dims) + function ToSimpleChainsAdaptor(input_dims, convert_to_array::Bool=false) input_dims isa Number && (input_dims = (input_dims,)) if input_dims isa Tuple{Vararg{Integer}} throw(ArgumentError("`input_dims` must be a Tuple of `static` integers.")) end - return new{typeof(input_dims)}(input_dims) + return new{typeof(input_dims)}(input_dims, convert_to_array) end end @@ -62,7 +66,7 @@ function Adapt.adapt(to::ToSimpleChainsAdaptor, L::AbstractExplicitLayer) error("`ToSimpleChainsAdaptor` requires `SimpleChains.jl` to be loaded.") end sc_layer = __fix_input_dims_simplechain(__to_simplechains_adaptor(L), to.input_dims) - return SimpleChainsLayer(sc_layer) + return SimpleChainsLayer{to.convert_to_array}(sc_layer) end function __to_simplechains_adaptor end @@ -82,26 +86,43 @@ function Base.showerror(io::IO, e::SimpleChainsModelConversionError) end """ - SimpleChainsLayer(layer) + SimpleChainsLayer{ToArray}(layer) + SimpleChainsLayer(layer, ToArray::Bool=false) Wraps a `SimpleChains` layer into a `Lux` layer. All operations are performed using `SimpleChains` but the layer satisfies the `AbstractExplicitLayer` interface. +`ToArray` is a boolean flag that determines whether the output should be converted to a +regular `Array` or not. Default is `false`. + ## Arguments - `layer`: SimpleChains layer + +!!! note + + Using the 2nd constructor makes the generation of the model struct type unstable. """ -@concrete struct SimpleChainsLayer <: AbstractExplicitLayer - layer +struct SimpleChainsLayer{ToArray, L} <: AbstractExplicitLayer + layer::L end -initialstates(rng::AbstractRNG, layer::SimpleChainsLayer) = (;) +@inline function SimpleChainsLayer{ToArray}(layer) where {ToArray} + return SimpleChainsLayer{ToArray, typeof(layer)}(layer) +end +@inline SimpleChainsLayer(layer, ToArray::Bool=false) = SimpleChainsLayer{ToArray}(layer) -function (sc::SimpleChainsLayer)(x, ps, st) +@inline initialstates(::AbstractRNG, ::SimpleChainsLayer) = (;) + +@inline function (sc::SimpleChainsLayer{false})(x, ps, st) return __apply_simple_chain(sc.layer, x, ps.params, get_device(x)), st end -__apply_simple_chain(layer, x, ps, ::LuxCPUDevice) = layer(x, ps) +@inline function (sc::SimpleChainsLayer{true})(x, ps, st) + return convert(Array, __apply_simple_chain(sc.layer, x, ps.params, get_device(x))), st +end + +@inline __apply_simple_chain(layer, x, ps, ::LuxCPUDevice) = layer(x, ps) function __apply_simple_chain(layer, x, ps, dev) throw(ArgumentError(lazy"`SimpleChains.jl` only supports CPU operations. Current device detected as $(dev).")) diff --git a/test/transform/simple_chains_tests.jl b/test/transform/simple_chains_tests.jl index ac2a7af21..e8aa133c0 100644 --- a/test/transform/simple_chains_tests.jl +++ b/test/transform/simple_chains_tests.jl @@ -33,6 +33,17 @@ @test size(gs[1]) == size(x) @test length(gs[2].params) == length(ps.params) + @testset "Array Output" begin + adaptor = ToSimpleChainsAdaptor((static(28), static(28), static(1)), true) + simple_chains_model = adaptor(lux_model) + + ps, st = Lux.setup(Random.default_rng(), simple_chains_model) + x = randn(Float32, 28, 28, 1, 1) + + @test size(first(simple_chains_model(x, ps, st))) == (10, 1) + @test first(simple_chains_model(x, ps, st)) isa Array + end + lux_model = Chain( FlattenLayer(3), Dense(784 => 20, tanh), Dropout(0.5), Dense(20 => 10)) From acb3b62a749a3479c8cae2df424af3fcee8154e7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 9 Apr 2024 14:21:56 -0400 Subject: [PATCH 3/3] Add rules for ReverseDiff and Tracker --- docs/src/manual/debugging.md | 2 +- ext/LuxReverseDiffExt.jl | 12 ++++++++--- ext/LuxTrackerExt.jl | 29 ++++++++++++++++++++++++--- src/transform/simplechains.jl | 11 ++++++---- test/transform/simple_chains_tests.jl | 26 +++++++++++++++++++----- 5 files changed, 64 insertions(+), 16 deletions(-) diff --git a/docs/src/manual/debugging.md b/docs/src/manual/debugging.md index 6ab3ca387..8b28e7edf 100644 --- a/docs/src/manual/debugging.md +++ b/docs/src/manual/debugging.md @@ -6,7 +6,7 @@ useful tools that ship with Lux, that can help you debug your models. !!! tip "TL;DR" - Simply wrap your model with `Lux.Experimental.@debug`!! + Simply wrap your model with `Lux.Experimental.@debug_mode`!! !!! warning "Don't Forget" diff --git a/ext/LuxReverseDiffExt.jl b/ext/LuxReverseDiffExt.jl index bc52c0e5a..19f081a72 100644 --- a/ext/LuxReverseDiffExt.jl +++ b/ext/LuxReverseDiffExt.jl @@ -3,8 +3,8 @@ module LuxReverseDiffExt using ADTypes: AutoReverseDiff using ArrayInterface: ArrayInterface using Functors: fmap -using Lux: Lux -using ReverseDiff: ReverseDiff +using Lux: Lux, LuxCPUDevice +using ReverseDiff: ReverseDiff, TrackedArray, @grad_from_chainrules using Setfield: @set! function Lux.Experimental.compute_gradients(::AutoReverseDiff, objective_function::F, data, @@ -33,6 +33,12 @@ function Lux.apply( end ## Prevent an infinite loop -Lux.apply(m::Lux.AbstractExplicitLayer, x::ReverseDiff.TrackedArray, ps, st) = m(x, ps, st) +Lux.apply(m::Lux.AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st) + +# Handle SimpleChains +@grad_from_chainrules Lux.__apply_simple_chain(layer, x::TrackedArray, ps, ::LuxCPUDevice) +@grad_from_chainrules Lux.__apply_simple_chain(layer, x, ps::TrackedArray, ::LuxCPUDevice) +@grad_from_chainrules Lux.__apply_simple_chain( + layer, x::TrackedArray, ps::TrackedArray, ::LuxCPUDevice) end diff --git a/ext/LuxTrackerExt.jl b/ext/LuxTrackerExt.jl index 960169ddb..4e1d65c17 100644 --- a/ext/LuxTrackerExt.jl +++ b/ext/LuxTrackerExt.jl @@ -2,10 +2,14 @@ module LuxTrackerExt using ADTypes: AutoTracker using ArrayInterface: ArrayInterface +using ChainRulesCore: ChainRulesCore +using FastClosures: @closure using Functors: fmap -using Lux: Lux +using Lux: Lux, LuxCPUDevice using Setfield: @set! -using Tracker: Tracker +using Tracker: Tracker, TrackedArray + +const CRC = ChainRulesCore # Type Piracy: Need to upstream Tracker.param(nt::NamedTuple) = fmap(Tracker.param, nt) @@ -20,7 +24,7 @@ Tracker.data(nt::NamedTuple) = fmap(Tracker.data, nt) Tracker.data(t::Tuple) = map(Tracker.data, t) # Weight Norm Patch -@inline Lux._norm(x::Tracker.TrackedArray; dims=Colon()) = sqrt.(sum(abs2.(x); dims)) +@inline Lux._norm(x::TrackedArray; dims=Colon()) = sqrt.(sum(abs2.(x); dims)) # multigate chain rules @inline Lux._gate(x::Tracker.TrackedVector, h::Int, n::Int) = x[Lux._gate(h, n)] @@ -49,4 +53,23 @@ function Lux.apply( return Lux.apply(m, ArrayInterface.aos_to_soa(x), ps, st) end +# Handle SimpleChains +for T1 in (:TrackedArray, :AbstractArray), T2 in (:TrackedArray, :AbstractArray) + T1 === :AbstractArray && T2 === :AbstractArray && continue + + @eval function Lux.__apply_simple_chain(layer, x::$(T1), ps::$(T2), dev::LuxCPUDevice) + return Tracker.track(Lux.__apply_simple_chain, layer, x, ps, dev) + end +end + +Tracker.@grad function Lux.__apply_simple_chain(layer, x, ps, ::LuxCPUDevice) + y, pb_f = CRC.rrule(layer, Tracker.data(x), Tracker.data(ps)) + __∇apply_simple_chain = @closure Δ -> begin + _, ∂x, ∂ps = pb_f(convert(Array, Δ)) + return Tracker.nobacksies(:__apply_simple_chain, (nothing, ∂x, ∂ps, nothing)) + end + # Tracker is not great at handling arbitrary types, so we convert to Array + return Array(y), __∇apply_simple_chain +end + end diff --git a/src/transform/simplechains.jl b/src/transform/simplechains.jl index 28dd7f746..b27a4f77c 100644 --- a/src/transform/simplechains.jl +++ b/src/transform/simplechains.jl @@ -102,6 +102,10 @@ regular `Array` or not. Default is `false`. !!! note Using the 2nd constructor makes the generation of the model struct type unstable. + +!!! note + + If using `Tracker.jl`, the output will always be a regular `Array`. """ struct SimpleChainsLayer{ToArray, L} <: AbstractExplicitLayer layer::L @@ -129,10 +133,9 @@ function __apply_simple_chain(layer, x, ps, dev) end # Workaround for SimpleChains not being able to handle some input types -function CRC.rrule(cfg::CRC.RuleConfig{>:HasReverseMode}, - ::typeof(__apply_simple_chain), layer, x, ps, dev) - res, pb = CRC.rrule_via_ad(cfg, layer, x, ps) - function __∇apply_simple_chain(Δ) +function CRC.rrule(::typeof(__apply_simple_chain), layer, x, ps, ::LuxCPUDevice) + res, pb = CRC.rrule(layer, x, ps) + __∇apply_simple_chain = @closure Δ -> begin # Safety measure to prevent errors from weird Array types that SimpleChains doesn't # support ∂layer, ∂x, ∂ps = pb(convert(Array, Δ)) diff --git a/test/transform/simple_chains_tests.jl b/test/transform/simple_chains_tests.jl index e8aa133c0..7f0f7366d 100644 --- a/test/transform/simple_chains_tests.jl +++ b/test/transform/simple_chains_tests.jl @@ -22,17 +22,25 @@ x = randn(Float32, 28, 28, 1, 1) @test size(first(simple_chains_model(x, ps, st))) == (10, 1) - gs = Zygote.gradient((x, p) -> sum(first(simple_chains_model(x, p, st))), x, ps) + __f = (x, p) -> sum(first(simple_chains_model(x, p, st))) + + gs = Zygote.gradient(__f, x, ps) @test size(gs[1]) == size(x) @test length(gs[2].params) == length(ps.params) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 + x = randn(Float32, 28, 28, 1, 15) @test size(first(simple_chains_model(x, ps, st))) == (10, 15) - gs = Zygote.gradient((x, p) -> sum(first(simple_chains_model(x, p, st))), x, ps) + __f = (x, p) -> sum(first(simple_chains_model(x, p, st))) + + gs = Zygote.gradient(__f, x, ps) @test size(gs[1]) == size(x) @test length(gs[2].params) == length(ps.params) + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 + @testset "Array Output" begin adaptor = ToSimpleChainsAdaptor((static(28), static(28), static(1)), true) simple_chains_model = adaptor(lux_model) @@ -56,14 +64,18 @@ x = randn(Float32, 28, 28, 1, 1) @test size(first(simple_chains_model(x, ps, st))) == (10, 1) - gs = Zygote.gradient((x, p) -> sum(first(simple_chains_model(x, p, st))), x, ps) + __f = (x, p) -> sum(first(simple_chains_model(x, p, st))) + + gs = Zygote.gradient(__f, x, ps) @test size(gs[1]) == size(x) @test length(gs[2].params) == length(ps.params) x = randn(Float32, 28, 28, 1, 15) @test size(first(simple_chains_model(x, ps, st))) == (10, 15) - gs = Zygote.gradient((x, p) -> sum(first(simple_chains_model(x, p, st))), x, ps) + __f = (x, p) -> sum(first(simple_chains_model(x, p, st))) + + gs = Zygote.gradient(__f, x, ps) @test size(gs[1]) == size(x) @test length(gs[2].params) == length(ps.params) @@ -80,9 +92,13 @@ x = randn(Float32, 10, 3) @test size(first(simple_chains_model(x, ps, st))) == (5, 3) - gs = Zygote.gradient((x, p) -> sum(first(simple_chains_model(x, p, st))), x, ps) + __f = (x, p) -> sum(first(simple_chains_model(x, p, st))) + + gs = Zygote.gradient(__f, x, ps) @test size(gs[1]) == size(x) @test length(gs[2].params) == length(ps.params) + + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 end end