Skip to content

Commit

Permalink
More fixes for SimpleVarInfo for large models (#160)
Browse files Browse the repository at this point in the history
* Avoid infinite recursion in _apply_iterate rule

* Fix up types in DynamicDerivedRule

* Fix signature computation

* Some work

* Add large function test to Turing

* Improve tangent_type

* Improve fwds rvs data

* Make tuple splat not generated

* _apply_iterate implementation

* Improve safe mode and fix rdata construction

* Make CuArray testable

* Enable all tests

* Fix errors

* Bump patch version
  • Loading branch information
willtebbutt authored May 23, 2024
1 parent 0e46680 commit 1ddee6a
Show file tree
Hide file tree
Showing 17 changed files with 357 additions and 120 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Tapir"
uuid = "07d77754-e150-4737-8c94-cd238a1fb45b"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.2.18"
version = "0.2.19"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
6 changes: 6 additions & 0 deletions ext/TapirCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ module TapirCUDAExt
m[k] = v
return m
end
function Tapir._verify_fdata_value(p::CuArray, f::CuArray)
if size(p) != size(f)
throw(InvalidFDataException("p has size $(size(p)) but f has size $(size(f))"))
end
return nothing
end

# Basic rules for operating on CuArrays.

Expand Down
181 changes: 180 additions & 1 deletion src/fwds_rvs_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,107 @@ end

uninit_fdata(p) = fdata(uninit_tangent(p))

"""
InvalidFDataException(msg::String)
Exception indicating that there is a problem with the fdata associated to a primal.
"""
struct InvalidFDataException <: Exception
msg::String
end

"""
verify_fdata_type(P::Type, F::Type)::Nothing
Check that `F` is a valid type for fdata associated to a primal of type `P`. Returns
`nothing` if valid, throws an `InvalidFDataException` if a problem is found.
This applies to both concrete and non-concrete `P`. For example, if `P` is the type inferred
for a primal `q::Q`, such that `Q <: P`, then this method is still applicable.
"""
function verify_fdata_type(P::Type, F::Type)::Nothing
_F = fdata_type(tangent_type(P))
F <: _F && return nothing
throw(InvalidFDataException("Type $P has fdata type $_F, but got $F."))
end

"""
verify_fdata_value(p, f)::Nothing
Check that `f` cannot be proven to be invalid fdata for `p`.
This method attempts to provide some confidence that `f` is valid fdata for `p` by checking
a collection of necessary conditions. We do not guarantee that these amount to a sufficient
condition, just that they rule out a variety of common problems.
Put differently, we cannot prove that `f` is valid fdata, only that it is not obviously
invalid.
"""
function verify_fdata_value(p, f)::Nothing
verify_fdata_type(_typeof(p), typeof(f))
_verify_fdata_value(p, f)
end

_verify_fdata_value(::IEEEFloat, ::NoFData) = nothing

_verify_fdata_value(::Ptr, ::Ptr) = nothing

function _verify_fdata_value(p::Array, f::Array)
if size(p) != size(f)
throw(InvalidFDataException("p has size $(size(p)) but f has size $(size(f))"))
end

# If the element type is `NoFData` then stop here.
eltype(f) == NoFData && return nothing

# Recurse into each element and check that it is correct. Note that the elements of an
# Array contain the tangents, so we must check that the fdata and rdata components are
# correct separately.
for n in eachindex(p)
if isassigned(p, n)
t = f[n]
verify_fdata_value(p[n], fdata(t))
verify_rdata_value(p[n], rdata(t))
end
end

return nothing
end

function _verify_fdata_value(p, f)

# If f is a NoFData then there are no checks needed, because we have already verified
# that NoFData is the correct type for fdata for p, and NoFData is a singleton type.
f isa NoFData && return nothing

# When a primitive is encountered here, it means that we don't have a method of
# _verify_fdata_value which is specific to it, and its fdata type is not NoFData.
# The rest of this method assumes p is an instance of a struct type, so we must error.
P = _typeof(p)
isprimitivetype(P) && error("Encountered primitive $p with fdata $f")

# (mutable) structs, Tuples, and NamedTuples all have slightly different storage.
_get_fdata_field(f::NamedTuple, name) = getfield(f, name)
_get_fdata_field(f::Tuple, name) = getfield(f, name)
_get_fdata_field(f::FData, name) = val(getfield(f.data, name))
_get_fdata_field(f::MutableTangent, name) = fdata(val(getfield(f.fields, name)))

# Having excluded primitive types, we must have a (mutable) struct type. Recurse into
# its fields and verify each of them.
for name in fieldnames(P)
if isdefined(p, name)
_p = getfield(p, name)
t = _get_fdata_field(f, name)
verify_fdata_value(_p, t)
if f isa MutableTangent
verify_rdata_value(_p, rdata(val(getfield(f.fields, name))))
end
end
end

return nothing
end

"""
NoRData()
Expand Down Expand Up @@ -324,7 +425,9 @@ obtained from `P` alone.
@generated function can_produce_zero_rdata_from_type(::Type{P}) where {P}
R = rdata_type(tangent_type(P))
R == NoRData && return true
isconcretetype(P) || return false
isabstracttype(P) && return false
(isconcretetype(P) || P <: Tuple) || return false
(P <: Tuple && Base.datatype_fieldcount(P) === nothing) && return false

# For general structs, just look at their fields.
return isstructtype(P) ? all(can_produce_zero_rdata_from_type, fieldtypes(P)) : false
Expand Down Expand Up @@ -407,6 +510,82 @@ end

zero_rdata_from_type(::Type{P}) where {P<:IEEEFloat} = zero(P)


"""
InvalidRDataException(msg::String)
Exception indicating that there is a problem with the rdata associated to a primal.
"""
struct InvalidRDataException <: Exception
msg::String
end

"""
verify_rdata_type(P::Type, R::Type)::Nothing
Check that `R` is a valid type for rdata associated to a primal of type `P`. Returns
`nothing` if valid, throws an `InvalidRDataException` if a problem is found.
This applies to both concrete and non-concrete `P`. For example, if `P` is the type inferred
for a primal `q::Q`, such that `Q <: P`, then this method is still applicable.
"""
function verify_rdata_type(P::Type, R::Type)::Nothing
_R = rdata_type(tangent_type(P))
R <: _R && return nothing
throw(InvalidRDataException("Type $P has rdata type $_R, but got $R."))
end

"""
verify_rdata_value(p, r)::Nothing
Check that `r` cannot be proven to be invalid rdata for `p`.
This method attempts to provide some confidence that `r` is valid rdata for `p` by checking
a collection of necessary conditions. We do not guarantee that these amount to a sufficient
condition, just that they rule out a variety of common problems.
Put differently, we cannot prove that `r` is valid rdata, only that it is not obviously
invalid.
"""
function verify_rdata_value(p, r)::Nothing
r isa ZeroRData && return nothing
verify_rdata_type(_typeof(p), typeof(r))
_verify_rdata_value(p, r)
end

_verify_rdata_value(::P, ::P) where {P<:IEEEFloat} = nothing

_verify_rdata_value(::Array, ::NoRData) = nothing

function _verify_rdata_value(p, r)

# If f is a NoFData then there are no checks needed, because we have already verified
# that NoFData is the correct type for fdata for p, and NoFData is a singleton type.
r isa NoRData && return nothing

# When a primitive is encountered here, it means that we don't have a method of
# _verify_rdata_value which is specific to it, and its fdata type is not NoFData.
# The rest of this method assumes p is an instance of a struct type, so we must error.
P = _typeof(p)
isprimitivetype(P) && error("Encountered primitive $p with rdata $r")

# (mutable) structs, Tuples, and NamedTuples all have slightly different storage.
_get_rdata_field(r::NamedTuple, name) = getfield(r, name)
_get_rdata_field(r::Tuple, name) = getfield(r, name)
_get_rdata_field(r::RData, name) = val(getfield(r.data, name))

# Having excluded primitive types, we must have a (mutable) struct type. Recurse into
# its fields and verify each of them.
for name in fieldnames(P)
if isdefined(p, name)
verify_rdata_value(getfield(p, name), _get_rdata_field(r, name))
end
end

return nothing
end


"""
LazyZeroRData{P, Tdata}()
Expand Down
2 changes: 1 addition & 1 deletion src/interpreter/abstract_interpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ end
_type(x) = x
_type(x::CC.Const) = _typeof(x.val)
_type(x::CC.PartialStruct) = x.typ
_type(x::CC.Conditional) = Union{x.thentype, x.elsetype}
_type(x::CC.Conditional) = Union{_type(x.thentype), _type(x.elsetype)}

function CC.inlining_policy(
interp::TapirInterpreter{C},
Expand Down
6 changes: 4 additions & 2 deletions src/interpreter/ir_normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ end

lift_intrinsic(x...) = x
function lift_intrinsic(x::GlobalRef, args...)
val = getglobal(x.mod, x.name)
return val isa Core.IntrinsicFunction ? lift_intrinsic(val, args...) : (x, args...)
return lift_intrinsic(getglobal(x.mod, x.name), args...)
end
function lift_intrinsic(x::Core.IntrinsicFunction, v, args...)
if x === cglobal
Expand All @@ -148,6 +147,9 @@ function lift_intrinsic(x::Core.IntrinsicFunction, v, args...)
return IntrinsicsWrappers.translate(Val(x)), v, args...
end
end
function lift_intrinsic(::typeof(Core._apply_iterate), args...)
return _apply_iterate_equivalent, args...
end

"""
lift_getfield_and_others(inst)
Expand Down
18 changes: 10 additions & 8 deletions src/interpreter/s2s_reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -786,17 +786,20 @@ function build_rrule(
else
fwds_ir = forwards_pass_ir(primal_ir, ad_stmts_blocks, info, _typeof(shared_data))
pb_ir = pullback_ir(primal_ir, Treturn, ad_stmts_blocks, info, _typeof(shared_data))
# @show sig, safety_on

optimised_fwds_ir = optimise_ir!(IRCode(fwds_ir); do_inline=true)
optimised_pb_ir = optimise_ir!(IRCode(pb_ir); do_inline=true)
# @show sig
# @show Treturn
# @show safety_on
# display(ir)
# display(IRCode(fwds_ir))
# display(IRCode(pb_ir))
optimised_fwds_ir = optimise_ir!(IRCode(fwds_ir); do_inline=true)
optimised_pb_ir = optimise_ir!(IRCode(pb_ir); do_inline=true)
# display(optimised_fwds_ir)
# display(optimised_pb_ir)
# @show length(ir.stmts.inst)
# @show length(optimised_fwds_ir.stmts.inst)
# @show length(optimised_pb_ir.stmts.inst)
# display(optimised_fwds_ir)
# display(optimised_pb_ir)
fwds_oc = OpaqueClosure(optimised_fwds_ir, shared_data...; do_compile=true)
pb_oc = OpaqueClosure(optimised_pb_ir, shared_data...; do_compile=true)
interp.oc_cache[(sig, safety_on)] = (fwds_oc, pb_oc)
Expand Down Expand Up @@ -1109,7 +1112,7 @@ __switch_case(id::Int32, predecessor_id::Int32) = !(id === predecessor_id)
@inline __deref_arg_rev_data_refs(arg_rev_data_refs...) = map(getindex, arg_rev_data_refs)

#=
DynamicDerivedRule(interp::TapirInterpreter)
DynamicDerivedRule(interp::TapirInterpreter, safety_on::Bool)
For internal use only.
Expand All @@ -1129,8 +1132,7 @@ function DynamicDerivedRule(interp::TapirInterpreter, safety_on::Bool)
end

function (dynamic_rule::DynamicDerivedRule)(args::Vararg{Any, N}) where {N}
sig = Tuple{tuple_map(_typeof, tuple_map(primal, args))...}
is_primitive(context_type(dynamic_rule.interp), sig) && return rrule!!(args...)
sig = Tuple{map(_typeof primal, args)...}
rule = get(dynamic_rule.cache, sig, nothing)
if rule === nothing
rule = build_rrule(dynamic_rule.interp, sig; safety_on=dynamic_rule.safety_on)
Expand Down
Loading

2 comments on commit 1ddee6a

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/107531

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.19 -m "<description of version>" 1ddee6a729a6d12ced0c938cb70864196e70d775
git push origin v0.2.19

Please sign in to comment.