diff --git a/Project.toml b/Project.toml index a1575231..7618fce0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MutableArithmetics" uuid = "d8a4904e-b15c-11e9-3269-09a3773c0cb0" authors = ["Gilles Peiffer", "BenoƮt Legat", "Sascha Timme"] -version = "1.3.2" +version = "1.3.3" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/rewrite_generic.jl b/src/rewrite_generic.jl index 55dcf6f5..08012ad3 100644 --- a/src/rewrite_generic.jl +++ b/src/rewrite_generic.jl @@ -51,6 +51,10 @@ function _is_parameters(expr) return Meta.isexpr(expr, :call, 3) && Meta.isexpr(expr.args[2], :parameters) end +function _is_kwarg(expr, kwarg::Symbol) + return Meta.isexpr(expr, :kw) && expr.args[1] == kwarg +end + """ _rewrite_generic(stack::Expr, expr::Expr) @@ -78,14 +82,20 @@ function _rewrite_generic(stack::Expr, expr::Expr) # come in two forms: `sum(i for i=I, j=J)` or `sum(i for i=I for j=J)`. # The latter is a `:flatten` expression and needs additional handling, # but we delay this complexity for _rewrite_generic_generator. - if Meta.isexpr(expr.args[2], :parameters, 1) && - Meta.isexpr(expr.args[2].args[1], :kw, 2) && - expr.args[2].args[1].args[1] == :init - # sum(iter ; init) form! - root = gensym() - init, _ = _rewrite_generic(stack, expr.args[2].args[1].args[2]) - push!(stack.args, :($root = $init)) - return _rewrite_generic_generator(stack, :+, expr.args[3], root) + if Meta.isexpr(expr.args[2], :parameters) + # The summation has keyword arguments. We can deal with `init`, but + # not any of the others. + p = expr.args[2] + if length(p.args) == 1 && _is_kwarg(p.args[1], :init) + # sum(iter ; init) form! + root = gensym() + init, _ = _rewrite_generic(stack, p.args[1].args[2]) + push!(stack.args, :($root = $init)) + return _rewrite_generic_generator(stack, :+, expr.args[3], root) + else + # We don't know how to deal with this + return esc(expr), false + end else # Summations use :+ as the reduction operator. init_expr = expr.args[2].args[end] diff --git a/test/rewrite_generic.jl b/test/rewrite_generic.jl index d6d32ebb..7b858742 100644 --- a/test/rewrite_generic.jl +++ b/test/rewrite_generic.jl @@ -344,6 +344,59 @@ function test_rewrite_expression() return end +function test_rewrite_generic_sum_dims() + x = [1 2; 3 4] + @test ==( + MA.@rewrite(sum(x; dims = 1), move_factors_into_sums = false), + [4 6], + ) + @test ==( + MA.@rewrite(sum(x; dims = 2), move_factors_into_sums = false), + [3; 7;;], + ) + @test ==( + MA.@rewrite(sum(x; dims = 1, init = 0), move_factors_into_sums = false), + [4 6], + ) + @test ==( + MA.@rewrite(sum(x; dims = 2, init = 0), move_factors_into_sums = false), + [3; 7;;], + ) + @test ==( + MA.@rewrite(sum(x; init = 0, dims = 1), move_factors_into_sums = false), + [4 6], + ) + @test ==( + MA.@rewrite(sum(x; init = 0, dims = 2), move_factors_into_sums = false), + [3; 7;;], + ) + @test ==( + MA.@rewrite(sum(x, dims = 1), move_factors_into_sums = false), + [4 6], + ) + @test ==( + MA.@rewrite(sum(x, dims = 2), move_factors_into_sums = false), + [3; 7;;], + ) + @test ==( + MA.@rewrite(sum(x, dims = 1, init = 0), move_factors_into_sums = false), + [4 6], + ) + @test ==( + MA.@rewrite(sum(x, dims = 2, init = 0), move_factors_into_sums = false), + [3; 7;;], + ) + @test ==( + MA.@rewrite(sum(x, init = 0, dims = 1), move_factors_into_sums = false), + [4 6], + ) + @test ==( + MA.@rewrite(sum(x, init = 0, dims = 2), move_factors_into_sums = false), + [3; 7;;], + ) + return +end + end # module TestRewriteGeneric.runtests()