diff --git a/Project.toml b/Project.toml index 28d8f5b4..62c8e76e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Tapir" uuid = "07d77754-e150-4737-8c94-cd238a1fb45b" authors = ["Will Tebbutt and contributors"] -version = "0.1.1" +version = "0.1.2" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/test_utils.jl b/src/test_utils.jl index fe942a30..8541377f 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -550,6 +550,9 @@ function test_tangent_consistency(rng::AbstractRNG, p::P; interface_only=false) test_equality_comparison(p) test_equality_comparison(t) + # Check that zero_tangent isn't obviously non-deterministic. + @test has_equal_data(z, Tapir.zero_tangent(p)) + # Check that ismutabletype(P) => ismutabletype(T) if ismutabletype(P) && !(T == NoTangent) @test ismutabletype(T) @@ -774,99 +777,28 @@ necessary but not sufficient conditions for the correctness of your code. function test_tangent(rng::AbstractRNG, p::P, z_target::T, x::T, y::T) where {P, T} @nospecialize rng p z_target x y - # This basic functionality must run in order to be able to check everything else. - @test tangent_type(P) isa Type - @test tangent_type(P) == T - @test zero_tangent(p) isa T - @test randn_tangent(rng, p) isa T - test_equality_comparison(p) - test_equality_comparison(x) - - # Verify that interface `tangent_type` runs. - Tt = tangent_type(P) - t = randn_tangent(rng, p) - z = zero_tangent(p) - - # Check that user-provided tangents have the same type as `tangent_type` expects. - @test T == Tt + # Check the interface. + test_tangent_consistency(rng, p; interface_only=false) - # Check that ismutabletype(P) => ismutabletype(T) - if ismutabletype(P) && !(Tt == NoTangent) - @test ismutabletype(Tt) - end - - # Check that tangents are of the correct type. - @test Tt == _typeof(t) - @test Tt == _typeof(z) - - # Check that zero_tangent is deterministic. - @test has_equal_data(z, Tapir.zero_tangent(p)) + # Is the tangent_type of `P` what we expected? + @test tangent_type(P) == T # Check that zero_tangent infers. - @test has_equal_data(z, @inferred Tapir.zero_tangent(p)) + @inferred Tapir.zero_tangent(p) # Verify that the zero tangent is zero via its action. - zc = deepcopy(z) - tc = deepcopy(t) - @test has_equal_data(@inferred(increment!!(zc, zc)), zc) - @test has_equal_data(increment!!(zc, tc), tc) - @test has_equal_data(increment!!(tc, zc), tc) - - if ismutabletype(P) - @test increment!!(zc, zc) === zc - @test increment!!(tc, zc) === tc - @test increment!!(zc, tc) === zc - @test increment!!(tc, tc) === tc - end + z = zero_tangent(p) + t = randn_tangent(rng, p) + @test has_equal_data(@inferred(increment!!(z, z)), z) + @test has_equal_data(increment!!(z, t), t) + @test has_equal_data(increment!!(t, z), t) + # Verify that adding together `x` and `y` gives the value the user expected. z_pred = increment!!(x, y) @test has_equal_data(z_pred, z_target) if ismutabletype(P) @test z_pred === x end - - # If t isn't the zero element, then adding it to itself must change its value. - if !has_equal_data(t, z) - if !ismutabletype(P) - tc′ = increment!!(tc, tc) - @test tc === tc′ || !has_equal_data(tc′, tc) - end - end - - # Adding things preserves types. - @test increment!!(zc, zc) isa Tt - @test increment!!(zc, tc) isa Tt - @test increment!!(tc, zc) isa Tt - - # Setting to zero equals zero. - @test has_equal_data(set_to_zero!!(tc), z) - if ismutabletype(P) - @test set_to_zero!!(tc) === tc - end - - z = zero_tangent(p) - r = randn_tangent(rng, p) - - # Verify that operations required for finite difference testing to run, and produce the - # correct output type. - @test _add_to_primal(p, t) isa P - @test _diff(p, p) isa T - @test _dot(t, t) isa Float64 - @test _scale(11.0, t) isa T - @test populate_address_map(p, t) isa AddressMap - - # Run some basic numerical sanity checks on the output the functions required for finite - # difference testing. These are necessary but insufficient conditions. - @test has_equal_data(_add_to_primal(p, z), p) - if !has_equal_data(z, r) - @test !has_equal_data(_add_to_primal(p, r), p) - end - @test has_equal_data(_diff(p, p), zero_tangent(p)) - @test _dot(t, t) >= 0.0 - @test _dot(t, zero_tangent(p)) == 0.0 - @test _dot(t, increment!!(deepcopy(t), t)) ≈ 2 * _dot(t, t) - @test has_equal_data(_scale(1.0, t), t) - @test has_equal_data(_scale(2.0, t), increment!!(deepcopy(t), t)) end function test_equality_comparison(x)