diff --git a/src/definitions.jl b/src/definitions.jl index b481a3a..016e622 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -272,11 +272,12 @@ summary(p::ScaledPlan) = string(p.scale, " * ", summary(p.p)) # Normalization for ifft, given unscaled bfft, is 1/prod(dimensions) # ensure that region is a subset of eachindex(sz). -_checkindex(sz, region::AbstractVector) = checkindex(Bool, sz, region) +_checkindex(szinds, region::AbstractVector) = checkindex(Bool, szinds, region) # this method handles the case where region is not an array, e.g. it is a Tuple -_checkindex(sz, region) = all(r -> checkindex(Bool, eachindex(sz), r), region) -function normalization(::Type{T}, sz, region) where T - @boundscheck !isempty(region) && _checkindex(eachindex(sz), region) +_checkindex(szinds, region) = all(r -> checkindex(Bool, szinds, r), region) +@inline function normalization(::Type{T}, sz, region) where T + @boundscheck (!isempty(region) && _checkindex(eachindex(sz), region)) || + throw(BoundsError(sz, region)) one(T) / mapreduce(r -> Int(@inbounds sz[r])::Int, *, region; init=oneunit(eltype(sz)))::Int end normalization(X, region) = normalization(real(eltype(X)), size(X), region) diff --git a/test/runtests.jl b/test/runtests.jl index 4d402c5..ced5927 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -211,6 +211,10 @@ end # p::TestPlan) f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, fftdims(p)) @test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10 + + @test_throws BoundsError AbstractFFTs.normalization(Float64, (2,), 1:3) + @test_throws BoundsError AbstractFFTs.normalization(Float64, (2,), Int[]) + @test_throws BoundsError AbstractFFTs.normalization(Float64, (2,), (1,3,)) end @testset "ChainRules" begin