From 72dbd20e6ad26c859b6864a47dfe7e3e840eeda4 Mon Sep 17 00:00:00 2001 From: odow Date: Fri, 1 Sep 2023 12:06:08 +1200 Subject: [PATCH 1/3] Fix rewrite of +(x, *(y...)) into add_mul --- src/rewrite_generic.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/rewrite_generic.jl b/src/rewrite_generic.jl index 2c1a06d0..535be6dc 100644 --- a/src/rewrite_generic.jl +++ b/src/rewrite_generic.jl @@ -121,6 +121,17 @@ function _rewrite_generic(stack::Expr, expr::Expr) @assert length(expr.args) > 1 if length(expr.args) == 2 # +(arg) return _rewrite_generic(stack, expr.args[2]) + elseif length(expr.args) == 3 && _is_call(expr.args[3], :*) + # +(x, *(y...)) => add_mul(x, y...) + x, is_mutable = _rewrite_generic(stack, expr.args[2]) + rhs = Expr(:call, operate!!, add_mul, x) + for i in 2:length(expr.args[3].args) + yi, _ = _rewrite_generic(stack, expr.args[3].args[i]) + push!(rhs.args, yi) + end + root = gensym() + push!(stack.args, :($root = $rhs)) + return root, is_mutable end return _rewrite_generic_to_nested_op(stack, expr, add_mul) elseif expr.args[1] == :- From 115d97cf3f1c2ea619985d3c4af0323b7631a914 Mon Sep 17 00:00:00 2001 From: odow Date: Fri, 1 Sep 2023 12:18:31 +1200 Subject: [PATCH 2/3] Add test --- src/rewrite_generic.jl | 8 ++++++-- test/rewrite_generic.jl | 12 ++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/rewrite_generic.jl b/src/rewrite_generic.jl index 535be6dc..55dcf6f5 100644 --- a/src/rewrite_generic.jl +++ b/src/rewrite_generic.jl @@ -124,14 +124,18 @@ function _rewrite_generic(stack::Expr, expr::Expr) elseif length(expr.args) == 3 && _is_call(expr.args[3], :*) # +(x, *(y...)) => add_mul(x, y...) x, is_mutable = _rewrite_generic(stack, expr.args[2]) - rhs = Expr(:call, operate!!, add_mul, x) + rhs = if is_mutable + Expr(:call, operate!!, add_mul, x) + else + Expr(:call, operate, add_mul, x) + end for i in 2:length(expr.args[3].args) yi, _ = _rewrite_generic(stack, expr.args[3].args[i]) push!(rhs.args, yi) end root = gensym() push!(stack.args, :($root = $rhs)) - return root, is_mutable + return root, true end return _rewrite_generic_to_nested_op(stack, expr, add_mul) elseif expr.args[1] == :- diff --git a/test/rewrite_generic.jl b/test/rewrite_generic.jl index 078d1b4d..de6e355a 100644 --- a/test/rewrite_generic.jl +++ b/test/rewrite_generic.jl @@ -324,6 +324,18 @@ function test_rewrite_kw_in_ref() return end +function test_rewrite_expression() + x = [1.2] + @test MA.@rewrite(x + 2 * x, move_factors_into_sums = false) == 3x + @test MA.@rewrite(x + *(2, x, 3), move_factors_into_sums = false) == + x + *(2, x, 3) + y = 1.2 + @test MA.@rewrite(y + 2 * y, move_factors_into_sums = false) == 3y + @test MA.@rewrite(y + *(2, y, 3), move_factors_into_sums = false) == + y + *(2, y, 3) + return +end + end # module TestRewriteGeneric.runtests() From 6cc2f52913c1678e546b964cba553229731260ca Mon Sep 17 00:00:00 2001 From: odow Date: Wed, 6 Sep 2023 15:32:41 +1200 Subject: [PATCH 3/3] Add more tests --- test/rewrite_generic.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/rewrite_generic.jl b/test/rewrite_generic.jl index de6e355a..d6d32ebb 100644 --- a/test/rewrite_generic.jl +++ b/test/rewrite_generic.jl @@ -330,6 +330,14 @@ function test_rewrite_expression() @test MA.@rewrite(x + *(2, x, 3), move_factors_into_sums = false) == x + *(2, x, 3) y = 1.2 + @test MA.@rewrite( + sum(y for i in 1:2) + 2y, + move_factors_into_sums = false + ) == 4y + @test MA.@rewrite( + sum(y for i in 1:2) + y * y, + move_factors_into_sums = false + ) == 2y + y^2 @test MA.@rewrite(y + 2 * y, move_factors_into_sums = false) == 3y @test MA.@rewrite(y + *(2, y, 3), move_factors_into_sums = false) == y + *(2, y, 3)