Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Disable more unsafe casting #162

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Tapir"
uuid = "07d77754-e150-4737-8c94-cd238a1fb45b"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.2.15"
version = "0.2.16"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
36 changes: 26 additions & 10 deletions src/rrules/builtins.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ end
# atomic_pointerswap

@intrinsic bitcast
function rrule!!(f::CoDual{typeof(bitcast)}, t::CoDual{Type{T}}, x) where {T}
function rrule!!(f::CoDual{typeof(bitcast)}, t::CoDual{Type{T}}, x::CoDual{V}) where {T, V}
if T <: IEEEFloat
msg = "It is not permissible to bitcast to a differentiable type during AD, as " *
"this risks dropping tangents, and therefore risks silently giving the wrong " *
"answer. If this call to bitcast appears as part of the implementation of a " *
"differentiable function, you should write a rule for this function, or modify " *
msg = "It is not permissible to bitcast to a differentiable type during " *
"AD, as this risks dropping tangents, and therefore risks silently giving the " *
" wrong answer. If this call to bitcast appears as part of the implementation of " *
"a differentiable function, you should write a rule for this function, or modify " *
"its implementation to avoid the bitcast."
throw(ArgumentError(msg))
end
Expand Down Expand Up @@ -193,8 +193,26 @@ end
# fpext -- maybe interesting

@inactive_intrinsic fpiseq
@inactive_intrinsic fptosi
@inactive_intrinsic fptoui

@intrinsic fptosi
function rrule!!(::CoDual{typeof(fptosi)}, ::CoDual...)
msg = "It is not permissible to cast a float to a signed integer in " *
"AD, as this risks dropping tangents, and therefore risks silently giving the " *
" wrong answer. If this call to Core.Intrinsics.fptosi appears as part of the " *
"implementation of a differentiable function, you should write a rule for this " *
"function, or modify its implementation to avoid this call."
throw(ArgumentError(msg))
end

@intrinsic fptoui
function rrule!!(::CoDual{typeof(fptoui)}, ::CoDual...)
msg = "It is not permissible to cast a float to an unsigned integer in " *
"AD, as this risks dropping tangents, and therefore risks silently giving the " *
" wrong answer. If this call to Core.Intrinsics.fptoui appears as part of the " *
"implementation of a differentiable function, you should write a rule for this " *
"function, or modify its implementation to avoid this call."
throw(ArgumentError(msg))
end

# fptrunc -- maybe interesting

Expand Down Expand Up @@ -739,7 +757,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins})
# atomic_pointerreplace -- NEEDS IMPLEMENTING AND TESTING
# atomic_pointerset -- NEEDS IMPLEMENTING AND TESTING
# atomic_pointerswap -- NEEDS IMPLEMENTING AND TESTING
(false, :stability, nothing, IntrinsicsWrappers.bitcast, Int64, 5.0),
(false, :stability, nothing, IntrinsicsWrappers.bitcast, UInt64, 5),
(false, :stability, nothing, IntrinsicsWrappers.bswap_int, 5),
(false, :stability, nothing, IntrinsicsWrappers.ceil_llvm, 4.1),
(
Expand Down Expand Up @@ -778,8 +796,6 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins})
(false, :stability, nothing, IntrinsicsWrappers.fma_float, 5.0, 4.0, 3.0),
# fpext -- NEEDS IMPLEMENTING AND TESTING
(false, :stability, nothing, IntrinsicsWrappers.fpiseq, 4.1, 4.0),
(false, :stability, nothing, IntrinsicsWrappers.fptosi, UInt32, 4.1),
(false, :stability, nothing, IntrinsicsWrappers.fptoui, Int32, 4.1),
# fptrunc -- maybe interesting
(true, :stability, nothing, IntrinsicsWrappers.have_fma, Float64),
(false, :stability, nothing, IntrinsicsWrappers.le_float, 4.1, 4.0),
Expand Down
18 changes: 17 additions & 1 deletion test/rrules/builtins.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,26 @@

TestUtils.run_rrule!!_test_cases(StableRNG, Val(:builtins))

@testset "Disable bitcast to differentiable type" begin
@testset "Disable casting to / from floats inside AD" begin
@test_throws(
ArgumentError,
rrule!!(zero_fcodual(bitcast), zero_fcodual(Float64), zero_fcodual(5))
)
@test_throws(
ArgumentError,
rrule!!(
zero_fcodual(Tapir.IntrinsicsWrappers.fptosi),
zero_fcodual(Int),
zero_fcodual(1.0),
),
)
@test_throws(
ArgumentError,
rrule!!(
zero_fcodual(Tapir.IntrinsicsWrappers.fptoui),
zero_fcodual(UInt),
zero_fcodual(1.0),
),
)
end
end
Loading