Skip to content

Commit

Permalink
Some very basic Float32 support (#227)
Browse files Browse the repository at this point in the history
* Some basics for float32

* Improve error message

* Bump patch version

* Tweak error

* Make CI not fail when codecov fails
  • Loading branch information
willtebbutt authored Aug 14, 2024
1 parent 905b958 commit 9d248cd
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:
with:
files: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: true
fail_ci_if_error: false
perf:
name: "Performance (${{ matrix.perf_group }})"
runs-on: ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Tapir"
uuid = "07d77754-e150-4737-8c94-cd238a1fb45b"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.2.35"
version = "0.2.36"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
19 changes: 18 additions & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ __verify_sig(rule::SafeRRule, fx) = __verify_sig(rule.rule, fx)
# check here.
__verify_sig(::typeof(rrule!!), fx::Tuple) = nothing

struct ValueAndGradientReturnTypeError <: Exception
msg::String
end

"""
__value_and_gradient!!(rule, f::CoDual, x::CoDual...)
Expand Down Expand Up @@ -63,7 +67,20 @@ Tapir.__value_and_gradient!!(
```
"""
function __value_and_gradient!!(rule::R, fx::Vararg{CoDual, N}) where {R, N}
return __value_and_pullback!!(rule, 1.0, fx...)
fx_fwds = tuple_map(to_fwds, fx)
__verify_sig(rule, fx_fwds)
out, pb!! = rule(fx_fwds...)
y = primal(out)
if !(y isa IEEEFloat)
throw(ValueAndGradientReturnTypeError(
"When calling __value_and_gradient!!, return value of primal must be a " *
"subtype of IEEEFloat. Instead, found value of type $(typeof(y))."
))
end
@assert y isa IEEEFloat
@assert tangent(out) isa NoFData

return y, tuple_map((f, r) -> tangent(fdata(tangent(f)), r), fx, pb!!(one(y)))
end

"""
Expand Down
68 changes: 35 additions & 33 deletions src/rrules/low_level_maths.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,66 +5,67 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing)
continue # Skip rules for methods not defined in the current scope
end
(f == :rem2pi || f == :ldexp) && continue # not designed for Float64s
(f in [:+, :*, :sin, :cos]) && continue # use intrinsics instead
P = Float64
(f in [:+, :*, :sin, :cos]) && continue # use other functionality to implement these
if arity == 1
dx = DiffRules.diffrule(M, f, :x)
pb_name = Symbol("$(M).$(f)_pb!!")
@eval begin
@is_primitive MinimalCtx Tuple{typeof($M.$f), $P}
function rrule!!(::CoDual{typeof($M.$f)}, _x::CoDual{$P})
@is_primitive MinimalCtx Tuple{typeof($M.$f), P} where {P<:IEEEFloat}
function rrule!!(::CoDual{typeof($M.$f)}, _x::CoDual{P}) where {P<:IEEEFloat}
x = primal(_x) # needed for dx expression
$pb_name(ȳ) = NoRData(), ȳ * $dx
$pb_name(ȳ::P) = NoRData(), ȳ * $dx
return CoDual(($M.$f)(x), NoFData()), $pb_name
end
end
elseif arity == 2
da, db = DiffRules.diffrule(M, f, :a, :b)
pb_name = Symbol("$(M).$(f)_pb!!")
@eval begin
@is_primitive MinimalCtx Tuple{typeof($M.$f), $P, $P}
function rrule!!(::CoDual{typeof($M.$f)}, _a::CoDual{$P}, _b::CoDual{$P})
@is_primitive MinimalCtx Tuple{typeof($M.$f), P, P} where {P<:IEEEFloat}
function rrule!!(
::CoDual{typeof($M.$f)}, _a::CoDual{P}, _b::CoDual{P}
) where {P<:IEEEFloat}
a = primal(_a)
b = primal(_b)
$pb_name(ȳ) = NoRData(), ȳ * $da, ȳ * $db
$pb_name(ȳ::P) = NoRData(), ȳ * $da, ȳ * $db
return CoDual(($M.$f)(a, b), NoFData()), $pb_name
end
end
end
end

@is_primitive MinimalCtx Tuple{typeof(sin), Float64}
function rrule!!(::CoDual{typeof(sin), NoFData}, x::CoDual{Float64, NoFData})
@is_primitive MinimalCtx Tuple{typeof(sin), <:IEEEFloat}
function rrule!!(::CoDual{typeof(sin), NoFData}, x::CoDual{P, NoFData}) where {P<:IEEEFloat}
s, c = sincos(primal(x))
sin_pullback!!(dy::Float64) = NoRData(), dy * c
sin_pullback!!(dy::P) = NoRData(), dy * c
return CoDual(s, NoFData()), sin_pullback!!
end

@is_primitive MinimalCtx Tuple{typeof(cos), Float64}
function rrule!!(::CoDual{typeof(cos), NoFData}, x::CoDual{Float64, NoFData})
@is_primitive MinimalCtx Tuple{typeof(cos), <:IEEEFloat}
function rrule!!(::CoDual{typeof(cos), NoFData}, x::CoDual{P, NoFData}) where {P<:IEEEFloat}
s, c = sincos(primal(x))
cos_pullback!!(dy::Float64) = NoRData(), -dy * s
cos_pullback!!(dy::P) = NoRData(), -dy * s
return CoDual(c, NoFData()), cos_pullback!!
end

rand_inputs(rng, f, arity) = randn(rng, arity)
rand_inputs(rng, ::typeof(acosh), _) = (rand(rng) + 1 + 1e-3, )
rand_inputs(rng, ::typeof(asech), _) = (rand(rng) * 0.9, )
rand_inputs(rng, ::typeof(log), _) = (rand(rng) + 1e-3, )
rand_inputs(rng, ::typeof(asin), _) = (rand(rng) * 0.9, )
rand_inputs(rng, ::typeof(asecd), _) = (rand(rng) + 1, )
rand_inputs(rng, ::typeof(log2), _) = (rand(rng) + 1e-3, )
rand_inputs(rng, ::typeof(log10), _) = (rand(rng) + 1e-3, )
rand_inputs(rng, ::typeof(acscd), _) = (rand(rng) + 1 + 1e-3, )
rand_inputs(rng, ::typeof(log1p), _) = (rand(rng) + 1e-3, )
rand_inputs(rng, ::typeof(acsc), _) = (rand(rng) + 1 + 1e-3, )
rand_inputs(rng, ::typeof(atanh), _) = (2 * 0.9 * rand(rng) - 0.9, )
rand_inputs(rng, ::typeof(acoth), _) = (rand(rng) + 1 + 1e-3, )
rand_inputs(rng, ::typeof(asind), _) = (0.9 * rand(rng), )
rand_inputs(rng, ::typeof(asec), _) = (rand(rng) + 1.001, )
rand_inputs(rng, ::typeof(acosd), _) = (2 * 0.9 * rand(rng) - 0.9, )
rand_inputs(rng, ::typeof(acos), _) = (2 * 0.9 * rand(rng) - 0.9, )
rand_inputs(rng, ::typeof(sqrt), _) = (rand(rng) + 1e-3, )
rand_inputs(rng, P::Type{<:IEEEFloat}, f, arity) = randn(rng, P, arity)
rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(acosh), _) = (rand(rng) + 1 + 1e-3, )
rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(asech), _) = (rand(rng) * 0.9, )
rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(log), _) = (rand(rng) + 1e-3, )
rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(asin), _) = (rand(rng) * 0.9, )
rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(asecd), _) = (rand(rng) + 1, )
rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(log2), _) = (rand(rng) + 1e-3, )
rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(log10), _) = (rand(rng) + 1e-3, )
rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(acscd), _) = (rand(rng) + 1 + 1e-3, )
rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(log1p), _) = (rand(rng) + 1e-3, )
rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(acsc), _) = (rand(rng) + 1 + 1e-3, )
rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(atanh), _) = (2 * 0.9 * rand(rng) - 0.9, )
rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(acoth), _) = (rand(rng) + 1 + 1e-3, )
rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(asind), _) = (0.9 * rand(rng), )
rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(asec), _) = (rand(rng) + 1.001, )
rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(acosd), _) = (2 * 0.9 * rand(rng) - 0.9, )
rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(acos), _) = (2 * 0.9 * rand(rng) - 0.9, )
rand_inputs(rng, P::Type{<:IEEEFloat}, ::typeof(sqrt), _) = (rand(rng) + 1e-3, )

function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:low_level_maths})
rng = Xoshiro(123)
Expand All @@ -78,7 +79,8 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:low_level_mat
(f == :rem2pi || f == :ldexp || f == :(^)) && return
(f == :+ || f == :*) && return # use intrinsics instead
f = @eval $M.$f
push!(test_cases, Any[false, :stability, nothing, f, rand_inputs(rng, f, arity)...])
push!(test_cases, Any[false, :stability, nothing, f, rand_inputs(rng, Float64, f, arity)...])
push!(test_cases, Any[true, :stability, nothing, f, rand_inputs(rng, Float32, f, arity)...])
end
memory = Any[]
return test_cases, memory
Expand Down
24 changes: 24 additions & 0 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,28 @@
rule = build_rrule(foo, 5.0)
@test_throws ArgumentError value_and_pullback!!(rule, 1.0, foo, CoDual(5.0, 0.0))
end
@testset "value_and_gradient!!" begin
@testset "($(typeof(fargs))" for fargs in Any[
(sin, randn(Float64)),
(sin, randn(Float32)),
(x -> sin(cos(x)), randn(Float64)),
(x -> sin(cos(x)), randn(Float32)),
((x, y) -> x + sin(y), randn(Float64), randn(Float64)),
((x, y) -> x + sin(y), randn(Float32), randn(Float32)),
]
rule = build_rrule(fargs...)
f, args... = fargs
v, dfargs = value_and_gradient!!(rule, fargs...)
@test v == f(args...)
for (arg, darg) in zip(fargs, dfargs)
@test tangent_type(typeof(arg)) == typeof(darg)
end
end

rule = build_rrule(identity, (5.0, 4.0))
@test_throws(
Tapir.ValueAndGradientReturnTypeError,
value_and_gradient!!(rule, identity, (5.0, 4.0)),
)
end
end

2 comments on commit 9d248cd

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/113135

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.36 -m "<description of version>" 9d248cdc2eed5956d623f008d9c4e022d83dca16
git push origin v0.2.36

Please sign in to comment.