Skip to content

Commit

Permalink
Constrain type in to_vec(::AbstractArray/Vector) to DenseArray/Vect…
Browse files Browse the repository at this point in the history
…or (#156)
  • Loading branch information
mzgubic authored Apr 28, 2021
1 parent c319525 commit 5935f4a
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 24 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FiniteDifferences"
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
version = "0.12.2"
version = "0.12.3"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
9 changes: 4 additions & 5 deletions src/to_vec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ function to_vec(x::T) where {T}
v, vals_from_vec = to_vec(vals)
function structtype_from_vec(v::Vector{<:Real})
val_vecs = vals_from_vec(v)
vals = map((b, v) -> b(v), backs, val_vecs)
return T(vals...)
values = map((b, v) -> b(v), backs, val_vecs)
return T(values...)
end
return v, structtype_from_vec
end

function to_vec(x::AbstractVector)
function to_vec(x::DenseVector)
x_vecs_and_backs = map(to_vec, x)
x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
function Vector_from_vec(x_vec)
Expand All @@ -53,7 +53,7 @@ function to_vec(x::AbstractVector)
return x_vec, Vector_from_vec
end

function to_vec(x::AbstractArray)
function to_vec(x::DenseArray)
x_vec, from_vec = to_vec(vec(x))

function Array_from_vec(x_vec)
Expand All @@ -63,7 +63,6 @@ function to_vec(x::AbstractArray)
return x_vec, Array_from_vec
end


# Some specific subtypes of AbstractArray.
function to_vec(x::Base.ReshapedArray{<:Any, 1})
x_vec, from_vec = to_vec(parent(x))
Expand Down
48 changes: 30 additions & 18 deletions test/to_vec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ end
Base.:(==)(x::DummyType, y::DummyType) = x.X == y.X
Base.length(x::DummyType) = size(x.X, 1)

# A dummy FillVector. This is a type for which the fallback implementation of
# `to_vec` should fail loudly.
# A dummy FillVector
struct FillVector <: AbstractVector{Float64}
x::Float64
len::Int
end

Base.size(x::FillVector) = (x.len,)
Base.getindex(x::FillVector, n::Int) = x.x

# For testing Composite{ThreeFields}
struct ThreeFields
a
Expand All @@ -32,10 +34,17 @@ struct Nested
y::Singleton
end

Base.size(x::FillVector) = (x.len,)
Base.getindex(x::FillVector, n::Int) = x.x
# For testing generic subtypes of AbstractArray
struct WrapperArray{T, N, A<:AbstractArray{T, N}} <: AbstractArray{T, N}
data::A
end
function WrapperArray(a::AbstractArray{T, N}) where {T, N}
return WrapperArray{T, N, AbstractArray{T, N}}(a)
end
Base.size(a::WrapperArray) = size(a.data)
Base.getindex(a::WrapperArray, inds...) = getindex(a.data, inds...)

function test_to_vec(x::T; check_inferred = true) where {T}
function test_to_vec(x::T; check_inferred=true) where {T}
check_inferred && @inferred to_vec(x)
x_vec, back = to_vec(x)
@test x_vec isa Vector
Expand All @@ -61,14 +70,14 @@ end
test_to_vec(randn(T, 5, 11))
test_to_vec(randn(T, 13, 17, 19))
test_to_vec(randn(T, 13, 0, 19))
test_to_vec([1.0, randn(T, 2), randn(T, 1), 2.0]; check_inferred = false)
test_to_vec([randn(T, 5, 4, 3), (5, 4, 3), 2.0]; check_inferred = false)
test_to_vec(reshape([1.0, randn(T, 5, 4, 3), randn(T, 4, 3), 2.0], 2, 2); check_inferred = false)
test_to_vec([1.0, randn(T, 2), randn(T, 1), 2.0]; check_inferred=false)
test_to_vec([randn(T, 5, 4, 3), (5, 4, 3), 2.0]; check_inferred=false)
test_to_vec(reshape([1.0, randn(T, 5, 4, 3), randn(T, 4, 3), 2.0], 2, 2); check_inferred=false)
test_to_vec(UpperTriangular(randn(T, 13, 13)))
test_to_vec(Diagonal(randn(T, 7)))
test_to_vec(DummyType(randn(T, 2, 9)))
test_to_vec(SVector{2, T}(1.0, 2.0))
test_to_vec(SMatrix{2, 2, T}(1.0, 2.0, 3.0, 4.0))
test_to_vec(SVector{2, T}(1.0, 2.0); check_inferred=false)
test_to_vec(SMatrix{2, 2, T}(1.0, 2.0, 3.0, 4.0); check_inferred=false)
test_to_vec(@view randn(T, 10)[1:4]) # SubArray -- Vector
test_to_vec(@view randn(T, 10, 2)[1:4, :]) # SubArray -- Matrix
test_to_vec(Base.ReshapedArray(rand(T, 3, 3), (9,), ()))
Expand Down Expand Up @@ -111,10 +120,10 @@ end
test_to_vec((5, 4))
# TODO remove "< 1.6" once https://github.com/JuliaLang/julia/issues/40277
test_to_vec((5, randn(T, 5)); check_inferred = VERSION v"1.2" && VERSION < v"1.6")
test_to_vec((randn(T, 4), randn(T, 4, 3, 2), 1); check_inferred = false)
test_to_vec((randn(T, 4), randn(T, 4, 3, 2), 1); check_inferred=false)
# TODO remove "< 1.6" once https://github.com/JuliaLang/julia/issues/40277
test_to_vec((5, randn(T, 4, 3, 2), UpperTriangular(randn(T, 4, 4)), 2.5); check_inferred = VERSION v"1.2" && VERSION < v"1.6")
test_to_vec(((6, 5), 3, randn(T, 3, 2, 0, 1)); check_inferred = false)
test_to_vec(((6, 5), 3, randn(T, 3, 2, 0, 1)); check_inferred=false)
test_to_vec((DummyType(randn(T, 2, 7)), DummyType(randn(T, 3, 9))))
test_to_vec((DummyType(randn(T, 3, 2)), randn(T, 11, 8)))
end
Expand All @@ -127,9 +136,9 @@ end
end
@testset "Dictionary" begin
if T == Float64
test_to_vec(Dict(:a=>5, :b=>randn(10, 11), :c=>(5, 4, 3)); check_inferred = false)
test_to_vec(Dict(:a=>5, :b=>randn(10, 11), :c=>(5, 4, 3)); check_inferred=false)
else
test_to_vec(Dict(:a=>3 + 2im, :b=>randn(T, 10, 11), :c=>(5+im, 2-im, 1+im)); check_inferred = false)
test_to_vec(Dict(:a=>3 + 2im, :b=>randn(T, 10, 11), :c=>(5+im, 2-im, 1+im)); check_inferred=false)
end
end
end
Expand All @@ -146,7 +155,7 @@ end
x_inner = (2, 3)
x_outer = (1, x_inner)
x_comp = Composite{typeof(x_outer)}(1, Composite{typeof(x_inner)}(2, 3))
test_to_vec(x_comp; check_inferred = false)
test_to_vec(x_comp; check_inferred=false)
end
end

Expand All @@ -173,13 +182,16 @@ end
end

@testset "FillVector" begin
x = FillVector(5.0, 10)
x_vec, from_vec = to_vec(x)
@test_throws MethodError from_vec(randn(10))
test_to_vec(FillVector(5.0, 10); check_inferred=false)
end

@testset "fallback" begin
nested = Nested(ThreeFields(1.0, 2.0, "Three"), Singleton())
test_to_vec(nested; check_inferred=false) # map
end

@testset "WrapperArray" begin
wa = WrapperArray(rand(4, 5))
test_to_vec(wa; check_inferred=false)
end
end

0 comments on commit 5935f4a

Please sign in to comment.