Skip to content

Commit

Permalink
Wct/refactor api and ad implementation (#121)
Browse files Browse the repository at this point in the history
* Separate forwards- and reverse-data

* Move low level maths rules over to new system

* Extend tuple_map to handle named tuples

* Get most builtins working

* Fix low_level_maths rules

* Excise register-related code

* Excise more register-related code

* Include interpreter code for now

* Single block code works + remove redundant stacks

* Add function to get phi nodes from bbcode

* Add online type checking functionality

* Improve safety error messages

* Remove redundant stack code

* Fix phi_nodes bbcode function

* Safety checks in test utils

* Add reflection to utils

* Some work

* Tweaks

* Improve cos and sin rules

* Incorporate test from slack

* Move generic functionality from tangnets to utils and document

* Add tests to code move to utils from tangents

* Remove redundant function

* Don't run performance tests in s2s tests

* Remove redundant code

* Remove redundant code

* Remove commented out code

* Remove redundant code

* Reorganise tangents file

* Move generic functionality to utils from tangents

* Move generic functionality to utils from tangents

* Start tidying up tangent test cases

* Simplify tangent testing

* Test perf for all tangents and fix perf bug

* Unify all type testing

* Remove redundant test info

* Remove redundant alias

* Fix comment

* Enable more correctness tests

* Tidy up test_utils further

* Add more tuple tests

* Add more tuple test cases

* Run fwds_rvs_data tests on all types

* Improve fwds and rvs implementation and get all related tests passing

* Simplify function names

* Rename zero_reverse_data_from_type to zero_rdata_from_type

* Rename zero_reverse_data to zero_rdata

* Remove redundant code

* Get NamedTuples working

* Add tests for structs in new

* Remove arbitrary limit on number of arguments to _new_

* Formatting

* Formatting

* Get all existing _new_ tests passing

* Newline at end of file

* Add remainder of standard tangent test cases to _new_

* Add triangular test cases to new

* Remove unused code

* Enable more integration tests

* Check number of cotangents in safe mode

* Tweak error message

* Rename safety to safe_mode

* Improve comment

* Rename safety to safe_mode in runtests

* Fix bug in new for partially initialised structs

* Require same lengths of args to tuple_map

* Improve docstring for tuple_map

* Improve docstring

* Fix vararg bug

* Fix typo in cglobal rule

* Split out test types construction

* Fix edge case for _new_ rrule

* Improve some builtin rules slightly

* Use alias for block stack

* Active more integration tests

* All existing lgetfield tests pass

* Support order keyword

* Fix up lsetfield

* Active more integration tests

* Add more test cases to lgetfield

* Tidy up lgetfield rrule implementation

* Reenable getfield and setfield rules

* Fix avoiding_non_diff_code rules

* Enable foreigncall stuff

* Enable more tests and fix increment perf

* Improve arrayref and set -- fix pointerref and set

* Sort out more things

* Enable all foreigncall tests

* Fix safe_mode recursion

* Fix safe mode compilation times

* Refactor combine_data to binary tangent and tangent_type

* Fix increment inference bug

* Fix deprecations in benchmarks script

* combine_data to tangent and tangent_type

* Name zero_rdata_from_type to zero_like_rdata_from_type

* Fix dynamic rrule safe mode

* Move zero_like_rdata to its own files

* Fix zerordata and friends

* Include IdDict tests

* Move tuple_fill to utils

* Fix gemv

* Update chainrules macro

* Update trmv

* Update remainder of blas rules

* Update lapack rules

* Add additional builtins test

* Add NoPullback rule to misc

* Add additional options to test utils

* Make DimensionMismatch tangent NoTangent

* Turn safety off for misc tests

* Minor codegen performance tweaks

* Reactivate remainder of tests

* Fix abstract types

* Bump patch

* Tweak contributor list

* Try running the GC before benchmarking

* Update performance bounds

* Tweak plotting range

* Rename functions

* Tidy up safe mode implementation

* Improve safe mode

* Improve safe mode error messages

* Fix rrule for arrayset

* Fix pointer fdata type bug

* Fix safe_mode tests

* Add increment_rdata functionality

* Simplify arrayref and pointerref implementations

* Simplify lgetfield rrule implementation

* Use lazy rdata in getfield

* Simplify lsetfield implementation

* Remove some specialisation to reduce compile times

* Improve codual

* Fix tangent_type for Tuple

* Fix up iddict tests

* Remove GC calls

* Loosen performance bounds on some type unstable tests

* Hopefully fix allocations in CI

* Fix typo

* Remove intrinsic test case

* Remove redundant code

* Fix performance of rdata creation

* Remove redundant code and tidy up variable naming

* Fix peformance of zero_rdata_from_type

* Add regression tests

* Fix performance bug

* Remove allocation tests which don't make sense

* Fix randn_tangent perf and improve test reliability

* Add pprof code to turing integration tests

* Force inline rules for getfield and setfield

* Increased stability for non-literal getfield and setfield

* Add signature to type of LazyDerivedRule

* Add unique predecessor computation

* Reduce block usage

* Revert flawed performance fix

* Enable all tests

* Fix performance regression

* Improve arrayref and arrayset performance

* Add erfcx rule to special functions ext

* Specialised getfield rules for nondifferentiable stuff

* Make eltypes non-differentiable

* Test type with LazyRZero

* Improve pullback codegen for unreachable blocks

* Fix test cases

* Tangent type for Method is NoTangent

* Revert Turing performance test again

* Add more lsetfield tests

* More tests for lgetfield and bug fix

* Add comment noting choice of code layout

* Optimise safe mode compile times

* Import more things during testing

* Make NoPullback use lazy rdata

* Support NoRData in instantiate

* Preserve NoPullback in RRuleZeroWrapper

* Update NoPullback uses to use new version

* Reduce problem sizes to make CI happier

* Bump Turing compat

* Add fast increment_field for homogeneous Tuple types

* Optimise for homogeneous tuples

* Optimise homogeneous named tuples increment_field

* Optimise getfield for homogeneously-typed NamedTuples

* Reduce test problem sizes further to reduce CI burden

* Add TemporalGPs integration test

* Force-inline more things

* Ensure to_benchmark is compiled before running the profiler

* Check that getfield works with tuple of types

* Fix test utils for tuple of types

* Add functionality to determine if a node is used

* Do not AD getfield when not used

* Optimise safe mode

* Add getfield regression integration test

* Force-inling forwards-pass IR for calls and invokes

* Add additional small-union test

* Optimise for un-used ssa nodes and things with NoPullback pullbacks

* Improve documentation for ADInfo

* Document BlockStack const

* Improve documentation for ADInfo outer constructors

* Improve documentation of misc functions in reverse mode ad transformations

* Tidy up RRuleZeroWrapper and ReturnNode

* Tidy up gotoifnot implementation

* Remove commented-out line of code

* Tidy up reverse-mode code further

* Explain special handling for unused getfield calls carefully

* Add directions to bbcode file

* More documentation for reverse-mode

* Improve transformation documentation further

* Improve fwds_rvs documentation

* More informative name than __convert

* Improve comment

* Rename fwds_codual_type to fcodual_type

* Improve unique pred characterisation documentation
  • Loading branch information
willtebbutt authored Apr 29, 2024
1 parent 992a862 commit 965b98b
Show file tree
Hide file tree
Showing 47 changed files with 3,689 additions and 3,720 deletions.
10 changes: 6 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Tapir"
uuid = "07d77754-e150-4737-8c94-cd238a1fb45b"
authors = ["Will Tebbutt and contributors"]
version = "0.1.2"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.2.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down Expand Up @@ -36,7 +36,8 @@ PDMats = "0.11"
Setfield = "1"
SpecialFunctions = "2"
StableRNGs = "1"
Turing = "0.29"
TemporalGPs = "0.6"
Turing = "0.31"
julia = "1"

[extras]
Expand All @@ -50,8 +51,9 @@ PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
TemporalGPs = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[targets]
test = ["AbstractGPs", "BenchmarkTools", "DiffTests", "Distributions", "FillArrays", "KernelFunctions", "PDMats", "ReverseDiff", "SpecialFunctions", "StableRNGs", "Test", "Turing"]
test = ["AbstractGPs", "BenchmarkTools", "DiffTests", "Distributions", "FillArrays", "KernelFunctions", "PDMats", "ReverseDiff", "SpecialFunctions", "StableRNGs", "Test", "Turing", "TemporalGPs"]
7 changes: 3 additions & 4 deletions bench/run_benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ using Tapir:
CoDual,
generate_hand_written_rrule!!_test_cases,
generate_derived_rrule!!_test_cases,
InterpretedFunction,
TestUtils,
PInterp,
_typeof
Expand Down Expand Up @@ -100,8 +99,8 @@ function _generate_gp_inputs()
end

@model broadcast_demo(x) = begin
μ ~ TruncatedNormal(1, 2, 0.1, 10)
σ ~ TruncatedNormal(1, 2, 0.1, 10)
μ ~ truncated(Normal(1, 2), 0.1, 10)
σ ~ truncated(Normal(1, 2), 0.1, 10)
x .~ LogNormal(μ, σ)
end

Expand Down Expand Up @@ -295,7 +294,7 @@ Constructs a histogram of the `tapir_ratio` field of `df`, with formatting that
well-suited to the numbers typically found in this field.
"""
function plot_ratio_histogram!(df::DataFrame)
bin = 10.0 .^ (0.0:0.05:6.0)
bin = 10.0 .^ (-1.0:0.05:4.0)
xlim = extrema(bin)
histogram(df.Tapir; xscale=:log10, xlim, bin, title="log", label="")
end
Expand Down
1 change: 1 addition & 0 deletions ext/TapirSpecialFunctionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ module TapirSpecialFunctionsExt
@from_rrule DefaultCtx Tuple{typeof(airyai), Float64}
@from_rrule DefaultCtx Tuple{typeof(airyaix), Float64}
@from_rrule DefaultCtx Tuple{typeof(erfc), Float64}
@from_rrule DefaultCtx Tuple{typeof(erfcx), Float64}
end
7 changes: 4 additions & 3 deletions src/Tapir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,21 @@ using LinearAlgebra.LAPACK: getrf!, getrs!, getri!, trtrs!, potrf!, potrs!

# Needs to be defined before various other things.
function _foreigncall_ end
function rrule!! end

include("utils.jl")
include("tangents.jl")
include("fwds_rvs_data.jl")
include("codual.jl")
include("safe_mode.jl")
include("stack.jl")

include(joinpath("interpreter", "contexts.jl"))
include(joinpath("interpreter", "abstract_interpretation.jl"))
include(joinpath("interpreter", "bbcode.jl"))
include(joinpath("interpreter", "ir_utils.jl"))
include(joinpath("interpreter", "ir_normalisation.jl"))
include(joinpath("interpreter", "registers.jl"))
include(joinpath("interpreter", "interpreted_function.jl"))
include(joinpath("interpreter", "reverse_mode_ad.jl"))
include(joinpath("interpreter", "zero_like_rdata.jl"))
include(joinpath("interpreter", "s2s_reverse_mode_ad.jl"))

include("test_utils.jl")
Expand Down
16 changes: 5 additions & 11 deletions src/chain_rules_macro.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__increment_shim!!(::NoTangent, ::ChainRulesCore.NoTangent) = NoTangent()
__increment_shim!!(x, y) = increment!!(x, y)
_to_rdata(::ChainRulesCore.NoTangent) = NoRData()
_to_rdata(dx::Float64) = dx

"""
@from_rrule ctx sig
Expand Down Expand Up @@ -38,21 +38,15 @@ macro from_rrule(ctx, sig)
map(n -> :(Tapir.primal($n)), arg_names)...,
)

