Skip to content

Commit

Permalink
Remove redundant rule (#39)
Browse files Browse the repository at this point in the history
* Remove redundant code and modify tests

* Bump Umlaut compat bound

* Bump Umlaut version again

* Fix splatting
  • Loading branch information
willtebbutt authored Nov 21, 2023
1 parent 8ee37f5 commit 2f319eb
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ PDMats = "0.11"
Setfield = "1"
SpecialFunctions = "2"
StableRNGs = "1"
Umlaut = "0.5.7"
Umlaut = "0.6.1"
julia = "1"

[extras]
Expand Down
4 changes: 2 additions & 2 deletions src/Taped.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ using
Setfield,
Umlaut

import Umlaut: isprimitive, Frame, Tracer, __foreigncall__
import Umlaut: isprimitive, Frame, Tracer, __foreigncall__, __to_tuple__

using Base:
IEEEFloat, unsafe_convert, unsafe_pointer_to_objref, pointer_from_objref, arrayref,
arrayset
using Core: Intrinsics, bitcast
using Core: Intrinsics, bitcast, SimpleVector, svec
using FunctionWrappers: FunctionWrapper
using LinearAlgebra.BLAS: @blasfunc, BlasInt, trsm!

Expand Down
31 changes: 22 additions & 9 deletions src/rrules/umlaut_internals_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,27 @@ function rrule!!(::CoDual{typeof(Umlaut.check_variable_length)}, args::Vararg{An
return CoDual(v, zero_tangent(v)), NoPullback()
end

# Umlaut occassionally pushes `getindex` onto the tape.
# Easiest just to handle it like this.
# Might remove at a later date when `Umlaut.primitivize` works properly.
isprimitive(::RMC, ::typeof(getindex), ::Tuple, ::Int) = true
function rrule!!(::CoDual{typeof(getindex)}, x::CoDual{<:Tuple}, i::CoDual{Int})
function getindex_pullback!!(dy, df, dx, ::NoTangent)
dx = ntuple(n -> n == primal(i) ? increment!!(dx[n], dy) : dx[n], length(dx))
return df, dx, NoTangent()
# This is the thing that Umlaut uses in order to splat. Must be a primitive.
isprimitive(::RMC, ::typeof(__to_tuple__), x) = true
function rrule!!(::CoDual{typeof(__to_tuple__)}, x::CoDual{<:Tuple})
__to_tuple_pb!!(dy, df, dx) = df, increment!!(dx, dy)
return x, __to_tuple_pb!!
end
function rrule!!(::CoDual{typeof(__to_tuple__)}, x::CoDual{<:NamedTuple{A}}) where {A}
__to_tuple_named_tuple_pb!!(dy, df, dx) = df, increment!!(dx, NamedTuple{A}(dy))
return CoDual(Tuple(primal(x)), Tuple(tangent(x))), __to_tuple_named_tuple_pb!!
end
function rrule!!(::CoDual{typeof(__to_tuple__)}, x::CoDual{<:Vector, <:Vector{T}}) where {T}
__to_tuple_vec_pb!!(dy, df, dx) = df, increment!!(dx, T[a for a in dy])
return CoDual(__to_tuple__(primal(x)), __to_tuple__(tangent(x))), __to_tuple_vec_pb!!
end
function rrule!!(::CoDual{typeof(__to_tuple__)}, x::CoDual{Int})
return zero_codual((primal(x), )), NoPullback()
end
function rrule!!(::CoDual{typeof(__to_tuple__)}, x::CoDual{Core.SimpleVector})
function __to_tuple_svec_pb!!(dy, df, dx)
return df, increment!!(dx, Any[a for a in dy])
end
return CoDual(primal(x)[primal(i)], tangent(x)[primal(i)]), getindex_pullback!!
y = CoDual(__to_tuple__(primal(x)), __to_tuple__(tangent(x)))
return y, __to_tuple_svec_pb!!
end
19 changes: 19 additions & 0 deletions src/tangents.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ tangent_type(::Type{Module}) = NoTangent

tangent_type(::Type{Nothing}) = NoTangent

tangent_type(::Type{SimpleVector}) = Vector{Any}

tangent_type(::Type{P}) where {P<:Union{UInt8, UInt16, UInt32, UInt64, UInt128}} = NoTangent

tangent_type(::Type{P}) where {P<:Union{Int8, Int16, Int32, Int64, Int128}} = NoTangent
Expand Down Expand Up @@ -206,6 +208,11 @@ anything other than that which this function returns.
zero_tangent(x)
@inline zero_tangent(::Union{Int8, Int16, Int32, Int64, Int128}) = NoTangent()
@inline zero_tangent(x::IEEEFloat) = zero(x)
function zero_tangent(x::SimpleVector)
return map!(Vector{Any}(undef, length(x)), eachindex(x)) do n
return zero_tangent(x[n])
end
end
@inline function zero_tangent(x::Array{P, N}) where {P, N}
y = Array{tangent_type(P), N}(undef, size(x)...)
v = _map_if_assigned!(zero_tangent, y, x)
Expand Down Expand Up @@ -261,6 +268,11 @@ function randn_tangent(rng::AbstractRNG, x::Array{T, N}) where {T, N}
dx = Array{tangent_type(T), N}(undef, size(x)...)
return _map_if_assigned!(Base.Fix1(randn_tangent, rng), dx, x)
end
function randn_tangent(rng::AbstractRNG, x::SimpleVector)
return map!(Vector{Any}(undef, length(x)), eachindex(x)) do n
return randn_tangent(rng, x[n])
end
end
function randn_tangent(rng::AbstractRNG, x::Union{Tuple, NamedTuple})
return map(x -> randn_tangent(rng, x), x)
end
Expand Down Expand Up @@ -482,6 +494,9 @@ function _add_to_primal(x::Array{P, N}, t::Array{<:Any, N}) where {P, N}
x′ = Array{P, N}(undef, size(x)...)
return _map_if_assigned!(_add_to_primal, x′, x, t)
end
function _add_to_primal(x::SimpleVector, t::Vector{Any})
return svec(map(n -> _add_to_primal(x[n], t[n]), eachindex(x))...)
end
_add_to_primal(x::Tuple, t::Tuple) = _map(_add_to_primal, x, t)
_add_to_primal(x::NamedTuple, t::NamedTuple) = _map(_add_to_primal, x, t)
_add_to_primal(x, ::Tangent{NamedTuple{(), Tuple{}}}) = x
Expand Down Expand Up @@ -519,6 +534,9 @@ function _diff(p::P, q::P) where {V, N, P<:Array{V, N}}
t = Array{tangent_type(V), N}(undef, size(p))
return _map_if_assigned!(_diff, t, p, q)
end
function _diff(p::P, q::P) where {P<:SimpleVector}
return Any[_diff(a, b) for (a, b) in zip(p, q)]
end
_diff(p::P, q::P) where {P<:Union{Tuple, NamedTuple}} = _map(_diff, p, q)

function _containerlike_diff(p::P, q::P) where {P}
Expand All @@ -541,3 +559,4 @@ end
@generated function might_be_active(::Type{<:Array{P}}) where {P}
return :(return $(might_be_active(P)))
end
might_be_active(::Type{SimpleVector}) = true
32 changes: 31 additions & 1 deletion src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ function has_equal_data(x::T, y::T; equal_undefs=true) where {T<:Array}
end
return all(equality)
end
has_equal_data(x::T, y::T) where {T<:Union{Float16, Float32, Float64}} = isapprox(x, y)
function has_equal_data(x::T, y::T; equal_undefs=true) where {T}
isprimitivetype(T) && return isequal(x, y)
return all(map(
Expand Down Expand Up @@ -78,6 +79,15 @@ function populate_address_map!(m::AddressMap, p::Array, t::Array)
return m
end

function populate_address_map!(m::AddressMap, p::Core.SimpleVector, t::Vector{Any})
k = pointer_from_objref(p)
v = pointer_from_objref(t)
haskey(m, k) && (@assert m[k] == v)
m[k] = v
foreach(n -> populate_address_map!(m, p[n], t[n]), eachindex(p))
return m
end

populate_address_map!(m::AddressMap, p::Union{Core.TypeName, Type, Symbol, String}, t) = m

"""
Expand Down Expand Up @@ -305,7 +315,27 @@ function test_taped_rrule!!(rng::AbstractRNG, f, x...; interface_only=false, kwa

# Check that f_t remains a faithful representation of the original function.
if !interface_only
@test has_equal_data(f(deepcopy(x)...), play!(f_t.tape, f, deepcopy(x)...))
xs_lhs = deepcopy(x)
xs_rhs = deepcopy(x)
y_lhs = f(xs_lhs...)
y_rhs = play!(f_t.tape, f, xs_rhs...)
if !has_equal_data(y_lhs, y_rhs) || !has_equal_data(xs_lhs, xs_rhs)
println("y_lhs")
display(y_lhs)
println()
println("y_rhs")
display(y_rhs)
println()

println("xs_lhs")
display(xs_lhs)
println()
println("xs_rhs")
display(xs_rhs)
println()
end
@test has_equal_data(y_lhs, y_rhs)
@test has_equal_data(xs_lhs, xs_rhs)
end
end

Expand Down
28 changes: 23 additions & 5 deletions test/rrules/umlaut_internals_rules.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
multiarg_fn(x) = only(x)
multiarg_fn(x, y) = only(x) + only(y)
multiarg_fn(x, y, z) = only(x) + only(y) + only(z)

vararg_fn(x) = multiarg_fn(x...)

@testset "umlaut_internals_rules" begin

@testset "misc utility" begin
Expand All @@ -10,7 +16,7 @@

_x = Ref(5.0) # data used in tests which aren't protected by GC.
_dx = Ref(4.0)
@testset "$f, $(typeof(x))" for (interface_only, perf_flag, f, x...) in [
@testset "$f, $(typeof(x))" for (interface_only, perf_flag, f, x...) in Any[

# IR-node workarounds:
(false, :stability, __new__, UnitRange{Int}, 5, 9),
Expand Down Expand Up @@ -44,10 +50,12 @@
),
(false, :stability, __new__, Tuple{Float64, Float64}, 5.0, 4.0),

# Umlaut internals -- getindex occassionally gets pushed onto the tape.
(false, :none, getindex, (5.0, 5.0), 2),
(false, :none, getindex, (randn(5), 2), 1),
(false, :none, getindex, (2, randn(5)), 1),
# Splatting primitives:
(false, :stability, __to_tuple__, (5.0, 4)),
(false, :stability, __to_tuple__, (a=5.0, b=4)),
(false, :stability, __to_tuple__, 5),
(false, :none, __to_tuple__, svec(5.0)),
(false, :none, __to_tuple__, [5.0, 4.0]),

# Umlaut limitations:
(false, :none, eltype, randn(5)),
Expand All @@ -57,4 +65,14 @@
]
test_rrule!!(Xoshiro(123456), f, x...; interface_only, perf_flag)
end
@testset for (interface_only, f, x...) in Any[
(false, x -> multiarg_fn(x...), 1),
(false, x -> multiarg_fn(x...), [1.0, 2.0]),
(false, x -> multiarg_fn(x...), [5.0, 4]),
(false, x -> multiarg_fn(x...), (5.0, 4)),
(false, x -> multiarg_fn(x...), (a=5.0, b=4)),
(false, x -> multiarg_fn(x...), svec(5.0, 4.0)),
]
test_taped_rrule!!(Xoshiro(123456), f, map(deepcopy, x)...; interface_only)
end
end
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using

using Base: unsafe_load, pointer_from_objref
using Base.Iterators: product
using Core: bitcast
using Core: bitcast, svec
using Core.Intrinsics: pointerref, pointerset
using FunctionWrappers: FunctionWrapper

Expand Down Expand Up @@ -44,7 +44,7 @@ using Taped:
rebind,
build_tangent

using Taped.Umlaut: __new__
using Taped.Umlaut: __new__, __to_tuple__

using .TestUtils:
test_rrule!!,
Expand Down
1 change: 1 addition & 0 deletions test/tangents.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
map(Float16, (5.0, 4.0, 3.1, 7.1)),
(5f0, 4f0, 3f0, 7f0),
(5.1, 4.0, 3.0, 7.0),
(svec(5.0), Any[4.0], Any[3.0], Any[7.0]),
([3.0, 2.0], [1.0, 2.0], [2.0, 3.0], [3.0, 5.0]),
(
[1, 2],
Expand Down

0 comments on commit 2f319eb

Please sign in to comment.