diff --git a/Project.toml b/Project.toml index 919cb0c0..40b29d60 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.3" +version = "0.4.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/rrules/low_level_maths.jl b/src/rrules/low_level_maths.jl index 33b53c13..9569ccfe 100644 --- a/src/rrules/low_level_maths.jl +++ b/src/rrules/low_level_maths.jl @@ -5,7 +5,7 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) continue # Skip rules for methods not defined in the current scope end (f == :rem2pi || f == :ldexp) && continue # not designed for Float64s - (f in [:+, :*, :sin, :cos, :exp]) && continue # use other functionality to implement these + (f in [:+, :*, :sin, :cos, :exp, :-, :abs2, :inv, :abs, :/, :\]) && continue # use other functionality to implement these if arity == 1 dx = DiffRules.diffrule(M, f, :x) pb_name = Symbol("$(M).$(f)_pb!!") @@ -84,11 +84,19 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:low_level_mat end arity > 2 && return (f == :rem2pi || f == :ldexp || f == :(^)) && return - (f == :+ || f == :*) && return # use intrinsics instead + (f in [:+, :*, :sin, :cos, :exp, :-, :abs2, :inv, :abs, :/, :\]) && return # use other functionality to implement these f = @eval $M.$f push!(test_cases, (false, :stability, nothing, f, rand_inputs(rng, Float64, f, arity)...)) push!(test_cases, (true, :stability, nothing, f, rand_inputs(rng, Float32, f, arity)...)) end + + # test cases for additional rules written in this file. + push!(test_cases, (false, :stability_and_allocs, nothing, sin, 1.1)) + push!(test_cases, (true, :stability_and_allocs, nothing, sin, Float32(1.1))) + push!(test_cases, (false, :stability_and_allocs, nothing, cos, 1.1)) + push!(test_cases, (true, :stability_and_allocs, nothing, cos, Float32(1.1))) + push!(test_cases, (false, :stability_and_allocs, nothing, exp, 1.1)) + push!(test_cases, (true, :stability_and_allocs, nothing, exp, Float32(1.1))) memory = Any[] return test_cases, memory end diff --git a/test/front_matter.jl b/test/front_matter.jl index d738d2c5..9b32e431 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -61,7 +61,9 @@ using Mooncake: InvalidFDataException, InvalidRDataException, verify_fdata_value, - verify_rdata_value + verify_rdata_value, + is_primitive, + MinimalCtx using .TestUtils: test_rule, diff --git a/test/rrules/low_level_maths.jl b/test/rrules/low_level_maths.jl index 32c696d4..ce6834a4 100644 --- a/test/rrules/low_level_maths.jl +++ b/test/rrules/low_level_maths.jl @@ -1,3 +1,20 @@ @testset "low_level_maths" begin TestUtils.run_rrule!!_test_cases(StableRNG, Val(:low_level_maths)) + + # These are all examples of signatures which we do _not_ want to make primitives, + # because they are very shallow wrappers around lower-level primitives for which we + # already have rules. + @testset "$T, $C" for T in [Float16, Float32, Float64], C in [DefaultCtx, MinimalCtx] + @test !is_primitive(C, Tuple{typeof(+), T}) + @test !is_primitive(C, Tuple{typeof(-), T}) + @test !is_primitive(C, Tuple{typeof(abs2), T}) + @test !is_primitive(C, Tuple{typeof(inv), T}) + @test !is_primitive(C, Tuple{typeof(abs), T}) + + @test !is_primitive(C, Tuple{typeof(+), T, T}) + @test !is_primitive(C, Tuple{typeof(-), T, T}) + @test !is_primitive(C, Tuple{typeof(*), T, T}) + @test !is_primitive(C, Tuple{typeof(/), T, T}) + @test !is_primitive(C, Tuple{typeof(\), T, T}) + end end