From a896e0d5a64593ce2cfe9ab1487597c25385358c Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 12 Aug 2024 14:31:43 +0100 Subject: [PATCH] Handle undefined type parameters in fcodual_type and implement rule for fpext (#225) * Fix undefined type parameter problem * Add support for fpext * Bump patch version * Remove redundant test case --- Project.toml | 2 +- src/codual.jl | 8 ++++++-- src/rrules/builtins.jl | 12 +++++++++--- test/codual.jl | 8 ++++++++ 4 files changed, 24 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 17b34597..88e30518 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Tapir" uuid = "07d77754-e150-4737-8c94-cd238a1fb45b" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.2.33" +version = "0.2.34" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/codual.jl b/src/codual.jl index f785be16..a9b52a9d 100644 --- a/src/codual.jl +++ b/src/codual.jl @@ -35,7 +35,9 @@ function codual_type(::Type{P}) where {P} return isconcretetype(P) ? CoDual{P, tangent_type(P)} : CoDual end -codual_type(::Type{Type{P}}) where {P} = CoDual{Type{P}, NoTangent} +function codual_type(p::Type{Type{P}}) where {P} + return @isdefined(P) ? CoDual{Type{P}, NoTangent} : CoDual{_typeof(p), NoTangent} +end struct NoPullback{R<:Tuple} r::R @@ -86,6 +88,8 @@ function fcodual_type(::Type{P}) where {P} return isconcretetype(P) ? CoDual{P, fdata_type(tangent_type(P))} : CoDual end -fcodual_type(::Type{Type{P}}) where {P} = CoDual{Type{P}, NoFData} +function fcodual_type(p::Type{Type{P}}) where {P} + return @isdefined(P) ? CoDual{Type{P}, NoFData} : CoDual{_typeof(p), NoFData} +end zero_rdata(x::CoDual) = zero_rdata(primal(x)) diff --git a/src/rrules/builtins.jl b/src/rrules/builtins.jl index b72267b4..3bcd9985 100644 --- a/src/rrules/builtins.jl +++ b/src/rrules/builtins.jl @@ -241,7 +241,13 @@ function rrule!!(::CoDual{typeof(fma_float)}, x, y, z) return CoDual(fma_float(_x, _y, primal(z)), NoFData()), fma_float_pullback!! end -# fpext -- maybe interesting +@intrinsic fpext +function rrule!!( + ::CoDual{typeof(fpext)}, ::CoDual{Type{Pext}}, x::CoDual{P} +) where {Pext<:IEEEFloat, P<:IEEEFloat} + fpext_adjoint!!(dy::Pext) = NoRData(), NoRData(), fptrunc(P, dy) + return zero_fcodual(fpext(Pext, primal(x))), fpext_adjoint!! +end @inactive_intrinsic fpiseq @inactive_intrinsic fptosi @@ -251,7 +257,7 @@ end function rrule!!( ::CoDual{typeof(fptrunc)}, ::CoDual{Type{Ptrunc}}, x::CoDual{P} ) where {Ptrunc<:IEEEFloat, P<:IEEEFloat} - fptrunc_adjoint!!(dy) = NoRData(), NoRData(), convert(P, dy) + fptrunc_adjoint!!(dy::Ptrunc) = NoRData(), NoRData(), convert(P, dy) return zero_fcodual(fptrunc(Ptrunc, primal(x))), fptrunc_adjoint!! end @@ -823,7 +829,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) (false, :stability, nothing, IntrinsicsWrappers.flipsign_int, 4, -3), (false, :stability, nothing, IntrinsicsWrappers.floor_llvm, 4.1), (false, :stability, nothing, IntrinsicsWrappers.fma_float, 5.0, 4.0, 3.0), - # fpext -- NEEDS IMPLEMENTING AND TESTING + (true, :stability_and_allocs, nothing, IntrinsicsWrappers.fpext, Float64, 5f0), (false, :stability, nothing, IntrinsicsWrappers.fpiseq, 4.1, 4.0), (false, :stability, nothing, IntrinsicsWrappers.fptosi, UInt32, 4.1), (false, :stability, nothing, IntrinsicsWrappers.fptoui, Int32, 4.1), diff --git a/test/codual.jl b/test/codual.jl index 0f3508fe..45671996 100644 --- a/test/codual.jl +++ b/test/codual.jl @@ -7,6 +7,14 @@ @test codual_type(Real) == CoDual @test codual_type(Any) == CoDual @test codual_type(Type{UnitRange{Int}}) == CoDual{Type{UnitRange{Int}}, NoTangent} + @test ==( + codual_type(Type{Tuple{T}} where {T}), + CoDual{Type{Type{Tuple{T}} where {T}}, NoTangent}, + ) + @test ==( + Tapir.fcodual_type(Type{Tuple{T}} where {T}), + CoDual{Type{Type{Tuple{T}} where {T}}, NoFData}, + ) @test(==( codual_type(Union{Float64, Int}), Union{CoDual{Float64, Float64}, CoDual{Int, NoTangent}},