From 66b369c8a0e51dc10224dc61ad5c73ba8c4265d8 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 28 May 2024 09:46:24 +0100 Subject: [PATCH 1/3] Basic Lux integration --- .github/workflows/CI.yml | 1 + Project.toml | 4 ++- test/integration_testing/lux.jl | 54 +++++++++++++++++++++++++++++++++ test/runtests.jl | 2 ++ 4 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 test/integration_testing/lux.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 38c946c7..f1acbe4e 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -31,6 +31,7 @@ jobs: - 'integration_testing/array' - 'integration_testing/turing' - 'integration_testing/temporalgps' + - 'integration_testing/lux' - 'interface' steps: - uses: actions/checkout@v4 diff --git a/Project.toml b/Project.toml index 8185e713..9ddb2332 100644 --- a/Project.toml +++ b/Project.toml @@ -41,6 +41,7 @@ FillArrays = "1" Graphs = "1" JET = "0.9" LogDensityProblemsAD = "1" +Lux = "0.5" PDMats = "0.11" Setfield = "1" SpecialFunctions = "2" @@ -59,6 +60,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" @@ -68,4 +70,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [targets] -test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "FillArrays", "KernelFunctions", "LogDensityProblemsAD", "PDMats", "ReverseDiff", "SpecialFunctions", "StableRNGs", "Test", "Turing", "TemporalGPs"] +test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "FillArrays", "KernelFunctions", "LogDensityProblemsAD", "Lux", "PDMats", "ReverseDiff", "SpecialFunctions", "StableRNGs", "Test", "Turing", "TemporalGPs"] diff --git a/test/integration_testing/lux.jl b/test/integration_testing/lux.jl new file mode 100644 index 00000000..f78fa295 --- /dev/null +++ b/test/integration_testing/lux.jl @@ -0,0 +1,54 @@ +using Lux + +@testset "lux" begin + interp = Tapir.TapirInterpreter() + @testset "$(typeof(f))" for (f, x_f32) in Any[ + (Dense(2, 4), randn(Float32, 2, 3)), + (Dense(2, 4, gelu), randn(Float32, 2, 3)), + (Dense(2, 4, gelu; use_bias=false), randn(Float32, 2, 3)), + (Chain(Dense(2, 4, relu), Dense(4, 3)), randn(Float32, 2, 3)), + (Scale(2), randn(Float32, 2, 3)), + (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 2)), + (Conv((3, 3), 2 => 3, gelu; pad=SamePad()), randn(Float32, 3, 3, 2, 2)), + (Conv((3, 3), 2 => 3, relu; use_bias=false, pad=SamePad()), randn(Float32, 3, 3, 2, 2)), + (Chain(Conv((3, 3), 2 => 3, gelu), Conv((3, 3), 3 => 1, gelu)), rand(Float32, 5, 5, 2, 2)), + (Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())), rand(Float32, 5, 5, 2, 2)), + (Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))), rand(Float32, 5, 5, 2, 2)), + (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 2)), + (Bilinear((2, 2) => 3), randn(Float32, 2, 3)), + (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3)), + (ConvTranspose((3, 3), 3 => 2; stride=2), rand(Float32, 5, 5, 3, 1)), + (StatefulRecurrentCell(RNNCell(3 => 5)), rand(Float32, 3, 2)), + (StatefulRecurrentCell(RNNCell(3 => 5, gelu)), rand(Float32, 3, 2)), + (StatefulRecurrentCell(RNNCell(3 => 5, gelu; use_bias=false)), rand(Float32, 3, 2)), + (Chain(StatefulRecurrentCell(RNNCell(3 => 5)), StatefulRecurrentCell(RNNCell(5 => 3))), rand(Float32, 3, 2)), + (StatefulRecurrentCell(LSTMCell(3 => 5)), rand(Float32, 3, 2)), + (Chain(StatefulRecurrentCell(LSTMCell(3 => 5)), StatefulRecurrentCell(LSTMCell(5 => 3))), rand(Float32, 3, 2)), + (StatefulRecurrentCell(GRUCell(3 => 5)), rand(Float32, 3, 10)), + (Chain(StatefulRecurrentCell(GRUCell(3 => 5)), StatefulRecurrentCell(GRUCell(5 => 3))), rand(Float32, 3, 10)), + (Chain(Dense(2, 4), BatchNorm(4)), randn(Float32, 2, 3)), + (Chain(Dense(2, 4), BatchNorm(4, gelu)), randn(Float32, 2, 3)), + (Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)), randn(Float32, 2, 3)), + (Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), + (Chain(Dense(2, 4), GroupNorm(4, 2, gelu)), randn(Float32, 2, 3)), + (Chain(Dense(2, 4), GroupNorm(4, 2)), randn(Float32, 2, 3)), + (Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 6, tanh), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 3, gelu), LayerNorm((1, 1, 3))), randn(Float32, 4, 4, 2, 2)), + (Chain(Conv((3, 3), 2 => 6), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), + ] + @info "$(_typeof((f, x_f32...)))" + ps, st = f64(Lux.setup(rng, f)) + x = f64(x_f32) + TestUtils.test_derived_rule( + Xoshiro(123456), f, x, ps, st; + safety_on=false, + interp, + perf_flag=:none, + interface_only=false, + is_primitive=false, + ) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 79791e1d..e77af286 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -61,6 +61,8 @@ include("front_matter.jl") include(joinpath("integration_testing", "turing.jl")) elseif test_group == "integration_testing/temporalgps" include(joinpath("integration_testing", "temporalgps.jl")) + elseif test_group == "integration_testing/lux" + include(joinpath("integration_testing", "lux.jl")) elseif test_group == "interface" include("interface.jl") elseif test_group == "gpu" From eff940c2e1038a547d6b59ae4af74dc1db3d2c6f Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 28 May 2024 17:48:18 +0100 Subject: [PATCH 2/3] Add rules for get_num_threads and set_num_threads --- src/rrules/blas.jl | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/rrules/blas.jl b/src/rrules/blas.jl index 4014b5d0..9f57b14e 100644 --- a/src/rrules/blas.jl +++ b/src/rrules/blas.jl @@ -20,6 +20,29 @@ function tri!(A, u::Char, d::Char) end +# +# Utility +# + +@is_primitive MinimalCtx Tuple{typeof(BLAS.get_num_threads)} +function rrule!!(f::CoDual{typeof(BLAS.get_num_threads)}) + return zero_fcodual(BLAS.get_num_threads()), NoPullback(f) +end + +@is_primitive MinimalCtx Tuple{typeof(BLAS.lbt_get_num_threads)} +function rrule!!(f::CoDual{typeof(BLAS.lbt_get_num_threads)}) + return zero_fcodual(BLAS.lbt_get_num_threads()), NoPullback(f) +end + +@is_primitive MinimalCtx Tuple{typeof(BLAS.set_num_threads), Union{Integer, Nothing}} +function rrule!!(f::CoDual{typeof(BLAS.set_num_threads)}, x::CoDual) + return zero_fcodual(BLAS.set_num_threads(primal(x))), NoPullback(f, x) +end + +@is_primitive MinimalCtx Tuple{typeof(BLAS.lbt_set_num_threads), Any} +function rrule!!(f::CoDual{typeof(BLAS.lbt_set_num_threads)}, x::CoDual) + return zero_fcodual(BLAS.lbt_set_num_threads(primal(x))), NoPullback(f, x) +end # # LEVEL 1 @@ -604,6 +627,12 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas}) betas = [0.0, 0.33] test_cases = vcat( + # utility + (false, :stability, nothing, BLAS.get_num_threads), + (false, :stability, nothing, BLAS.lbt_get_num_threads), + (false, :stability, nothing, BLAS.set_num_threads, 1), + (false, :stability, nothing, BLAS.lbt_set_num_threads, 1), + # gemm! vec(reduce( vcat, From 312aea855becc1e1d5153862b2918dcd0a8e2694 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 28 May 2024 17:48:37 +0100 Subject: [PATCH 3/3] Improve Lux testing notes --- test/integration_testing/lux.jl | 46 ++++++++++++++++----------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/test/integration_testing/lux.jl b/test/integration_testing/lux.jl index f78fa295..be90299f 100644 --- a/test/integration_testing/lux.jl +++ b/test/integration_testing/lux.jl @@ -8,31 +8,31 @@ using Lux (Dense(2, 4, gelu; use_bias=false), randn(Float32, 2, 3)), (Chain(Dense(2, 4, relu), Dense(4, 3)), randn(Float32, 2, 3)), (Scale(2), randn(Float32, 2, 3)), - (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 2)), - (Conv((3, 3), 2 => 3, gelu; pad=SamePad()), randn(Float32, 3, 3, 2, 2)), - (Conv((3, 3), 2 => 3, relu; use_bias=false, pad=SamePad()), randn(Float32, 3, 3, 2, 2)), - (Chain(Conv((3, 3), 2 => 3, gelu), Conv((3, 3), 3 => 1, gelu)), rand(Float32, 5, 5, 2, 2)), - (Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())), rand(Float32, 5, 5, 2, 2)), - (Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))), rand(Float32, 5, 5, 2, 2)), + (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 2)), # uses a task, so has recurrence problem. needs rule + (Conv((3, 3), 2 => 3, gelu; pad=SamePad()), randn(Float32, 3, 3, 2, 2)), # uses a task, so has recurrence problem. needs rule + (Conv((3, 3), 2 => 3, relu; use_bias=false, pad=SamePad()), randn(Float32, 3, 3, 2, 2)), # uses a task, so has recurrence problem. needs rule + (Chain(Conv((3, 3), 2 => 3, gelu), Conv((3, 3), 3 => 1, gelu)), rand(Float32, 5, 5, 2, 2)), # uses a task, so has recurrence problem. needs rule + (Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())), rand(Float32, 5, 5, 2, 2)), # uses a task, so has recurrence problem. needs rule + (Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))), rand(Float32, 5, 5, 2, 2)), # uses a task, so has recurrence problem. needs rule (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 2)), - (Bilinear((2, 2) => 3), randn(Float32, 2, 3)), + (Bilinear((2, 2) => 3), randn(Float32, 2, 3)), # missing intrinsic atomic_pointerref. Also might just need a rule (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3)), - (ConvTranspose((3, 3), 3 => 2; stride=2), rand(Float32, 5, 5, 3, 1)), - (StatefulRecurrentCell(RNNCell(3 => 5)), rand(Float32, 3, 2)), - (StatefulRecurrentCell(RNNCell(3 => 5, gelu)), rand(Float32, 3, 2)), - (StatefulRecurrentCell(RNNCell(3 => 5, gelu; use_bias=false)), rand(Float32, 3, 2)), - (Chain(StatefulRecurrentCell(RNNCell(3 => 5)), StatefulRecurrentCell(RNNCell(5 => 3))), rand(Float32, 3, 2)), - (StatefulRecurrentCell(LSTMCell(3 => 5)), rand(Float32, 3, 2)), - (Chain(StatefulRecurrentCell(LSTMCell(3 => 5)), StatefulRecurrentCell(LSTMCell(5 => 3))), rand(Float32, 3, 2)), - (StatefulRecurrentCell(GRUCell(3 => 5)), rand(Float32, 3, 10)), - (Chain(StatefulRecurrentCell(GRUCell(3 => 5)), StatefulRecurrentCell(GRUCell(5 => 3))), rand(Float32, 3, 10)), - (Chain(Dense(2, 4), BatchNorm(4)), randn(Float32, 2, 3)), - (Chain(Dense(2, 4), BatchNorm(4, gelu)), randn(Float32, 2, 3)), - (Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)), randn(Float32, 2, 3)), - (Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), + (ConvTranspose((3, 3), 3 => 2; stride=2), rand(Float32, 5, 5, 3, 1)), # uses a task, so has recurrence problem. needs rule + (StatefulRecurrentCell(RNNCell(3 => 5)), rand(Float32, 3, 2)), # fpext getting used here somehow + (StatefulRecurrentCell(RNNCell(3 => 5, gelu)), rand(Float32, 3, 2)), # fpext getting used here somehow + (StatefulRecurrentCell(RNNCell(3 => 5, gelu; use_bias=false)), rand(Float32, 3, 2)), # fpext getting used here somehow + (Chain(StatefulRecurrentCell(RNNCell(3 => 5)), StatefulRecurrentCell(RNNCell(5 => 3))), rand(Float32, 3, 2)), # fpext getting used here somehow + (StatefulRecurrentCell(LSTMCell(3 => 5)), rand(Float32, 3, 2)), # fpext getting used here somehow + (Chain(StatefulRecurrentCell(LSTMCell(3 => 5)), StatefulRecurrentCell(LSTMCell(5 => 3))), rand(Float32, 3, 2)), # fpext getting used here somehow + (StatefulRecurrentCell(GRUCell(3 => 5)), rand(Float32, 3, 10)), # fpext getting used here somehow + (Chain(StatefulRecurrentCell(GRUCell(3 => 5)), StatefulRecurrentCell(GRUCell(5 => 3))), rand(Float32, 3, 10)), # fpext getting used here somehow + (Chain(Dense(2, 4), BatchNorm(4)), randn(Float32, 2, 3)), # something about begin_optional -- this is an unrecognised expression + (Chain(Dense(2, 4), BatchNorm(4, gelu)), randn(Float32, 2, 3)), # something about begin_optional -- this is an unrecognised expression + (Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)), randn(Float32, 2, 3)), # something about begin_optional -- this is an unrecognised expression + (Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), stack overflow. Probably task again (Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), - (Chain(Dense(2, 4), GroupNorm(4, 2, gelu)), randn(Float32, 2, 3)), - (Chain(Dense(2, 4), GroupNorm(4, 2)), randn(Float32, 2, 3)), + (Chain(Dense(2, 4), GroupNorm(4, 2, gelu)), randn(Float32, 2, 3)), # fpext again + (Chain(Dense(2, 4), GroupNorm(4, 2)), randn(Float32, 2, 3)), # fpext again (Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), (Chain(Conv((3, 3), 2 => 6, tanh), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), (Chain(Conv((3, 3), 2 => 3, gelu), LayerNorm((1, 1, 3))), randn(Float32, 4, 4, 2, 2)), @@ -40,7 +40,7 @@ using Lux (Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), ] @info "$(_typeof((f, x_f32...)))" - ps, st = f64(Lux.setup(rng, f)) + ps, st = f64(Lux.setup(sr(123456), f)) x = f64(x_f32) TestUtils.test_derived_rule( Xoshiro(123456), f, x, ps, st;