pb_arg_names = map(n -> Symbol("dx_$(n)"), eachindex(arg_names))
pb_output_names = map(n -> Symbol("dx_$(n)_inc"), eachindex(arg_names))

call_pb = Expr(:(=), Expr(:tuple, pb_output_names...), :(pb(dy)))
incrementers = Expr(
:tuple,
map(pb_arg_names, pb_output_names) do a, b
:(Tapir.__increment_shim!!($a, $b))
end...,
)
incrementers = Expr(:tuple, map(b -> :(Tapir._to_rdata($b)), pb_output_names)...)

pb = ExprTools.combinedef(Dict(
:head => :function,
:name => :pb!!,
:args => [:dy, pb_arg_names...],
:args => [:dy],
:body => quote
$call_pb
return $incrementers
Expand All @@ -67,7 +61,7 @@ macro from_rrule(ctx, sig)
:body => quote
y, pb = $call_rrule
$pb
return Tapir.zero_codual(y), pb!!
return Tapir.zero_fcodual(y), pb!!
end,
)
)
Expand Down
56 changes: 51 additions & 5 deletions src/codual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ struct CoDual{Tx, Tdx}
dx::Tdx
end

# Always sharpen the first thing if it's a type, in order to preserve dispatch possibility.
# Always sharpen the first thing if it's a type so static dispatch remains possible.
function CoDual(x::Type{P}, dx::NoFData) where {P}
return CoDual{@isdefined(P) ? Type{P} : typeof(x), NoFData}(P, dx)
end

