diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 38c946c7..f1acbe4e 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -31,6 +31,7 @@ jobs: - 'integration_testing/array' - 'integration_testing/turing' - 'integration_testing/temporalgps' + - 'integration_testing/lux' - 'interface' steps: - uses: actions/checkout@v4 diff --git a/Project.toml b/Project.toml index 8185e713..9ddb2332 100644 --- a/Project.toml +++ b/Project.toml @@ -41,6 +41,7 @@ FillArrays = "1" Graphs = "1" JET = "0.9" LogDensityProblemsAD = "1" +Lux = "0.5" PDMats = "0.11" Setfield = "1" SpecialFunctions = "2" @@ -59,6 +60,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" @@ -68,4 +70,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [targets] -test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "FillArrays", "KernelFunctions", "LogDensityProblemsAD", "PDMats", "ReverseDiff", "SpecialFunctions", "StableRNGs", "Test", "Turing", "TemporalGPs"] +test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "FillArrays", "KernelFunctions", "LogDensityProblemsAD", "Lux", "PDMats", "ReverseDiff", "SpecialFunctions", "StableRNGs", "Test", "Turing", "TemporalGPs"] diff --git a/src/rrules/blas.jl b/src/rrules/blas.jl index 4014b5d0..9f57b14e 100644 --- a/src/rrules/blas.jl +++ b/src/rrules/blas.jl @@ -20,6 +20,29 @@ function tri!(A, u::Char, d::Char) end +# +# Utility +# + +@is_primitive MinimalCtx Tuple{typeof(BLAS.get_num_threads)} +function rrule!!(f::CoDual{typeof(BLAS.get_num_threads)}) + return zero_fcodual(BLAS.get_num_threads()), NoPullback(f) +end + +@is_primitive MinimalCtx Tuple{typeof(BLAS.lbt_get_num_threads)} +function rrule!!(f::CoDual{typeof(BLAS.lbt_get_num_threads)}) + return zero_fcodual(BLAS.lbt_get_num_threads()), NoPullback(f) +end + +@is_primitive MinimalCtx Tuple{typeof(BLAS.set_num_threads), Union{Integer, Nothing}} +function rrule!!(f::CoDual{typeof(BLAS.set_num_threads)}, x::CoDual) + return zero_fcodual(BLAS.set_num_threads(primal(x))), NoPullback(f, x) +end + +@is_primitive MinimalCtx Tuple{typeof(BLAS.lbt_set_num_threads), Any} +function rrule!!(f::CoDual{typeof(BLAS.lbt_set_num_threads)}, x::CoDual) + return zero_fcodual(BLAS.lbt_set_num_threads(primal(x))), NoPullback(f, x) +end # # LEVEL 1 @@ -604,6 +627,12 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas}) betas = [0.0, 0.33] test_cases = vcat( + # utility + (false, :stability, nothing, BLAS.get_num_threads), + (false, :stability, nothing, BLAS.lbt_get_num_threads), + (false, :stability, nothing, BLAS.set_num_threads, 1), + (false, :stability, nothing, BLAS.lbt_set_num_threads, 1), + # gemm! vec(reduce( vcat, diff --git a/test/integration_testing/lux.jl b/test/integration_testing/lux.jl new file mode 100644 index 00000000..be90299f --- /dev/null +++ b/test/integration_testing/lux.jl @@ -0,0 +1,54 @@ +using Lux + +@testset "lux" begin + interp = Tapir.TapirInterpreter() + @testset "$(typeof(f))" for (f, x_f32) in Any[ + (Dense(2, 4), randn(Float32, 2, 3)), + (Dense(2, 4, gelu), randn(Float32, 2, 3)), + (Dense(2, 4, gelu; use_bias=false), randn(Float32, 2, 3)), + (Chain(Dense(2, 4, relu), Dense(4, 3)), randn(Float32, 2, 3)), + (Scale(2), randn(Float32, 2, 3)), + (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 2)), # uses a task, so has recurrence problem. needs rule + (Conv((3, 3), 2 => 3, gelu; pad=SamePad()), randn(Float32, 3, 3, 2, 2)), # uses a task, so has recurrence problem. needs rule + (Conv((3, 3), 2 => 3, relu; use_bias=false, pad=SamePad()), randn(Float32, 3, 3, 2, 2)), # uses a task, so has recurrence problem. needs rule + (Chain(Conv((3, 3), 2 => 3, gelu), Conv((3, 3), 3 => 1, gelu)), rand(Float32, 5, 5, 2, 2)), # uses a task, so has recurrence problem. needs rule + (Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())), rand(Float32, 5, 5, 2, 2)), # uses a task, so has recurrence problem. needs rule + (Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))), rand(Float32, 5, 5, 2, 2)), # uses a task, so has recurrence problem. needs rule + (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 2)), + (Bilinear((2, 2) => 3), randn(Float32, 2, 3)), # missing intrinsic atomic_pointerref. Also might just need a rule + (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3)), + (ConvTranspose((3, 3), 3 => 2; stride=2), rand(Float32, 5, 5, 3, 1)), # uses a task, so has recurrence problem. needs rule + (StatefulRecurrentCell(RNNCell(3 => 5)), rand(Float32, 3, 2)), # fpext getting used here somehow + (StatefulRecurrentCell(RNNCell(3 => 5, gelu)), rand(Float32, 3, 2)), # fpext getting used here somehow + (StatefulRecurrentCell(RNNCell(3 => 5, gelu; use_bias=false)), rand(Float32, 3, 2)), # fpext getting used here somehow + (Chain(StatefulRecurrentCell(RNNCell(3 => 5)), StatefulRecurrentCell(RNNCell(5 => 3))), rand(Float32, 3, 2)), # fpext getting used here somehow + (StatefulRecurrentCell(LSTMCell(3 => 5)), rand(Float32, 3, 2)), # fpext getting used here somehow + (Chain(StatefulRecurrentCell(LSTMCell(3 => 5)), StatefulRecurrentCell(LSTMCell(5 => 3))), rand(Float32, 3, 2)), # fpext getting used here somehow + (StatefulRecurrentCell(GRUCell(3 => 5)), rand(Float32, 3, 10)), # fpext getting used here somehow + (Chain(StatefulRecurrentCell(GRUCell(3 => 5)), StatefulRecurrentCell(GRUCell(5 => 3))), rand(Float32, 3, 10)), # fpext getting used here somehow + (Chain(Dense(2, 4), BatchNorm(4)), randn(Float32, 2, 3)), # something about begin_optional -- this is an unrecognised expression + (Chain(Dense(2, 4), BatchNorm(4, gelu)), randn(Float32, 2, 3)), # something about begin_optional -- this is an unrecognised expression + (Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)), randn(Float32, 2, 3)), # something about begin_optional -- this is an unrecognised expression + (Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), stack overflow. Probably task again + (Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), + (Chain(Dense(2, 4), GroupNorm(4, 2, gelu)), randn(Float32, 2, 3)), # fpext again + (Chain(Dense(2, 4), GroupNorm(4, 2)), randn(Float32, 2, 3)), # fpext again + (Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 6, tanh), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 3, gelu), LayerNorm((1, 1, 3))), randn(Float32, 4, 4, 2, 2)), + (Chain(Conv((3, 3), 2 => 6), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), + ] + @info "$(_typeof((f, x_f32...)))" + ps, st = f64(Lux.setup(sr(123456), f)) + x = f64(x_f32) + TestUtils.test_derived_rule( + Xoshiro(123456), f, x, ps, st; + safety_on=false, + interp, + perf_flag=:none, + interface_only=false, + is_primitive=false, + ) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 79791e1d..e77af286 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -61,6 +61,8 @@ include("front_matter.jl") include(joinpath("integration_testing", "turing.jl")) elseif test_group == "integration_testing/temporalgps" include(joinpath("integration_testing", "temporalgps.jl")) + elseif test_group == "integration_testing/lux" + include(joinpath("integration_testing", "lux.jl")) elseif test_group == "interface" include("interface.jl") elseif test_group == "gpu"