Skip to content

Commit

Permalink
Housekeeping (#165)
Browse files Browse the repository at this point in the history
* Bump patch version

* Remove redundant stack code

* Remove uninit_codual

* Remove eltype for PossiblyUninitTangent

* Tidy up ir_utils
  • Loading branch information
willtebbutt authored May 21, 2024
1 parent cda63a7 commit 0e46680
Show file tree
Hide file tree
Showing 8 changed files with 15 additions and 162 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Tapir"
uuid = "07d77754-e150-4737-8c94-cd238a1fb45b"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.2.17"
version = "0.2.18"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
17 changes: 8 additions & 9 deletions src/codual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,6 @@ Equivalent to `CoDual(x, zero_tangent(x))`.
"""
zero_codual(x) = CoDual(x, zero_tangent(x))

"""
uninit_codual(x)
See implementation for details, as this function is subject to change.
"""
@inline uninit_codual(x::P) where {P} = CoDual(x, uninit_tangent(x))

@inline uninit_fcodual(x::P) where {P} = CoDual(x, uninit_fdata(x))

"""
codual_type(P::Type)
Expand Down Expand Up @@ -75,6 +66,14 @@ to_fwds(x::CoDual{Type{P}}) where {P} = CoDual{Type{P}, NoFData}(primal(x), NoFD

zero_fcodual(p) = to_fwds(zero_codual(p))

"""
uninit_fcodual(x)
Like `zero_fcodual`, but doesn't guarantee that the value of the fdata is initialised.
See implementation for details, as this function is subject to change.
"""
@inline uninit_fcodual(x::P) where {P} = CoDual(x, uninit_fdata(x))

"""
fcodual_type(P::Type)
Expand Down
61 changes: 2 additions & 59 deletions src/interpreter/ir_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,56 +170,6 @@ function optimise_ir!(ir::IRCode; show_ir=false, do_inline=true)
return ir
end

"""
replace_all_uses_with!(ir::IRCode, value::SSAValue, new_value::Any) -> IRCode
Wherever `value` appears in `ir`, replace it with `new_value`.
Note: this will *not* effect anything in the `new_nodes` field of `ir`.
"""
function replace_all_uses_with!(ir::IRCode, value::SSAValue, new_value::Any)
insts = ir.stmts.inst
for (n, inst) in enumerate(insts)
insts[n] = replace_uses_with(inst, value, new_value)
end
return ir
end

#=
replace_uses_with(x::Any, value::SSAValue, new_value::Any)
Replace all occurences of `value` in the IR node `x` with `new_value`. The semantics of this
are node-dependent.
If `value` appears by itself as a constant, it will not be replaced.
=#
replace_uses_with(x::Any, ::SSAValue, _) = x # constants
function replace_uses_with(x::Expr, v::SSAValue, new_v)
return Expr(x.head, [_replace(v, new_v, a) for a in x.args]...)
end
replace_uses_with(x::GotoNode, ::SSAValue, _) = x
function replace_uses_with(x::GotoIfNot, v::SSAValue, new_v)
return GotoIfNot(_replace(v, new_v, x.cond), x.dest)
end
function replace_uses_with(x::PhiNode, v::SSAValue, new_v)
values = x.values
new_values = Vector{Any}(undef, length(values))
for n in eachindex(new_values)
if isassigned(values, n)
new_values[n] = _replace(v, new_v, values[n])
end
end
return PhiNode(x.edges, new_values)
end
replace_uses_with(x::PiNode, v::SSAValue, new_v) = PiNode(_replace(v, new_v, x.val), x.typ)
replace_uses_with(x::QuoteNode, ::SSAValue, _) = x
function replace_uses_with(x::ReturnNode, v::SSAValue, new_v)
return isdefined(x, :val) ? ReturnNode(_replace(v, new_v, x.val)) : x
end

# Return new_value if val equals current_val.
_replace(val::SSAValue, new_val, current_val) = val == current_val ? new_val : current_val

"""
lookup_ir(interp::AbstractInterpreter, sig::Type{<:Tuple})::Tuple{IRCode, T}
Expand All @@ -239,7 +189,7 @@ function lookup_ir(interp::CC.AbstractInterpreter, sig::Type{<:Tuple})
end

"""
is_reachable(x::ReturnNode)
is_reachable_return_node(x::ReturnNode)
Determine whether `x` is a `ReturnNode`, and if it is, if it is also reachable. This is
purely a function of whether or not its `val` field is defined or not.
Expand All @@ -248,21 +198,14 @@ is_reachable_return_node(x::ReturnNode) = isdefined(x, :val)
is_reachable_return_node(x) = false

"""
is_unreachable(x::ReturnNode)
is_unreachable_return_node(x::ReturnNode)
Determine whehter `x` is a `ReturnNode`, and if it is, if it is also unreachable. This is
purely a function of whether or not its `val` field is defined or not.
"""
is_unreachable_return_node(x::ReturnNode) = !isdefined(x, :val)
is_unreachable_return_node(x) = false

"""
globalref_type(x::GlobaRef)
Returns the static type of the value referred to by `x`.
"""
globalref_type(x::GlobalRef) = isconst(x) ? _typeof(getglobal(x.mod, x.name)) : x.binding.ty

"""
UnhandledLanguageFeatureException(message::String)
Expand Down
41 changes: 1 addition & 40 deletions src/stack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,16 @@ mutable struct Stack{T}
Stack{T}() where {T} = new{T}(Vector{T}(undef, 0), 0)
end

function Stack{T}(x) where {T}
stack = Stack{T}()
push!(stack, x)
return stack
end

Stack(x::T) where {T} = Stack{T}(x)

@inline function Base.push!(x::Stack{T}, val::T) where {T}
position = x.position + 1
memory = x.memory
x.position = position
if position <= length(memory)
@inbounds memory[position] = val
return nothing
else
@noinline push!(memory, val)
return nothing
end
return nothing
end

@inline function Base.pop!(x::Stack)
Expand All @@ -39,37 +30,7 @@ end
return val
end

Base.isempty(x::Stack) = x.position == 0

Base.length(x::Stack) = x.position

"""
Base.getindex(x::Stack)
Return the value at the top of `x` without popping it.
"""
Base.getindex(x::Stack) = x.memory[x.position]

"""
Base.setindex!(x::Stack, v)
Set the value of the element at the top of the `x` to `v`.
"""
function Base.setindex!(x::Stack, v)
x.memory[x.position] = v
return v
end

Base.eltype(::Stack{T}) where {T} = T


struct SingletonStack{T} end

Base.push!(::SingletonStack, ::Any) = nothing
@generated Base.pop!(::SingletonStack{T}) where {T} = T.instance


function reverse_data_ref_type(::Type{P}) where {P}
P === DataType && return Ref{Any}
return Base.RefValue{rdata_type(tangent_type(P))}
end
2 changes: 0 additions & 2 deletions src/tangents.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ _wrap_type(::Type{T}) where {T} = PossiblyUninitTangent{T}
_wrap_field(::Type{Q}, x::T) where {Q, T} = PossiblyUninitTangent{Q}(x)
_wrap_field(x::T) where {T} = _wrap_field(T, x)

Base.eltype(::Type{PossiblyUninitTangent{T}}) where {T} = T

struct Tangent{Tfields<:NamedTuple}
fields::Tfields
end
Expand Down
1 change: 0 additions & 1 deletion test/codual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
@test CoDual(5.0, 4.0) isa CoDual{Float64, Float64}
@test CoDual(Float64, NoTangent()) isa CoDual{Type{Float64}, NoTangent}
@test zero_codual(5.0) == CoDual(5.0, 0.0)
@test Tapir.uninit_codual(5.0) == CoDual(5.0, 0.0)
@test codual_type(Float64) == CoDual{Float64, Float64}
@test codual_type(Int) == CoDual{Int, NoTangent}
@test codual_type(Real) == CoDual
Expand Down
43 changes: 0 additions & 43 deletions test/interpreter/ir_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,49 +42,6 @@ end
# Check that the ir is runable.
@test Core.OpaqueClosure(ir)(5.0) == cos(sin(5.0))
end
@testset "replace_all_uses_with!" begin

# `replace_all_uses_with!` is just a lightweight wrapper around `replace_uses_with`,
# so we just test that carefully.
@testset "replace_uses_with $val" for (val, target) in Any[
(5.0, 5.0),
(5, 5),
(Expr(:call, sin, SSAValue(1)), Expr(:call, sin, SSAValue(2))),
(Expr(:call, sin, SSAValue(3)), Expr(:call, sin, SSAValue(3))),
(GotoNode(1), GotoNode(1)),
(GotoIfNot(false, 5), GotoIfNot(false, 5)),
(GotoIfNot(SSAValue(1), 3), GotoIfNot(SSAValue(2), 3)),
(GotoIfNot(SSAValue(3), 3), GotoIfNot(SSAValue(3), 3)),
(
PhiNode(Int32[1, 2, 3], Any[5, SSAValue(1), SSAValue(3)]),
PhiNode(Int32[1, 2, 3], Any[5, SSAValue(2), SSAValue(3)]),
),
(PiNode(SSAValue(1), Float64), PiNode(SSAValue(2), Float64)),
(PiNode(SSAValue(3), Float64), PiNode(SSAValue(3), Float64)),
(PiNode(Argument(1), Float64), PiNode(Argument(1), Float64)),
(QuoteNode(:a_quote), QuoteNode(:a_quote)),
(ReturnNode(5), ReturnNode(5)),
(ReturnNode(SSAValue(1)), ReturnNode(SSAValue(2))),
(ReturnNode(SSAValue(3)), ReturnNode(SSAValue(3))),
(ReturnNode(), ReturnNode()),
]
@test Tapir.replace_uses_with(val, SSAValue(1), SSAValue(2)) == target
end
@testset "PhiNode with undefined" begin
vals_with_undef_1 = Vector{Any}(undef, 2)
vals_with_undef_1[2] = SSAValue(1)
val = PhiNode(Int32[1, 2], vals_with_undef_1)
result = Tapir.replace_uses_with(val, SSAValue(1), SSAValue(2))
@test result.values[2] == SSAValue(2)
@test !isassigned(result.values, 1)
end
end
@testset "globalref_type" begin
@test Tapir.globalref_type(GlobalRef(IRUtilsGlobalRefs, :__x_1)) == Any
@test Tapir.globalref_type(GlobalRef(IRUtilsGlobalRefs, :__x_2)) == Float64
@test Tapir.globalref_type(GlobalRef(IRUtilsGlobalRefs, :__x_3)) == Float64
@test Tapir.globalref_type(GlobalRef(IRUtilsGlobalRefs, :__x_4)) == Float64
end
@testset "unhandled_feature" begin
@test_throws Tapir.UnhandledLanguageFeatureException Tapir.unhandled_feature("foo")
end
Expand Down
10 changes: 3 additions & 7 deletions test/stack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,10 @@
push!(s, 5.0)
@test s.position == 1
@test s.memory[1] == 5.0
@test length(s) == 1
@test !isempty(s)
@test length(s.memory) == 1

s[] = 6.0
@test s[] == 6.0
@test pop!(s) == 6.0
@test pop!(s) == 5.0
@test s.position == 0
@test length(s) == 0
@test isempty(s)
@test length(s.memory) == 1
end
end

2 comments on commit 0e46680

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/107352

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.18 -m "<description of version>" 0e46680583510f2f1008d8e06ec41bd517af6d9b
git push origin v0.2.18

Please sign in to comment.