diff --git a/test/runtests.jl b/test/runtests.jl index ce1e44ebd..7b49ec60c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -44,13 +44,8 @@ using Lux # Type Stability tests fail if run with DispatchDoctor enabled if "all" in LUX_TEST_GROUP || "core_layers" in LUX_TEST_GROUP - try - # Run in a separate process to load the updated preferences - run(`$(Base.julia_cmd()) --color=yes --project=$(dirname(Pkg.project().path)) - --startup-file=no --code-coverage=user $(@__DIR__)/zygote_type_stability.jl`) - @test true - catch - @test false + @testset "Zygote Type Stability" begin + include("zygote_type_stability.jl") end end @@ -76,9 +71,11 @@ Lux.set_dispatch_doctor_preferences!(; luxcore="error", luxlib="error") @testset "Load Packages Tests" begin @test_throws ErrorException FromFluxAdaptor()(1) showerror(stdout, Lux.FluxModelConversionException("cannot convert")) + println() @test_throws ErrorException ToSimpleChainsAdaptor(nothing)(Dense(2 => 2)) showerror(stdout, Lux.SimpleChainsModelConversionException(Dense(2 => 2))) + println() @test_throws ErrorException vector_jacobian_product( x -> x, AutoZygote(), rand(2), rand(2)) diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index 85abd3204..aba3646de 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -1,8 +1,5 @@ @testsetup module SharedTestSetup -using Enzyme -Enzyme.API.runtimeActivity!(true) - include("setup_modes.jl") import Reexport: @reexport diff --git a/test/zygote_type_stability.jl b/test/zygote_type_stability.jl index b8d0d22c3..1338ca229 100644 --- a/test/zygote_type_stability.jl +++ b/test/zygote_type_stability.jl @@ -75,8 +75,8 @@ include("setup_modes.jl") ps, st = Lux.setup(rng, model) |> dev x = input |> dev - @test @inferred(model(x, ps, st)) isa Any - @test @inferred(loss_function(model, x, ps, st)) isa Any + @test @inferred(model(x, ps, Lux.testmode(st))) isa Any + @test @inferred(loss_function(model, x, ps, Lux.testmode(st))) isa Number if mode == "amdgpu" && model isa Conv @test_broken @inferred(Zygote.gradient(loss_function, model, x, ps, st)) isa Any