function CoDual(x::Type{P}, dx::NoTangent) where {P}
return CoDual{@isdefined(P) ? Type{P} : typeof(x), NoTangent}(P, dx)
end
Expand All @@ -26,21 +30,63 @@ See implementation for details, as this function is subject to change.
"""
@inline uninit_codual(x::P) where {P} = CoDual(x, uninit_tangent(x))

@inline uninit_fcodual(x::P) where {P} = CoDual(x, uninit_fdata(x))

"""
codual_type(P::Type)
Shorthand for `CoDual{P, tangent_type(P}}` when `P` is concrete, equal to `CoDual` if not.
The type of the `CoDual` which contains instances of `P` and associated tangents.
"""
function codual_type(::Type{P}) where {P}
P == DataType && return CoDual
P isa Union && return Union{codual_type(P.a), codual_type(P.b)}
P <: UnionAll && return CoDual
return isconcretetype(P) ? CoDual{P, tangent_type(P)} : CoDual
end

codual_type(::Type{Type{P}}) where {P} = CoDual{Type{P}, NoTangent}

struct NoPullback end
struct NoPullback{R<:Tuple}
r::R
end

"""
NoPullback(args::CoDual...)
Construct a `NoPullback` from the arguments passed to an `rrule!!`. For each argument,
extracts the primal value, and constructs a `LazyZeroRData`. These are stored in a
`NoPullback` which, in the reverse-pass of AD, instantiates these `LazyZeroRData`s and
returns them in order to perform the reverse-pass of AD.
The advantage of this approach is that if it is possible to construct the zero rdata element
for each of the arguments lazily, the `NoPullback` generated will be a singleton type. This
means that AD can avoid generating a stack to store this pullback, which can result in
significant performance improvements.
"""
function NoPullback(args::Vararg{CoDual, N}) where {N}
return NoPullback(tuple_map(LazyZeroRData primal, args))
end

@inline (pb::NoPullback)(_) = tuple_map(instantiate, pb.r)

to_fwds(x::CoDual) = CoDual(primal(x), fdata(tangent(x)))

to_fwds(x::CoDual{Type{P}}) where {P} = CoDual{Type{P}, NoFData}(primal(x), NoFData())

zero_fcodual(p) = to_fwds(zero_codual(p))

"""
fcodual_type(P::Type)
The type of the `CoDual` which contains instances of `P` and its fdata.
"""
function fcodual_type(::Type{P}) where {P}
P == DataType && return CoDual
P isa Union && return Union{fcodual_type(P.a), fcodual_type(P.b)}
P <: UnionAll && return CoDual
return isconcretetype(P) ? CoDual{P, fdata_type(tangent_type(P))} : CoDual
end

@inline (::NoPullback)(dy, dx...) = dx
fcodual_type(::Type{Type{P}}) where {P} = CoDual{Type{P}, NoFData}

might_be_active(args) = any(might_be_active _typeof, args)
zero_rdata(x::CoDual) = zero_rdata(primal(x))
Loading

2 comments on commit 965b98b

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/105808

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.0 -m "<description of version>" 965b98bdafbbc82de94398fb5f66d7ff51996679
git push origin v0.2.0

Please sign in to comment.