Skip to content

Commit

Permalink
Merge branch 'main' into wct/simplevarinfo-fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed May 22, 2024
2 parents ccdf6ec + 0e46680 commit 16d8229
Show file tree
Hide file tree
Showing 13 changed files with 124 additions and 167 deletions.
15 changes: 15 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
steps:
- label: "Julia v1"
plugins:
- JuliaCI/julia#v1:
version: "1"
- JuliaCI/julia-test#v1: ~
- JuliaCI/julia-coverage#v1:
codecov: true
agents:
queue: "juliagpu"
cuda: "*"
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 60
env:
TEST_GROUP: "gpu"
12 changes: 8 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Tapir"
uuid = "07d77754-e150-4737-8c94-cd238a1fb45b"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.2.15"
version = "0.2.18"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -18,16 +18,19 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[extensions]
TapirCUDAExt = "CUDA"
TapirLogDensityProblemsADExt = "LogDensityProblemsAD"
TapirSpecialFunctionsExt = "SpecialFunctions"

[compat]
ADTypes = "1.2"
BenchmarkTools = "1"
CUDA = "5"
ChainRulesCore = "1"
DiffRules = "1"
DiffTests = "0.1"
Expand All @@ -36,19 +39,20 @@ Documenter = "1"
ExprTools = "0.1"
FillArrays = "1"
Graphs = "1"
JET = "0.8"
JET = "0.9"
LogDensityProblemsAD = "1"
PDMats = "0.11"
Setfield = "1"
SpecialFunctions = "2"
StableRNGs = "1"
TemporalGPs = "0.6"
Turing = "0.31.3"
Turing = "0.32"
julia = "1"

[extras]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Expand All @@ -64,4 +68,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[targets]
test = ["AbstractGPs", "BenchmarkTools", "DiffTests", "Distributions", "Documenter", "FillArrays", "KernelFunctions", "LogDensityProblemsAD", "PDMats", "ReverseDiff", "SpecialFunctions", "StableRNGs", "Test", "Turing", "TemporalGPs"]
test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "FillArrays", "KernelFunctions", "LogDensityProblemsAD", "PDMats", "ReverseDiff", "SpecialFunctions", "StableRNGs", "Test", "Turing", "TemporalGPs"]
65 changes: 65 additions & 0 deletions ext/TapirCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
module TapirCUDAExt

using LinearAlgebra, Random, Tapir

using Base: IEEEFloat
using CUDA: CuArray, cu

import Tapir:
MinimalCtx,
rrule!!,
@is_primitive,
tangent_type,
zero_tangent,
randn_tangent,
increment!!,
set_to_zero!!,
_add_to_primal,
_diff,
_dot,
_scale,
TestUtils,
CoDual,
NoPullback

import Tapir.TestUtils: populate_address_map!, AddressMap, __increment_should_allocate

# Tell Tapir.jl how to handle CuArrays.

