Skip to content

Commit

Permalink
Recurse rewrite_generic into :if and :block (#303)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Sep 25, 2024
1 parent bc97807 commit bcd1514
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 1 deletion.
41 changes: 40 additions & 1 deletion src/rewrite_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,53 @@ function _is_kwarg(expr, kwarg::Symbol)
return Meta.isexpr(expr, :kw) && expr.args[1] == kwarg
end

function _rewrite_elseif!(if_expr, expr::Any)
if expr isa Expr && Meta.isexpr(expr, :elseif)
new_ifelse_expr = Expr(:elseif, esc(expr.args[1]))
push!(if_expr.args, new_ifelse_expr)
@assert 2 <= length(expr.args) <= 3
return mapreduce(&, 2:length(expr.args)) do i
return _rewrite_elseif!(new_ifelse_expr, expr.args[i])
end
end
stack = quote end
root, is_mutable = _rewrite_generic(stack, expr)
push!(stack.args, root)
push!(if_expr.args, stack)
return is_mutable
end

_rewrite_generic(stack::Expr, expr::LineNumberNode) = expr, false

"""
_rewrite_generic(stack::Expr, expr::Expr)
This method is the heart of the rewrite logic. It converts `expr` into a mutable
equivalent.
"""
function _rewrite_generic(stack::Expr, expr::Expr)
if !Meta.isexpr(expr, :call)
if Meta.isexpr(expr, :block)
new_stack = quote end
for arg in expr.args
root, _ = _rewrite_generic(new_stack, arg)
push!(new_stack.args, root)
end
root = gensym()
push!(stack.args, :($root = $new_stack))
return root, false
elseif Meta.isexpr(expr, :if)
# `if` blocks are special, because we can't lift the computation inside
# them into the stack; the values might be defined only if the branch is
# true.
if_expr = Expr(:if, esc(expr.args[1]))
@assert 2 <= length(expr.args) <= 3
is_mutable = mapreduce(&, 2:length(expr.args)) do i
return _rewrite_elseif!(if_expr, expr.args[i])
end
root = gensym()
push!(stack.args, :($root = $if_expr))
return root, is_mutable
elseif !Meta.isexpr(expr, :call)
# In situations like `x[i]`, we do not attempt to rewrite. Return `expr`
# and don't let future callers mutate.
return esc(expr), false
Expand Down
73 changes: 73 additions & 0 deletions test/rewrite_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,79 @@ function test_rewrite_generic_sum_dims()
return
end

function test_rewrite_block()
@test_rewrite begin
x = 1
y = x + 2
z = 3 * y
end
@test_rewrite begin
x = [1]
y = x + [2]
z = 3 * y
end
return
end

function test_rewrite_ifelse()
@test_rewrite begin
x = -1
y = [3.0]
if x < 1
y .+ x
else
2 * x
end
end
@test_rewrite begin
x = 2
y = [3.0]
if x < 1
y .+ x
else
2 * x
end
end
@test_rewrite begin
x = 2
y = [3.0, 4.0]
if x < 1
y .+ x
elseif length(y) == 2
0.0
else
2 * x
end
end
@test_rewrite begin
x = 2
y = Float64[]
if x < 1
y .+ x
elseif length(y) == 2
0.0
elseif isempty(y)
-1.0
else
2 * x
end
end
@test_rewrite begin
x = 2
y = Float64[1.0]
if x < 1
1.0
elseif length(y) == 2
2.0
elseif isempty(y)
3.0
else
4.0
end
end
return
end

end # module

TestRewriteGeneric.runtests()

0 comments on commit bcd1514

Please sign in to comment.