From a690869d75954e6eb14622295f96e556f2774935 Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Tue, 22 Aug 2023 09:00:45 +1200 Subject: [PATCH] Fix broadcasting for non-DefaultArrayStyle containers (#222) --- src/broadcast.jl | 13 +++++++++++++ test/broadcast.jl | 8 ++++++++ 2 files changed, 21 insertions(+) diff --git a/src/broadcast.jl b/src/broadcast.jl index 7ebd8fa3..7356e809 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -23,6 +23,19 @@ function _broadcasted_type( return BitArray{N} end +""" +This method is a generic fallback for array types that are not +`DefaultArrayStyle`. Because we can't tell the container from a generic +broadcast style, we fallback to `Any`, which is always a valid super type (just +not a helpful one). + +In MutableArithmetics, `_broadcasted_type` appears only in `promote_broadcast`, +which itself appears only in `broadcast_mutability`, and so types hitting this +method will fallback to the `IsNotMutable()` branch, which is the expected +outcome. +""" +_broadcasted_type(::Broadcast.BroadcastStyle, ::Base.HasShape, ::Type) = Any + # Same as `Base.Broadcast._combine_styles` but with types as argument. _combine_styles() = Broadcast.DefaultArrayStyle{0}() diff --git a/test/broadcast.jl b/test/broadcast.jl index 2ce5335f..edad933a 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -44,3 +44,11 @@ end @test MA.@rewrite(y .* x) == y .* x @test MA.@rewrite(x .* y) == x .* y end + +struct Struct221 <: AbstractArray{Int,1} end +struct BroadcastStyle221 <: Broadcast.BroadcastStyle end +Base.BroadcastStyle(::Type{Struct221}) = BroadcastStyle221() + +@testset "promote_broadcast_for_new_style" begin + @test MA.promote_broadcast(MA.add_mul, Vector{Int}, Struct221) === Any +end