From 1ed7b400e0013cee5ba79e21226ea0469737f6c8 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 2 Sep 2024 22:33:52 +0100 Subject: [PATCH] NoPullback for hash (#240) * Add failing tests * Make tests pass * Bump patch * Tweak CoDual _copy method * Add failing test * Fix rdata_type method * Tidy up + fix fdata_type --- Project.toml | 2 +- src/codual.jl | 2 +- src/fwds_rvs_data.jl | 12 ++++++++---- src/rrules/foreigncall.jl | 11 ++++++----- test/fwds_rvs_data.jl | 15 +++++++++++++++ 5 files changed, 31 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index d3590d17c..6d64d6deb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Tapir" uuid = "07d77754-e150-4737-8c94-cd238a1fb45b" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.2.43" +version = "0.2.44" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/codual.jl b/src/codual.jl index ec167d1a0..dac1b2c1b 100644 --- a/src/codual.jl +++ b/src/codual.jl @@ -15,7 +15,7 @@ end primal(x::CoDual) = x.x tangent(x::CoDual) = x.dx Base.copy(x::CoDual) = CoDual(copy(primal(x)), copy(tangent(x))) -_copy(x::P) where {P<:CoDual} = P(_copy(x.x), _copy(x.dx)) +_copy(x::P) where {P<:CoDual} = x """ zero_codual(x) diff --git a/src/fwds_rvs_data.jl b/src/fwds_rvs_data.jl index 1c4adf4f3..6f7a382b7 100644 --- a/src/fwds_rvs_data.jl +++ b/src/fwds_rvs_data.jl @@ -200,8 +200,10 @@ fdata_type(::Type{T}) where {T<:Ptr} = T isa(P, Union) && return Union{fdata_type(P.a), fdata_type(P.b)} isempty(P.parameters) && return NoFData isa(last(P.parameters), Core.TypeofVararg) && return Any - all(p -> fdata_type(p) == NoFData, P.parameters) && return NoFData - return Tuple{map(fdata_type, fieldtypes(P))...} + nofdata_tt = Tuple{Vararg{NoFData, length(P.parameters)}} + fdata_tt = Tuple{map(fdata_type, fieldtypes(P))...} + fdata_tt <: nofdata_tt && return NoFData + return nofdata_tt <: fdata_tt ? Union{NoFData, fdata_tt} : fdata_tt end @generated function fdata_type(::Type{NamedTuple{names, T}}) where {names, T<:Tuple} @@ -444,8 +446,10 @@ rdata_type(::Type{<:Ptr}) = NoRData isa(P, Union) && return Union{rdata_type(P.a), rdata_type(P.b)} isempty(P.parameters) && return NoRData isa(last(P.parameters), Core.TypeofVararg) && return Any - all(p -> rdata_type(p) == NoRData, P.parameters) && return NoRData - return Tuple{map(rdata_type, fieldtypes(P))...} + nordata_tt = Tuple{Vararg{NoRData, length(P.parameters)}} + rdata_tt = Tuple{map(rdata_type, fieldtypes(P))...} + rdata_tt <: nordata_tt && return NoRData + return nordata_tt <: rdata_tt ? Union{NoRData, rdata_tt} : rdata_tt end function rdata_type(::Type{NamedTuple{names, T}}) where {names, T<:Tuple} diff --git a/src/rrules/foreigncall.jl b/src/rrules/foreigncall.jl index 575132d2c..337225b4d 100644 --- a/src/rrules/foreigncall.jl +++ b/src/rrules/foreigncall.jl @@ -489,13 +489,12 @@ function rrule!!(f::CoDual{<:Type{UnionAll}}, x::CoDual{<:TypeVar}, y::CoDual{<: return zero_fcodual(UnionAll(primal(x), primal(y))), NoPullback(f, x, y) end -@is_primitive MinimalCtx Tuple{typeof(hash), Union{String, SubString{String}}, UInt} -function rrule!!( - f::CoDual{typeof(hash)}, s::CoDual{P}, h::CoDual{UInt} -) where {P<:Union{String, SubString{String}}} - return zero_fcodual(hash(primal(s), primal(h))), NoPullback(f, s, h) +@is_primitive MinimalCtx Tuple{typeof(hash), Vararg} +function rrule!!(f::CoDual{typeof(hash)}, x::Vararg{CoDual, N}) where {N} + return zero_fcodual(hash(map(primal, x)...)), NoPullback(f, x...) end + function rrule!!( ::CoDual{typeof(_foreigncall_)}, ::CoDual{Val{:jl_string_ptr}}, args::Vararg{CoDual, N} ) where {N} @@ -613,6 +612,8 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:foreigncall}) (false, :stability, nothing, deepcopy, (a=5.0, b=randn(5))), (false, :none, nothing, UnionAll, TypeVar(:a), Real), (false, :none, nothing, hash, "5", UInt(3)), + (false, :none, nothing, hash, Float64, UInt(5)), + (false, :none, nothing, hash, Float64), ( true, :none, nothing, _foreigncall_, diff --git a/test/fwds_rvs_data.jl b/test/fwds_rvs_data.jl index 2434b8e06..5a5e90bfa 100644 --- a/test/fwds_rvs_data.jl +++ b/test/fwds_rvs_data.jl @@ -3,6 +3,21 @@ module FwdsRvsDataTestResources end @testset "fwds_rvs_data" begin + @testset "fdata_type / rdata_type($P)" for (P, F, R) in Any[ + ( + Tuple{Any, Vector{Float64}}, + Tuple{Any, Vector{Float64}}, + Union{NoRData, Tuple{Any, NoRData}}, + ), + ( + Tuple{Any, Float64}, + Union{NoFData, Tuple{Any, NoFData}}, + Tuple{Any, Float64}, + ), + ] + @test fdata_type(tangent_type(P)) == F + @test rdata_type(tangent_type(P)) == R + end @testset "$(typeof(p))" for (_, p, _...) in Tapir.tangent_test_cases() TestUtils.test_fwds_rvs_data(Xoshiro(123456), p) end