Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Housekeeping #165

Merged
merged 5 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Tapir"
uuid = "07d77754-e150-4737-8c94-cd238a1fb45b"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.2.17"
version = "0.2.18"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
17 changes: 8 additions & 9 deletions src/codual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,6 @@ Equivalent to `CoDual(x, zero_tangent(x))`.
"""
zero_codual(x) = CoDual(x, zero_tangent(x))

"""
uninit_codual(x)

See implementation for details, as this function is subject to change.
"""
@inline uninit_codual(x::P) where {P} = CoDual(x, uninit_tangent(x))

@inline uninit_fcodual(x::P) where {P} = CoDual(x, uninit_fdata(x))

"""
codual_type(P::Type)

Expand Down Expand Up @@ -75,6 +66,14 @@ to_fwds(x::CoDual{Type{P}}) where {P} = CoDual{Type{P}, NoFData}(primal(x), NoFD

zero_fcodual(p) = to_fwds(zero_codual(p))

"""
uninit_fcodual(x)

Like `zero_fcodual`, but doesn't guarantee that the value of the fdata is initialised.
See implementation for details, as this function is subject to change.
"""
@inline uninit_fcodual(x::P) where {P} = CoDual(x, uninit_fdata(x))

"""
fcodual_type(P::Type)

Expand Down
61 changes: 2 additions & 59 deletions src/interpreter/ir_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,56 +170,6 @@ function optimise_ir!(ir::IRCode; show_ir=false, do_inline=true)
return ir
end

"""
replace_all_uses_with!(ir::IRCode, value::SSAValue, new_value::Any) -> IRCode

Wherever `value` appears in `ir`, replace it with `new_value`.

Note: this will *not* effect anything in the `new_nodes` field of `ir`.
"""
function replace_all_uses_with!(ir::IRCode, value::SSAValue, new_value::Any)
insts = ir.stmts.inst
for (n, inst) in enumerate(insts)
insts[n] = replace_uses_with(inst, value, new_value)
end
return ir
end

#=
replace_uses_with(x::Any, value::SSAValue, new_value::Any)

Replace all occurences of `value` in the IR node `x` with `new_value`. The semantics of this
are node-dependent.

If `value` appears by itself as a constant, it will not be replaced.
=#
replace_uses_with(x::Any, ::SSAValue, _) = x # constants
function replace_uses_with(x::Expr, v::SSAValue, new_v)
return Expr(x.head, [_replace(v, new_v, a) for a in x.args]...)
end
replace_uses_with(x::GotoNode, ::SSAValue, _) = x
function replace_uses_with(x::GotoIfNot, v::SSAValue, new_v)
return GotoIfNot(_replace(v, new_v, x.cond), x.dest)
end
function replace_uses_with(x::PhiNode, v::SSAValue, new_v)
values = x.values
new_values = Vector{Any}(undef, length(values))
for n in eachindex(new_values)
if isassigned(values, n)
new_values[n] = _replace(v, new_v, values[n])
end
end
return PhiNode(x.edges, new_values)
end
replace_uses_with(x::PiNode, v::SSAValue, new_v) = PiNode(_replace(v, new_v, x.val), x.typ)
replace_uses_with(x::QuoteNode, ::SSAValue, _) = x
function replace_uses_with(x::ReturnNode, v::SSAValue, new_v)
return isdefined(x, :val) ? ReturnNode(_replace(v, new_v, x.val)) : x
end

# Return new_value if val equals current_val.
_replace(val::SSAValue, new_val, current_val) = val == current_val ? new_val : current_val

"""
lookup_ir(interp::AbstractInterpreter, sig::Type{<:Tuple})::Tuple{IRCode, T}

Expand All @@ -239,7 +189,7 @@ function lookup_ir(interp::CC.AbstractInterpreter, sig::Type{<:Tuple})
end

"""
is_reachable(x::ReturnNode)
is_reachable_return_node(x::ReturnNode)

Determine whether `x` is a `ReturnNode`, and if it is, if it is also reachable. This is
purely a function of whether or not its `val` field is defined or not.
Expand All @@ -248,21 +198,14 @@ is_reachable_return_node(x::ReturnNode) = isdefined(x, :val)
is_reachable_return_node(x) = false

"""
is_unreachable(x::ReturnNode)
is_unreachable_return_node(x::ReturnNode)

Determine whehter `x` is a `ReturnNode`, and if it is, if it is also unreachable. This is
purely a function of whether or not its `val` field is defined or not.
"""
is_unreachable_return_node(x::ReturnNode) = !isdefined(x, :val)
is_unreachable_return_node(x) = false

"""
globalref_type(x::GlobaRef)

Returns the static type of the value referred to by `x`.
"""
globalref_type(x::GlobalRef) = isconst(x) ? _typeof(getglobal(x.mod, x.name)) : x.binding.ty

"""
UnhandledLanguageFeatureException(message::String)

Expand Down
41 changes: 1 addition & 40 deletions src/stack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,16 @@ mutable struct Stack{T}
Stack{T}() where {T} = new{T}(Vector{T}(undef, 0), 0)
end

function Stack{T}(x) where {T}
stack = Stack{T}()
push!(stack, x)
return stack
end

Stack(x::T) where {T} = Stack{T}(x)

@inline function Base.push!(x::Stack{T}, val::T) where {T}
position = x.position + 1
memory = x.memory
x.position = position
if position <= length(memory)
@inbounds memory[position] = val
return nothing
else
@noinline push!(memory, val)
return nothing
end
return nothing
end

@inline function Base.pop!(x::Stack)
Expand All @@ -39,37 +30,7 @@ end
return val
end

Base.isempty(x::Stack) = x.position == 0

Base.length(x::Stack) = x.position

"""
Base.getindex(x::Stack)

Return the value at the top of `x` without popping it.
"""
Base.getindex(x::Stack) = x.memory[x.position]

"""
Base.setindex!(x::Stack, v)

Set the value of the element at the top of the `x` to `v`.
"""
function Base.setindex!(x::Stack, v)
x.memory[x.position] = v
return v
end

Base.eltype(::Stack{T}) where {T} = T


struct SingletonStack{T} end

Base.push!(::SingletonStack, ::Any) = nothing
@generated Base.pop!(::SingletonStack{T}) where {T} = T.instance


function reverse_data_ref_type(::Type{P}) where {P}
P === DataType && return Ref{Any}
return Base.RefValue{rdata_type(tangent_type(P))}
end
2 changes: 0 additions & 2 deletions src/tangents.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ _wrap_type(::Type{T}) where {T} = PossiblyUninitTangent{T}
_wrap_field(::Type{Q}, x::T) where {Q, T} = PossiblyUninitTangent{Q}(x)
_wrap_field(x::T) where {T} = _wrap_field(T, x)

Base.eltype(::Type{PossiblyUninitTangent{T}}) where {T} = T

struct Tangent{Tfields<:NamedTuple}
fields::Tfields
end
Expand Down
1 change: 0 additions & 1 deletion test/codual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
@test CoDual(5.0, 4.0) isa CoDual{Float64, Float64}
@test CoDual(Float64, NoTangent()) isa CoDual{Type{Float64}, NoTangent}
@test zero_codual(5.0) == CoDual(5.0, 0.0)
@test Tapir.uninit_codual(5.0) == CoDual(5.0, 0.0)
@test codual_type(Float64) == CoDual{Float64, Float64}
@test codual_type(Int) == CoDual{Int, NoTangent}
@test codual_type(Real) == CoDual
Expand Down
43 changes: 0 additions & 43 deletions test/interpreter/ir_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,49 +42,6 @@ end
# Check that the ir is runable.
@test Core.OpaqueClosure(ir)(5.0) == cos(sin(5.0))
end
@testset "replace_all_uses_with!" begin

# `replace_all_uses_with!` is just a lightweight wrapper around `replace_uses_with`,
# so we just test that carefully.
@testset "replace_uses_with $val" for (val, target) in Any[
(5.0, 5.0),
(5, 5),
(Expr(:call, sin, SSAValue(1)), Expr(:call, sin, SSAValue(2))),
(Expr(:call, sin, SSAValue(3)), Expr(:call, sin, SSAValue(3))),
(GotoNode(1), GotoNode(1)),
(GotoIfNot(false, 5), GotoIfNot(false, 5)),
(GotoIfNot(SSAValue(1), 3), GotoIfNot(SSAValue(2), 3)),
(GotoIfNot(SSAValue(3), 3), GotoIfNot(SSAValue(3), 3)),
(
PhiNode(Int32[1, 2, 3], Any[5, SSAValue(1), SSAValue(3)]),
PhiNode(Int32[1, 2, 3], Any[5, SSAValue(2), SSAValue(3)]),
),
(PiNode(SSAValue(1), Float64), PiNode(SSAValue(2), Float64)),
(PiNode(SSAValue(3), Float64), PiNode(SSAValue(3), Float64)),
(PiNode(Argument(1), Float64), PiNode(Argument(1), Float64)),
(QuoteNode(:a_quote), QuoteNode(:a_quote)),
(ReturnNode(5), ReturnNode(5)),
(ReturnNode(SSAValue(1)), ReturnNode(SSAValue(2))),
(ReturnNode(SSAValue(3)), ReturnNode(SSAValue(3))),
(ReturnNode(), ReturnNode()),
]
@test Tapir.replace_uses_with(val, SSAValue(1), SSAValue(2)) == target
end
@testset "PhiNode with undefined" begin
vals_with_undef_1 = Vector{Any}(undef, 2)
vals_with_undef_1[2] = SSAValue(1)
val = PhiNode(Int32[1, 2], vals_with_undef_1)
result = Tapir.replace_uses_with(val, SSAValue(1), SSAValue(2))
@test result.values[2] == SSAValue(2)
@test !isassigned(result.values, 1)
end
end
@testset "globalref_type" begin
@test Tapir.globalref_type(GlobalRef(IRUtilsGlobalRefs, :__x_1)) == Any
@test Tapir.globalref_type(GlobalRef(IRUtilsGlobalRefs, :__x_2)) == Float64
@test Tapir.globalref_type(GlobalRef(IRUtilsGlobalRefs, :__x_3)) == Float64
@test Tapir.globalref_type(GlobalRef(IRUtilsGlobalRefs, :__x_4)) == Float64
end
@testset "unhandled_feature" begin
@test_throws Tapir.UnhandledLanguageFeatureException Tapir.unhandled_feature("foo")
end
Expand Down
10 changes: 3 additions & 7 deletions test/stack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,10 @@
push!(s, 5.0)
@test s.position == 1
@test s.memory[1] == 5.0
@test length(s) == 1
@test !isempty(s)
@test length(s.memory) == 1

s[] = 6.0
@test s[] == 6.0
@test pop!(s) == 6.0
@test pop!(s) == 5.0
@test s.position == 0
@test length(s) == 0
@test isempty(s)
@test length(s.memory) == 1
end
end
Loading