From 74b034d3ea64218c7f66771f9e0044bceaa33b6e Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Sun, 15 Sep 2024 22:27:18 -0400 Subject: [PATCH] Fix tests --- test/chainrules.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/chainrules.jl b/test/chainrules.jl index 00fd1b0af..3d5fcb035 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -422,11 +422,11 @@ end end @testset "ChainRules translation" begin - @test Zygote.wrap_chainrules_input(nothing) == ChainRules.ZeroTangent() - @test Zygote.wrap_chainrules_input((nothing,)) == ChainRules.ZeroTangent() - @test Zygote.wrap_chainrules_input([nothing]) == ChainRules.ZeroTangent() - @test Zygote.wrap_chainrules_input(((1.0, 2.0), 3.0)) == ChainRules.Tangent{Any}(ChainRules.Tangent{Any}(1.0, 2.0), 3.0) - @test Zygote.wrap_chainrules_input((; a = 1.0, b = 2.0)) == ChainRules.Tangent{Any}(a = 1.0, b = 2.0) + @test Zygote.wrap_chainrules_input(nothing) == ZeroTangent() + @test Zygote.wrap_chainrules_input((nothing,)) == ZeroTangent() + @test Zygote.wrap_chainrules_input([nothing]) == ZeroTangent() + @test Zygote.wrap_chainrules_input(((1.0, 2.0), 3.0)) == Tangent{Any}(Tangent{Any}(1.0, 2.0), 3.0) + @test Zygote.wrap_chainrules_input((; a = 1.0, b = 2.0)) == Tangent{Any}(a = 1.0, b = 2.0) @test Zygote.wrap_chainrules_input(Ref(1)) == 1 @test Zygote.wrap_chainrules_input([2.0; 4.0]) == [2.0; 4.0] @test Zygote.wrap_chainrules_input([[2.0; 4.0], [1.0; 3.0]]) == [[2.0; 4.0], [1.0; 3.0]]