From 0e46680583510f2f1008d8e06ec41bd517af6d9b Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 21 May 2024 22:03:04 +0200 Subject: [PATCH] Housekeeping (#165) * Bump patch version * Remove redundant stack code * Remove uninit_codual * Remove eltype for PossiblyUninitTangent * Tidy up ir_utils --- Project.toml | 2 +- src/codual.jl | 17 +++++----- src/interpreter/ir_utils.jl | 61 ++---------------------------------- src/stack.jl | 41 +----------------------- src/tangents.jl | 2 -- test/codual.jl | 1 - test/interpreter/ir_utils.jl | 43 ------------------------- test/stack.jl | 10 ++---- 8 files changed, 15 insertions(+), 162 deletions(-) diff --git a/Project.toml b/Project.toml index 940f3bb8..afc97412 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Tapir" uuid = "07d77754-e150-4737-8c94-cd238a1fb45b" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.2.17" +version = "0.2.18" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/codual.jl b/src/codual.jl index a64fc8b5..de1fd514 100644 --- a/src/codual.jl +++ b/src/codual.jl @@ -23,15 +23,6 @@ Equivalent to `CoDual(x, zero_tangent(x))`. """ zero_codual(x) = CoDual(x, zero_tangent(x)) -""" - uninit_codual(x) - -See implementation for details, as this function is subject to change. -""" -@inline uninit_codual(x::P) where {P} = CoDual(x, uninit_tangent(x)) - -@inline uninit_fcodual(x::P) where {P} = CoDual(x, uninit_fdata(x)) - """ codual_type(P::Type) @@ -75,6 +66,14 @@ to_fwds(x::CoDual{Type{P}}) where {P} = CoDual{Type{P}, NoFData}(primal(x), NoFD zero_fcodual(p) = to_fwds(zero_codual(p)) +""" + uninit_fcodual(x) + +Like `zero_fcodual`, but doesn't guarantee that the value of the fdata is initialised. +See implementation for details, as this function is subject to change. +""" +@inline uninit_fcodual(x::P) where {P} = CoDual(x, uninit_fdata(x)) + """ fcodual_type(P::Type) diff --git a/src/interpreter/ir_utils.jl b/src/interpreter/ir_utils.jl index 466cdedd..c30a70f0 100644 --- a/src/interpreter/ir_utils.jl +++ b/src/interpreter/ir_utils.jl @@ -170,56 +170,6 @@ function optimise_ir!(ir::IRCode; show_ir=false, do_inline=true) return ir end -""" - replace_all_uses_with!(ir::IRCode, value::SSAValue, new_value::Any) -> IRCode - -Wherever `value` appears in `ir`, replace it with `new_value`. - -Note: this will *not* effect anything in the `new_nodes` field of `ir`. -""" -function replace_all_uses_with!(ir::IRCode, value::SSAValue, new_value::Any) - insts = ir.stmts.inst - for (n, inst) in enumerate(insts) - insts[n] = replace_uses_with(inst, value, new_value) - end - return ir -end - -#= - replace_uses_with(x::Any, value::SSAValue, new_value::Any) - -Replace all occurences of `value` in the IR node `x` with `new_value`. The semantics of this -are node-dependent. - -If `value` appears by itself as a constant, it will not be replaced. -=# -replace_uses_with(x::Any, ::SSAValue, _) = x # constants -function replace_uses_with(x::Expr, v::SSAValue, new_v) - return Expr(x.head, [_replace(v, new_v, a) for a in x.args]...) -end -replace_uses_with(x::GotoNode, ::SSAValue, _) = x -function replace_uses_with(x::GotoIfNot, v::SSAValue, new_v) - return GotoIfNot(_replace(v, new_v, x.cond), x.dest) -end -function replace_uses_with(x::PhiNode, v::SSAValue, new_v) - values = x.values - new_values = Vector{Any}(undef, length(values)) - for n in eachindex(new_values) - if isassigned(values, n) - new_values[n] = _replace(v, new_v, values[n]) - end - end - return PhiNode(x.edges, new_values) -end -replace_uses_with(x::PiNode, v::SSAValue, new_v) = PiNode(_replace(v, new_v, x.val), x.typ) -replace_uses_with(x::QuoteNode, ::SSAValue, _) = x -function replace_uses_with(x::ReturnNode, v::SSAValue, new_v) - return isdefined(x, :val) ? ReturnNode(_replace(v, new_v, x.val)) : x -end - -# Return new_value if val equals current_val. -_replace(val::SSAValue, new_val, current_val) = val == current_val ? new_val : current_val - """ lookup_ir(interp::AbstractInterpreter, sig::Type{<:Tuple})::Tuple{IRCode, T} @@ -239,7 +189,7 @@ function lookup_ir(interp::CC.AbstractInterpreter, sig::Type{<:Tuple}) end """ - is_reachable(x::ReturnNode) + is_reachable_return_node(x::ReturnNode) Determine whether `x` is a `ReturnNode`, and if it is, if it is also reachable. This is purely a function of whether or not its `val` field is defined or not. @@ -248,7 +198,7 @@ is_reachable_return_node(x::ReturnNode) = isdefined(x, :val) is_reachable_return_node(x) = false """ - is_unreachable(x::ReturnNode) + is_unreachable_return_node(x::ReturnNode) Determine whehter `x` is a `ReturnNode`, and if it is, if it is also unreachable. This is purely a function of whether or not its `val` field is defined or not. @@ -256,13 +206,6 @@ purely a function of whether or not its `val` field is defined or not. is_unreachable_return_node(x::ReturnNode) = !isdefined(x, :val) is_unreachable_return_node(x) = false -""" - globalref_type(x::GlobaRef) - -Returns the static type of the value referred to by `x`. -""" -globalref_type(x::GlobalRef) = isconst(x) ? _typeof(getglobal(x.mod, x.name)) : x.binding.ty - """ UnhandledLanguageFeatureException(message::String) diff --git a/src/stack.jl b/src/stack.jl index ec83ddaf..23dba0d6 100644 --- a/src/stack.jl +++ b/src/stack.jl @@ -11,25 +11,16 @@ mutable struct Stack{T} Stack{T}() where {T} = new{T}(Vector{T}(undef, 0), 0) end -function Stack{T}(x) where {T} - stack = Stack{T}() - push!(stack, x) - return stack -end - -Stack(x::T) where {T} = Stack{T}(x) - @inline function Base.push!(x::Stack{T}, val::T) where {T} position = x.position + 1 memory = x.memory x.position = position if position <= length(memory) @inbounds memory[position] = val - return nothing else @noinline push!(memory, val) - return nothing end + return nothing end @inline function Base.pop!(x::Stack) @@ -39,37 +30,7 @@ end return val end -Base.isempty(x::Stack) = x.position == 0 - -Base.length(x::Stack) = x.position - -""" - Base.getindex(x::Stack) - -Return the value at the top of `x` without popping it. -""" -Base.getindex(x::Stack) = x.memory[x.position] - -""" - Base.setindex!(x::Stack, v) - -Set the value of the element at the top of the `x` to `v`. -""" -function Base.setindex!(x::Stack, v) - x.memory[x.position] = v - return v -end - -Base.eltype(::Stack{T}) where {T} = T - - struct SingletonStack{T} end Base.push!(::SingletonStack, ::Any) = nothing @generated Base.pop!(::SingletonStack{T}) where {T} = T.instance - - -function reverse_data_ref_type(::Type{P}) where {P} - P === DataType && return Ref{Any} - return Base.RefValue{rdata_type(tangent_type(P))} -end diff --git a/src/tangents.jl b/src/tangents.jl index 21e966c6..66d50c90 100644 --- a/src/tangents.jl +++ b/src/tangents.jl @@ -47,8 +47,6 @@ _wrap_type(::Type{T}) where {T} = PossiblyUninitTangent{T} _wrap_field(::Type{Q}, x::T) where {Q, T} = PossiblyUninitTangent{Q}(x) _wrap_field(x::T) where {T} = _wrap_field(T, x) -Base.eltype(::Type{PossiblyUninitTangent{T}}) where {T} = T - struct Tangent{Tfields<:NamedTuple} fields::Tfields end diff --git a/test/codual.jl b/test/codual.jl index 21115f51..0f3508fe 100644 --- a/test/codual.jl +++ b/test/codual.jl @@ -2,7 +2,6 @@ @test CoDual(5.0, 4.0) isa CoDual{Float64, Float64} @test CoDual(Float64, NoTangent()) isa CoDual{Type{Float64}, NoTangent} @test zero_codual(5.0) == CoDual(5.0, 0.0) - @test Tapir.uninit_codual(5.0) == CoDual(5.0, 0.0) @test codual_type(Float64) == CoDual{Float64, Float64} @test codual_type(Int) == CoDual{Int, NoTangent} @test codual_type(Real) == CoDual diff --git a/test/interpreter/ir_utils.jl b/test/interpreter/ir_utils.jl index 1d8495d3..321021cd 100644 --- a/test/interpreter/ir_utils.jl +++ b/test/interpreter/ir_utils.jl @@ -42,49 +42,6 @@ end # Check that the ir is runable. @test Core.OpaqueClosure(ir)(5.0) == cos(sin(5.0)) end - @testset "replace_all_uses_with!" begin - - # `replace_all_uses_with!` is just a lightweight wrapper around `replace_uses_with`, - # so we just test that carefully. - @testset "replace_uses_with $val" for (val, target) in Any[ - (5.0, 5.0), - (5, 5), - (Expr(:call, sin, SSAValue(1)), Expr(:call, sin, SSAValue(2))), - (Expr(:call, sin, SSAValue(3)), Expr(:call, sin, SSAValue(3))), - (GotoNode(1), GotoNode(1)), - (GotoIfNot(false, 5), GotoIfNot(false, 5)), - (GotoIfNot(SSAValue(1), 3), GotoIfNot(SSAValue(2), 3)), - (GotoIfNot(SSAValue(3), 3), GotoIfNot(SSAValue(3), 3)), - ( - PhiNode(Int32[1, 2, 3], Any[5, SSAValue(1), SSAValue(3)]), - PhiNode(Int32[1, 2, 3], Any[5, SSAValue(2), SSAValue(3)]), - ), - (PiNode(SSAValue(1), Float64), PiNode(SSAValue(2), Float64)), - (PiNode(SSAValue(3), Float64), PiNode(SSAValue(3), Float64)), - (PiNode(Argument(1), Float64), PiNode(Argument(1), Float64)), - (QuoteNode(:a_quote), QuoteNode(:a_quote)), - (ReturnNode(5), ReturnNode(5)), - (ReturnNode(SSAValue(1)), ReturnNode(SSAValue(2))), - (ReturnNode(SSAValue(3)), ReturnNode(SSAValue(3))), - (ReturnNode(), ReturnNode()), - ] - @test Tapir.replace_uses_with(val, SSAValue(1), SSAValue(2)) == target - end - @testset "PhiNode with undefined" begin - vals_with_undef_1 = Vector{Any}(undef, 2) - vals_with_undef_1[2] = SSAValue(1) - val = PhiNode(Int32[1, 2], vals_with_undef_1) - result = Tapir.replace_uses_with(val, SSAValue(1), SSAValue(2)) - @test result.values[2] == SSAValue(2) - @test !isassigned(result.values, 1) - end - end - @testset "globalref_type" begin - @test Tapir.globalref_type(GlobalRef(IRUtilsGlobalRefs, :__x_1)) == Any - @test Tapir.globalref_type(GlobalRef(IRUtilsGlobalRefs, :__x_2)) == Float64 - @test Tapir.globalref_type(GlobalRef(IRUtilsGlobalRefs, :__x_3)) == Float64 - @test Tapir.globalref_type(GlobalRef(IRUtilsGlobalRefs, :__x_4)) == Float64 - end @testset "unhandled_feature" begin @test_throws Tapir.UnhandledLanguageFeatureException Tapir.unhandled_feature("foo") end diff --git a/test/stack.jl b/test/stack.jl index 1c760929..325551fe 100644 --- a/test/stack.jl +++ b/test/stack.jl @@ -4,14 +4,10 @@ push!(s, 5.0) @test s.position == 1 @test s.memory[1] == 5.0 - @test length(s) == 1 - @test !isempty(s) + @test length(s.memory) == 1 - s[] = 6.0 - @test s[] == 6.0 - @test pop!(s) == 6.0 + @test pop!(s) == 5.0 @test s.position == 0 - @test length(s) == 0 - @test isempty(s) + @test length(s.memory) == 1 end end