tangent_type(::Type{P}) where {P<:CuArray{<:IEEEFloat}} = P
zero_tangent(x::CuArray{<:IEEEFloat}) = zero(x)
function randn_tangent(rng::AbstractRNG, x::CuArray{Float32})
return cu(randn(rng, Float32, size(x)...))
end
TestUtils.has_equal_data(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x == y
increment!!(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x .+= y
__increment_should_allocate(::Type{<:CuArray{<:IEEEFloat}}) = true
set_to_zero!!(x::CuArray{<:IEEEFloat}) = x .= 0
_add_to_primal(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x + y
_diff(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = x - y
_dot(x::P, y::P) where {P<:CuArray{<:IEEEFloat}} = Float64(dot(x, y))
_scale(x::Float64, y::P) where {T<:IEEEFloat, P<:CuArray{T}} = T(x) * y
function populate_address_map!(m::AddressMap, p::CuArray, t::CuArray)
k = pointer_from_objref(p)
v = pointer_from_objref(t)
haskey(m, k) && (@assert m[k] == v)
m[k] = v
return m
end

# Basic rules for operating on CuArrays.

@is_primitive(
MinimalCtx,
Tuple{Type{<:CuArray}, UndefInitializer, Vararg{Int, N}} where {N},
)
function rrule!!(
p::CoDual{Type{P}},
init::CoDual{UndefInitializer},
dims::CoDual{Int}...
) where {P<:CuArray{<:Base.IEEEFloat}}
_dims = map(primal, dims)
y = CoDual(P(undef, _dims), P(undef, _dims))
return y, NoPullback(p, init, dims...)
end
end
17 changes: 8 additions & 9 deletions src/codual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,6 @@ Equivalent to `CoDual(x, zero_tangent(x))`.
"""
zero_codual(x) = CoDual(x, zero_tangent(x))

"""
uninit_codual(x)
See implementation for details, as this function is subject to change.
"""
@inline uninit_codual(x::P) where {P} = CoDual(x, uninit_tangent(x))

@inline uninit_fcodual(x::P) where {P} = CoDual(x, uninit_fdata(x))

"""
codual_type(P::Type)
Expand Down Expand Up @@ -75,6 +66,14 @@ to_fwds(x::CoDual{Type{P}}) where {P} = CoDual{Type{P}, NoFData}(primal(x), NoFD

zero_fcodual(p) = to_fwds(zero_codual(p))

"""
uninit_fcodual(x)
Like `zero_fcodual`, but doesn't guarantee that the value of the fdata is initialised.
See implementation for details, as this function is subject to change.
"""
@inline uninit_fcodual(x::P) where {P} = CoDual(x, uninit_fdata(x))

"""
fcodual_type(P::Type)
Expand Down
61 changes: 2 additions & 59 deletions src/interpreter/ir_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,56 +170,6 @@ function optimise_ir!(ir::IRCode; show_ir=false, do_inline=true)
return ir
end

"""
replace_all_uses_with!(ir::IRCode, value::SSAValue, new_value::Any) -> IRCode
Wherever `value` appears in `ir`, replace it with `new_value`.
Note: this will *not* effect anything in the `new_nodes` field of `ir`.
"""
function replace_all_uses_with!(ir::IRCode, value::SSAValue, new_value::Any)
insts = ir.stmts.inst
for (n, inst) in enumerate(insts)
insts[n] = replace_uses_with(inst, value, new_value)
end
return ir
end

#=
replace_uses_with(x::Any, value::SSAValue, new_value::Any)
Replace all occurences of `value` in the IR node `x` with `new_value`. The semantics of this
are node-dependent.
If `value` appears by itself as a constant, it will not be replaced.
=#
replace_uses_with(x::Any, ::SSAValue, _) = x # constants
function replace_uses_with(x::Expr, v::SSAValue, new_v)
return Expr(x.head, [_replace(v, new_v, a) for a in x.args]...)
end
replace_uses_with(x::GotoNode, ::SSAValue, _) = x
function replace_uses_with(x::GotoIfNot, v::SSAValue, new_v)
return GotoIfNot(_replace(v, new_v, x.cond), x.dest)
end
function replace_uses_with(x::PhiNode, v::SSAValue, new_v)
values = x.values
new_values = Vector{Any}(undef, length(values))
for n in eachindex(new_values)
if isassigned(values, n)
new_values[n] = _replace(v, new_v, values[n])
end
end
return PhiNode(x.edges, new_values)
end
replace_uses_with(x::PiNode, v::SSAValue, new_v) = PiNode(_replace(v, new_v, x.val), x.typ)
replace_uses_with(x::QuoteNode, ::SSAValue, _) = x
function replace_uses_with(x::ReturnNode, v::SSAValue, new_v)
return isdefined(x, :val) ? ReturnNode(_replace(v, new_v, x.val)) : x
end

# Return new_value if val equals current_val.
_replace(val::SSAValue, new_val, current_val) = val == current_val ? new_val : current_val

"""
lookup_ir(interp::AbstractInterpreter, sig::Type{<:Tuple})::Tuple{IRCode, T}
Expand All @@ -239,7 +189,7 @@ function lookup_ir(interp::CC.AbstractInterpreter, sig::Type{<:Tuple})
end

"""
is_reachable(x::ReturnNode)
is_reachable_return_node(x::ReturnNode)
Determine whether `x` is a `ReturnNode`, and if it is, if it is also reachable. This is
purely a function of whether or not its `val` field is defined or not.
Expand All @@ -248,21 +198,14 @@ is_reachable_return_node(x::ReturnNode) = isdefined(x, :val)
is_reachable_return_node(x) = false

"""
is_unreachable(x::ReturnNode)
is_unreachable_return_node(x::ReturnNode)
Determine whehter `x` is a `ReturnNode`, and if it is, if it is also unreachable. This is
purely a function of whether or not its `val` field is defined or not.
"""
is_unreachable_return_node(x::ReturnNode) = !isdefined(x, :val)
is_unreachable_return_node(x) = false

"""
globalref_type(x::GlobaRef)
Returns the static type of the value referred to by `x`.
"""
globalref_type(x::GlobalRef) = isconst(x) ? _typeof(getglobal(x.mod, x.name)) : x.binding.ty

"""
UnhandledLanguageFeatureException(message::String)
Expand Down
41 changes: 1 addition & 40 deletions src/stack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,16 @@ mutable struct Stack{T}
Stack{T}() where {T} = new{T}(Vector{T}(undef, 0), 0)
end

function Stack{T}(x) where {T}
stack = Stack{T}()
push!(stack, x)
return stack
end

Stack(x::T) where {T} = Stack{T}(x)

@inline function Base.push!(x::Stack{T}, val::T) where {T}
position = x.position + 1
memory = x.memory
x.position = position
if position <= length(memory)
@inbounds memory[position] = val
return nothing
else
@noinline push!(memory, val)
return nothing
end
return nothing
end

@inline function Base.pop!(x::Stack)
Expand All @@ -39,37 +30,7 @@ end
return val
end

Base.isempty(x::Stack) = x.position == 0

Base.length(x::Stack) = x.position

"""
Base.getindex(x::Stack)
Return the value at the top of `x` without popping it.
"""
Base.getindex(x::Stack) = x.memory[x.position]

"""
Base.setindex!(x::Stack, v)
Set the value of the element at the top of the `x` to `v`.
"""
function Base.setindex!(x::Stack, v)
x.memory[x.position] = v
return v
end

Base.eltype(::Stack{T}) where {T} = T


struct SingletonStack{T} end

Base.push!(::SingletonStack, ::Any) = nothing
@generated Base.pop!(::SingletonStack{T}) where {T} = T.instance


function reverse_data_ref_type(::Type{P}) where {P}
P === DataType && return Ref{Any}
return Base.RefValue{rdata_type(tangent_type(P))}
end
2 changes: 0 additions & 2 deletions src/tangents.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ _wrap_type(::Type{T}) where {T} = PossiblyUninitTangent{T}
_wrap_field(::Type{Q}, x::T) where {Q, T} = PossiblyUninitTangent{Q}(x)
_wrap_field(x::T) where {T} = _wrap_field(T, x)

Base.eltype(::Type{PossiblyUninitTangent{T}}) where {T} = T

struct Tangent{Tfields<:NamedTuple}
fields::Tfields
end
Expand Down
4 changes: 2 additions & 2 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -815,9 +815,9 @@ function test_tangent(
perf && test_tangent_performance(rng, p)
end

function test_tangent(rng::AbstractRNG, p::P; interface_only=false) where {P}
function test_tangent(rng::AbstractRNG, p::P; interface_only=false, perf=true) where {P}
test_tangent_consistency(rng, p; interface_only)
test_tangent_performance(rng, p)
perf && test_tangent_performance(rng, p)
end

function test_equality_comparison(x)
Expand Down
1 change: 0 additions & 1 deletion test/codual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
@test CoDual(5.0, 4.0) isa CoDual{Float64, Float64}
@test CoDual(Float64, NoTangent()) isa CoDual{Type{Float64}, NoTangent}
@test zero_codual(5.0) == CoDual(5.0, 0.0)
@test Tapir.uninit_codual(5.0) == CoDual(5.0, 0.0)
@test codual_type(Float64) == CoDual{Float64, Float64}
@test codual_type(Int) == CoDual{Int, NoTangent}
@test codual_type(Real) == CoDual
Expand Down
18 changes: 18 additions & 0 deletions test/integration_testing/cuda.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
using CUDA

@testset "cuda" begin

# Check we can operate on CuArrays.
test_tangent(
Xoshiro(123456),
CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}(undef, 8, 8);
interface_only=false,
)

# Check we can instantiate a CuArray.
interp = Tapir.TapirInterpreter()
TestUtils.test_derived_rule(
sr(123456), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, undef, 256;
interp, perf_flag=:none, interface_only=true, is_primitive=true,
)
end
Loading

0 comments on commit 16d8229

Please sign in to comment.