Skip to content

Commit

Permalink
Small tweaks (#22)
Browse files Browse the repository at this point in the history
* Remove set_field_to_zero

* Check for zero

* Remove check

* Rename shadow to tangent
  • Loading branch information
willtebbutt authored Oct 16, 2023
1 parent b9210fb commit 4605b44
Show file tree
Hide file tree
Showing 17 changed files with 116 additions and 150 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ Docs incoming soon.
# Known Limitations:

- If a (mutable) data structure `x` contains a circular reference, it will not be possible to construct a `zero_tangent` / `random_tangent` `MutableTangent` to it -- an infinite recursion will occur. It should be possible, however, to differentiate through its construction. If you find this to be a problem in practice, please open an issue.
- `zero_tangent` and `random_tangent` do not work for pointers, because we don't know how large a chunk of memory a given pointer points to, so cannot allocate a corresponding chunk of shadow memory. Your best bet when testing `rrule!!`s for things involving pointers is currently to do integration testing. See the tests for blas functionality for examples.
- `zero_tangent` and `random_tangent` do not work for pointers, because we don't know how large a chunk of memory a given pointer points to, so cannot allocate a corresponding chunk of tangent memory. Your best bet when testing `rrule!!`s for things involving pointers is currently to do integration testing. See the tests for blas functionality for examples.
- If you pass active data through a global variable, AD will fail. Furthermore / worse still, the failures will probably be silent.
3 changes: 1 addition & 2 deletions src/Taped.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ include(joinpath("rrules", "unrolled_function.jl"))

export
primal,
shadow,
tangent,
randn_tangent,
increment!!,
increment_field!!,
Expand All @@ -44,7 +44,6 @@ export
MutableTangent,
PossiblyUninitTangent,
set_to_zero!!,
set_field_to_zero!!,
tangent_type,
zero_tangent,
_scale,
Expand Down
15 changes: 11 additions & 4 deletions src/reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,15 @@ function CoDual(x::Type{P}, dx::NoTangent) where {P}
end

primal(x::CoDual) = x.x
shadow(x::CoDual) = x.dx
Base.copy(x::CoDual) = CoDual(copy(primal(x)), copy(shadow(x)))
tangent(x::CoDual) = x.dx
Base.copy(x::CoDual) = CoDual(copy(primal(x)), copy(tangent(x)))

"""
zero_codual(x)
Equivalent to `CoDual(x, zero_tangent(x))`.
"""
zero_codual(x) = CoDual(x, zero_tangent(x))

"""
uninit_codual(x)
Expand All @@ -27,7 +34,7 @@ See implementation for details, as this function is subject to change.
"""
uninit_codual(x) = CoDual(x, uninit_tangent(x))

set_shadow!!(x::CoDual, dx) = CoDual(primal(x), increment!!(set_to_zero!!(shadow(x)), dx))
set_tangent!!(x::CoDual, dx) = CoDual(primal(x), increment!!(set_to_zero!!(tangent(x)), dx))

function verify_codual_type(::CoDual{P, T}) where {P, T}
Tt = tangent_type(P)
Expand Down Expand Up @@ -61,7 +68,7 @@ rebind(x) = x

rebind_pb!!(ȳ, f̄, x̄) = f̄, increment!!(x̄, ȳ)
function rrule!!(::CoDual{typeof(rebind)}, x::CoDual)
return CoDual(primal(x), rebind_tangent(shadow(x))), rebind_pb!!
return CoDual(primal(x), rebind_tangent(tangent(x))), rebind_pb!!
end

isprimitive(::RMC, ::typeof(rebind), x) = true
2 changes: 1 addition & 1 deletion src/rrules/avoiding_non_differentiable_code.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
# https://github.com/JuliaLang/julia/blob/9f9e989f241fad1ae03c3920c20a93d8017a5b8f/base/pointer.jl#L282
isprimitive(::RMC, ::typeof(Base.:(+)), x::Ptr, y::Integer) = true
function rrule!!(::CoDual{typeof(Base.:(+))}, x::CoDual{<:Ptr}, y::CoDual{<:Integer})
return CoDual(primal(x) + primal(y), shadow(x) + primal(y)), NoPullback()
return CoDual(primal(x) + primal(y), tangent(x) + primal(y)), NoPullback()
end
2 changes: 1 addition & 1 deletion src/rrules/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ for (fname, elty) in ((:dscal_, :Float64), (:sscal_, :Float32))
_incx = unsafe_load(primal(incx))
_DA = unsafe_load(primal(DA))
_DX = unsafe_wrap(Vector{$elty}, primal(DX), _n * _incx)
_DX_s = unsafe_wrap(Vector{$elty}, shadow(DX), _n * _incx)
_DX_s = unsafe_wrap(Vector{$elty}, tangent(DX), _n * _incx)

inds = 1:_incx:(_incx * _n)
DX_copy = _DX[inds]
Expand Down
42 changes: 21 additions & 21 deletions src/rrules/builtins.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ module IntrinsicsWrappers
import Umlaut: isprimitive
using Core: Intrinsics
import ..Taped:
rrule!!, CoDual, primal, shadow, zero_tangent, isprimitive, RMC, NoPullback,
rrule!!, CoDual, primal, tangent, zero_tangent, isprimitive, RMC, NoPullback,
tangent_type, increment!!

# Note: performance is not considered _at_ _all_ in this implementation.
function rrule!!(f::CoDual{<:Core.IntrinsicFunction}, args...)
return rrule!!(CoDual(translate(Val(primal(f))), shadow(f)), args...)
return rrule!!(CoDual(translate(Val(primal(f))), tangent(f)), args...)
end

macro intrinsic(name)
Expand Down Expand Up @@ -89,7 +89,7 @@ function rrule!!(::CoDual{typeof(bitcast)}, T, x)
_x = primal(x)
v = bitcast(_T, _x)
if _T <: Ptr && _x isa Ptr
dv = bitcast(Ptr{tangent_type(eltype(_T))}, shadow(x))
dv = bitcast(Ptr{tangent_type(eltype(_T))}, tangent(x))
else
dv = zero_tangent(v)
end
Expand Down Expand Up @@ -244,7 +244,7 @@ function rrule!!(::CoDual{typeof(pointerref)}, x, y, z)
_x = primal(x)
_y = primal(y)
_z = primal(z)
x_s = shadow(x)
x_s = tangent(x)
a = CoDual(pointerref(_x, _y, _z), pointerref(x_s, _y, _z))
function pointerref_pullback!!(da, df, dx, dy, dz)
dx_v = pointerref(dx, _y, _z)
Expand All @@ -261,15 +261,15 @@ function rrule!!(::CoDual{typeof(pointerset)}, p, x, idx, z)
_idx = primal(idx)
_z = primal(z)
old_value = pointerref(_p, _idx, _z)
old_shadow = pointerref(shadow(p), _idx, _z)
old_tangent = pointerref(tangent(p), _idx, _z)
function pointerset_pullback!!(_, df, dp, dx, didx, dz)
dx_new = increment!!(dx, pointerref(dp, _idx, _z))
pointerset(_p, old_value, _idx, _z)
pointerset(dp, old_shadow, _idx, _z)
pointerset(dp, old_tangent, _idx, _z)
return df, dp, dx_new, didx, dz
end
pointerset(_p, primal(x), _idx, _z)
pointerset(shadow(p), shadow(x), _idx, _z)
pointerset(tangent(p), tangent(x), _idx, _z)
return p, pointerset_pullback!!
end

Expand Down Expand Up @@ -387,7 +387,7 @@ function rrule!!(
return df, dinbounds, dx, dinds...
end
_y = arrayref(_inbounds, primal(x), _inds...)
dy = arrayref(_inbounds, shadow(x), _inds...)
dy = arrayref(_inbounds, tangent(x), _inds...)
return CoDual(_y, dy), arrayref_pullback!!
end

Expand All @@ -402,9 +402,9 @@ function rrule!!(
_inds = map(primal, inds)
to_save = isassigned(primal(A), _inds...)
old_A_v = to_save ? arrayref(_inbounds, primal(A), _inds...) : nothing
old_A_v_t = to_save ? arrayref(_inbounds, shadow(A), _inds...) : nothing
old_A_v_t = to_save ? arrayref(_inbounds, tangent(A), _inds...) : nothing
arrayset(_inbounds, primal(A), primal(v), _inds...)
arrayset(_inbounds, shadow(A), shadow(v), _inds...)
arrayset(_inbounds, tangent(A), tangent(v), _inds...)
function setindex_pullback!!(dA::TdA, df, dinbounds, dA2::TdA, dv, dinds::NoTangent...)
dv_new = increment!!(dv, arrayref(_inbounds, dA, _inds...))
to_save && arrayset(_inbounds, primal(A), old_A_v, _inds...)
Expand Down Expand Up @@ -459,7 +459,7 @@ function rrule!!(::CoDual{typeof(getfield)}, value::CoDual, name::CoDual)
end
y = CoDual(
getfield(primal(value), _name),
_get_shadow_field(primal(value), shadow(value), _name),
_get_tangent_field(primal(value), tangent(value), _name),
)
return y, getfield_pullback
end
Expand All @@ -474,16 +474,16 @@ function rrule!!(::CoDual{typeof(getfield)}, value::CoDual, name::CoDual, order:
_order = _order isa Expr ? true : _order
y = CoDual(
getfield(primal(value), _name, _order),
_get_shadow_field(primal(value), shadow(value), _name, _order),
_get_tangent_field(primal(value), tangent(value), _name, _order),
)
return y, getfield_pullback
end

_get_shadow_field(_, shadow, f...) = getfield(shadow, f...)
function _get_shadow_field(_, shadow::Union{Tangent, MutableTangent}, f...)
return _value(getfield(shadow.fields, f...))
_get_tangent_field(_, tangent, f...) = getfield(tangent, f...)
function _get_tangent_field(_, tangent::Union{Tangent, MutableTangent}, f...)
return _value(getfield(tangent.fields, f...))
end
_get_shadow_field(primal, shadow::NoTangent, f...) = uninit_tangent(getfield(primal, f...))
_get_tangent_field(primal, ::NoTangent, f...) = uninit_tangent(getfield(primal, f...))

_increment_field!!(x, y, f) = increment_field!!(x, y, f)
_increment_field!!(x::NoTangent, y, f) = x
Expand Down Expand Up @@ -529,17 +529,17 @@ function rrule!!(::CoDual{typeof(setfield!)}, value, name, x)
_name = primal(name)
save = isdefined(primal(value), _name)
old_x = save ? getfield(primal(value), _name) : nothing
old_dx = save ? getfield(shadow(value).fields, _name).tangent : nothing
old_dx = save ? getfield(tangent(value).fields, _name).tangent : nothing
function setfield!_pullback(dy, df, dvalue, ::NoTangent, dx)
new_dx = increment!!(dx, getfield(dvalue.fields, _name).tangent)
new_dx = increment!!(new_dx, dy)
old_x !== nothing && setfield!(primal(value), _name, old_x)
old_x !== nothing && _setfield!(shadow(value), _name, old_dx)
old_x !== nothing && _setfield!(tangent(value), _name, old_dx)
return df, dvalue, NoTangent(), new_dx
end
y = CoDual(
setfield!(primal(value), _name, primal(x)),
_setfield!(shadow(value), _name, shadow(x)),
_setfield!(tangent(value), _name, tangent(x)),
)
return y, setfield!_pullback
end
Expand All @@ -548,7 +548,7 @@ end
# throw

function rrule!!(::CoDual{typeof(tuple)}, args...)
y = CoDual(tuple(map(primal, args)...), tuple(map(shadow, args)...))
y = CoDual(tuple(map(primal, args)...), tuple(map(tangent, args)...))
tuple_pullback(dy, ::NoTangent, dargs...) = NoTangent(), map(increment!!, dargs, dy)...
return y, tuple_pullback
end
Expand All @@ -557,7 +557,7 @@ function rrule!!(::CoDual{typeof(typeassert)}, x, type)
function typeassert_pullback(dy, ::NoTangent, dx, ::NoTangent)
return NoTangent(), increment!!(dx, dy), NoTangent()
end
return CoDual(typeassert(primal(x), primal(type)), shadow(x)), typeassert_pullback
return CoDual(typeassert(primal(x), primal(type)), tangent(x)), typeassert_pullback
end

rrule!!(::CoDual{typeof(typeof)}, x) = CoDual(typeof(primal(x)), NoTangent()), NoPullback()
24 changes: 12 additions & 12 deletions src/rrules/foreigncall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ end

isprimitive(::RMC, ::typeof(copy), ::Array) = true
function rrule!!(::CoDual{typeof(copy)}, a::CoDual{<:Array})
y = CoDual(copy(primal(a)), copy(shadow(a)))
y = CoDual(copy(primal(a)), copy(tangent(a)))
copy_pullback!!(dy, df, dx) = df, increment!!(dx, dy)
return y, copy_pullback!!
end
Expand All @@ -63,7 +63,7 @@ function rrule!!(
_d = primal(delta)
_a = primal(a)
Base._growend!(_a, _d)
Base._growend!(shadow(a), _d)
Base._growend!(tangent(a), _d)
function _growend!_pullback!!(dy, df, da, ddelta)
Base._deleteend!(_a, _d)
Base._deleteend!(da, _d)
Expand All @@ -81,7 +81,7 @@ isprimitive(::RMC, ::typeof(pointer_from_objref), x) = true
function rrule!!(::CoDual{typeof(pointer_from_objref)}, x)
y = CoDual(
pointer_from_objref(primal(x)),
bitcast(Ptr{tangent_type(Nothing)}, pointer_from_objref(shadow(x))),
bitcast(Ptr{tangent_type(Nothing)}, pointer_from_objref(tangent(x))),
)
return y, NoPullback()
end
Expand All @@ -105,11 +105,11 @@ function rrule!!(
dest_copy = Vector{T}(undef, _n)
ddest_copy = Vector{T}(undef, _n)
unsafe_copyto!(pointer(dest_copy), primal(dest), _n)
unsafe_copyto!(pointer(ddest_copy), shadow(dest), _n)
unsafe_copyto!(pointer(ddest_copy), tangent(dest), _n)

# Run primal computation.
unsafe_copyto!(primal(dest), primal(src), _n)
unsafe_copyto!(shadow(dest), shadow(src), _n)
unsafe_copyto!(tangent(dest), tangent(src), _n)

function unsafe_copyto!_pb!!(_, df, ddest, dsrc, dn)

Expand All @@ -118,7 +118,7 @@ function rrule!!(

# Restore initial state.
unsafe_copyto!(primal(dest), pointer(dest_copy), _n)
unsafe_copyto!(shadow(dest), pointer(ddest_copy), _n)
unsafe_copyto!(tangent(dest), pointer(ddest_copy), _n)

return df, ddest, dsrc, dn
end
Expand All @@ -141,11 +141,11 @@ function rrule!!(
dest_idx = _doffs:_doffs + _n - 1
_soffs = primal(soffs)
dest_copy = primal(dest)[dest_idx]
ddest_copy = shadow(dest)[dest_idx]
ddest_copy = tangent(dest)[dest_idx]

# Run primal computation.
unsafe_copyto!(primal(dest), _doffs, primal(src), _soffs, _n)
unsafe_copyto!(shadow(dest), _doffs, shadow(src), _soffs, _n)
unsafe_copyto!(tangent(dest), _doffs, tangent(src), _soffs, _n)

function unsafe_copyto_pb!!(_, df, ddest, ddoffs, dsrc, dsoffs, dn)

Expand All @@ -155,7 +155,7 @@ function rrule!!(

# Restore initial state.
primal(dest)[dest_idx] .= dest_copy
shadow(dest)[dest_idx] .= ddest_copy
tangent(dest)[dest_idx] .= ddest_copy

return df, ddest, ddoffs, dsrc, dsoffs, dn
end
Expand All @@ -165,7 +165,7 @@ end

isprimitive(::RMC, ::typeof(Base.unsafe_pointer_to_objref), x::Ptr) = true
function rrule!!(::CoDual{typeof(Base.unsafe_pointer_to_objref)}, x::CoDual{<:Ptr})
y = CoDual(unsafe_pointer_to_objref(primal(x)), unsafe_pointer_to_objref(shadow(x)))
y = CoDual(unsafe_pointer_to_objref(primal(x)), unsafe_pointer_to_objref(tangent(x)))
return y, NoPullback()
end

Expand Down Expand Up @@ -197,7 +197,7 @@ function rrule!!(
) where {T, V}
y = CoDual(
ccall(:jl_array_ptr, Ptr{T}, (Any, ), primal(a)),
ccall(:jl_array_ptr, Ptr{V}, (Any, ), shadow(a)),
ccall(:jl_array_ptr, Ptr{V}, (Any, ), tangent(a)),
)
return y, NoPullback()
end
Expand All @@ -216,7 +216,7 @@ function rrule!!(
d = primal(dims)
y = CoDual(
ccall(:jl_reshape_array, Array{P, M}, (Any, Any, Any), Array{P, M}, primal(a), d),
ccall(:jl_reshape_array, Array{T, M}, (Any, Any, Any), Array{T, M}, shadow(a), d),
ccall(:jl_reshape_array, Array{T, M}, (Any, Any, Any), Array{T, M}, tangent(a), d),
)
return y, NoPullback()
end
Expand Down
4 changes: 2 additions & 2 deletions src/rrules/lapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ for (fname, elty) in ((:dgetrf_, :Float64), (:sgetrf_, :Float32))
N_val = unsafe_load(N)
LDA_val = unsafe_load(LDA)
data_len = LDA_val * N_val
A, dA = primal(_A), shadow(_A)
A, dA = primal(_A), tangent(_A)

@assert M_val === N_val

Expand All @@ -35,7 +35,7 @@ for (fname, elty) in ((:dgetrf_, :Float64), (:sgetrf_, :Float32))
M, N, A, LDA, IPIV, INFO,
)

# Zero out the shadow.
# Zero out the tangent.
foreach(n -> unsafe_store!(dA, zero($elty), n), 1:data_len)

function getrf_pb!!(
Expand Down
2 changes: 1 addition & 1 deletion src/rrules/misc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ function rrule!!(
::CoDual{typeof(lgetfield)}, x::CoDual, ::CoDual{T}
) where {f, T<:Union{SSym{f}, SInt{f}}}
lgetfield_pb!!(dy, df, dx, dsym) = df, increment_field!!(dx, dy, T()), dsym
y = CoDual(getfield(primal(x), f), _get_shadow_field(primal(x), shadow(x), f))
y = CoDual(getfield(primal(x), f), _get_tangent_field(primal(x), tangent(x), f))
return y, lgetfield_pb!!
end

Expand Down
4 changes: 2 additions & 2 deletions src/rrules/umlaut_internals_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ end
return quote
x_ps = map(primal, xs)
y = $(Expr(:new, P, map(n -> :(x_ps[$n]), 1:N)...))
dy = build_tangent(P, map(shadow, xs)...)
dy = build_tangent(P, map(tangent, xs)...)
return CoDual(y, dy), __new__pullback
end
end
Expand Down Expand Up @@ -66,5 +66,5 @@ function rrule!!(::CoDual{typeof(getindex)}, x::CoDual{<:Tuple}, i::CoDual{Int})
dx = ntuple(n -> n == primal(i) ? increment!!(dx[n], dy) : dx[n], length(dx))
return df, dx, NoTangent()
end
return CoDual(primal(x)[primal(i)], shadow(x)[primal(i)]), getindex_pullback!!
return CoDual(primal(x)[primal(i)], tangent(x)[primal(i)]), getindex_pullback!!
end
Loading

0 comments on commit 4605b44

Please sign in to comment.