Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constrain type in to_vec(::AbstractArray/Vector) #156

Merged
merged 13 commits into from
Apr 28, 2021
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FiniteDifferences"
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
version = "0.12.2"
version = "0.13.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
9 changes: 4 additions & 5 deletions src/to_vec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ function to_vec(x::T) where {T}
v, vals_from_vec = to_vec(vals)
function structtype_from_vec(v::Vector{<:Real})
val_vecs = vals_from_vec(v)
vals = map((b, v) -> b(v), backs, val_vecs)
return T(vals...)
values = map((b, v) -> b(v), backs, val_vecs)
return T(values...)
end
return v, structtype_from_vec
end

function to_vec(x::AbstractVector)
function to_vec(x::DenseVector)
x_vecs_and_backs = map(to_vec, x)
x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
function Vector_from_vec(x_vec)
Expand All @@ -53,7 +53,7 @@ function to_vec(x::AbstractVector)
return x_vec, Vector_from_vec
end

function to_vec(x::AbstractArray)
function to_vec(x::DenseArray)
x_vec, from_vec = to_vec(vec(x))

function Array_from_vec(x_vec)
Expand All @@ -63,7 +63,6 @@ function to_vec(x::AbstractArray)
return x_vec, Array_from_vec
end


# Some specific subtypes of AbstractArray.
function to_vec(x::Base.ReshapedArray{<:Any, 1})
x_vec, from_vec = to_vec(parent(x))
Expand Down
49 changes: 30 additions & 19 deletions test/to_vec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ end
Base.:(==)(x::DummyType, y::DummyType) = x.X == y.X
Base.length(x::DummyType) = size(x.X, 1)

# A dummy FillVector. This is a type for which the fallback implementation of
# `to_vec` should fail loudly.
# A dummy FillVector
struct FillVector <: AbstractVector{Float64}
x::Float64
len::Int
end

Base.size(x::FillVector) = (x.len,)
Base.getindex(x::FillVector, n::Int) = x.x

# For testing Composite{ThreeFields}
struct ThreeFields
a
Expand All @@ -32,10 +34,7 @@ struct Nested
y::Singleton
end

Base.size(x::FillVector) = (x.len,)
Base.getindex(x::FillVector, n::Int) = x.x

function test_to_vec(x::T; check_inferred = true) where {T}
function test_to_vec(x::T; check_inferred=true) where {T}
check_inferred && @inferred to_vec(x)
x_vec, back = to_vec(x)
@test x_vec isa Vector
Expand All @@ -61,14 +60,14 @@ end
test_to_vec(randn(T, 5, 11))
test_to_vec(randn(T, 13, 17, 19))
test_to_vec(randn(T, 13, 0, 19))
test_to_vec([1.0, randn(T, 2), randn(T, 1), 2.0]; check_inferred = false)
test_to_vec([randn(T, 5, 4, 3), (5, 4, 3), 2.0]; check_inferred = false)
test_to_vec(reshape([1.0, randn(T, 5, 4, 3), randn(T, 4, 3), 2.0], 2, 2); check_inferred = false)
test_to_vec([1.0, randn(T, 2), randn(T, 1), 2.0]; check_inferred=false)
test_to_vec([randn(T, 5, 4, 3), (5, 4, 3), 2.0]; check_inferred=false)
test_to_vec(reshape([1.0, randn(T, 5, 4, 3), randn(T, 4, 3), 2.0], 2, 2); check_inferred=false)
test_to_vec(UpperTriangular(randn(T, 13, 13)))
test_to_vec(Diagonal(randn(T, 7)))
test_to_vec(DummyType(randn(T, 2, 9)))
test_to_vec(SVector{2, T}(1.0, 2.0))
test_to_vec(SMatrix{2, 2, T}(1.0, 2.0, 3.0, 4.0))
test_to_vec(SVector{2, T}(1.0, 2.0); check_inferred=false)
test_to_vec(SMatrix{2, 2, T}(1.0, 2.0, 3.0, 4.0); check_inferred=false)
test_to_vec(@view randn(T, 10)[1:4]) # SubArray -- Vector
test_to_vec(@view randn(T, 10, 2)[1:4, :]) # SubArray -- Matrix
test_to_vec(Base.ReshapedArray(rand(T, 3, 3), (9,), ()))
Expand Down Expand Up @@ -111,10 +110,10 @@ end
test_to_vec((5, 4))
# TODO remove "< 1.6" once https://github.com/JuliaLang/julia/issues/40277
test_to_vec((5, randn(T, 5)); check_inferred = VERSION ≥ v"1.2" && VERSION < v"1.6")
test_to_vec((randn(T, 4), randn(T, 4, 3, 2), 1); check_inferred = false)
test_to_vec((randn(T, 4), randn(T, 4, 3, 2), 1); check_inferred=false)
# TODO remove "< 1.6" once https://github.com/JuliaLang/julia/issues/40277
test_to_vec((5, randn(T, 4, 3, 2), UpperTriangular(randn(T, 4, 4)), 2.5); check_inferred = VERSION ≥ v"1.2" && VERSION < v"1.6")
test_to_vec(((6, 5), 3, randn(T, 3, 2, 0, 1)); check_inferred = false)
test_to_vec(((6, 5), 3, randn(T, 3, 2, 0, 1)); check_inferred=false)
test_to_vec((DummyType(randn(T, 2, 7)), DummyType(randn(T, 3, 9))))
test_to_vec((DummyType(randn(T, 3, 2)), randn(T, 11, 8)))
end
Expand All @@ -127,9 +126,9 @@ end
end
@testset "Dictionary" begin
if T == Float64
test_to_vec(Dict(:a=>5, :b=>randn(10, 11), :c=>(5, 4, 3)); check_inferred = false)
test_to_vec(Dict(:a=>5, :b=>randn(10, 11), :c=>(5, 4, 3)); check_inferred=false)
else
test_to_vec(Dict(:a=>3 + 2im, :b=>randn(T, 10, 11), :c=>(5+im, 2-im, 1+im)); check_inferred = false)
test_to_vec(Dict(:a=>3 + 2im, :b=>randn(T, 10, 11), :c=>(5+im, 2-im, 1+im)); check_inferred=false)
end
end
end
Expand All @@ -146,7 +145,7 @@ end
x_inner = (2, 3)
x_outer = (1, x_inner)
x_comp = Composite{typeof(x_outer)}(1, Composite{typeof(x_inner)}(2, 3))
test_to_vec(x_comp; check_inferred = false)
test_to_vec(x_comp; check_inferred=false)
end
end

Expand All @@ -173,13 +172,25 @@ end
end

@testset "FillVector" begin
x = FillVector(5.0, 10)
x_vec, from_vec = to_vec(x)
@test_throws MethodError from_vec(randn(10))
test_to_vec(FillVector(5.0, 10); check_inferred=false)
end

@testset "fallback" begin
nested = Nested(ThreeFields(1.0, 2.0, "Three"), Singleton())
test_to_vec(nested; check_inferred=false) # map
end

@testset "WrapperArray" begin
struct WrapperArray{T, N, A<:AbstractArray{T, N}} <: AbstractArray{T, N}
data::A
end
function WrapperArray(a::AbstractArray{T, N}) where {T, N}
return WrapperArray{T, N, AbstractArray{T, N}}(a)
end
Base.size(a::WrapperArray) = size(a.data)
Base.getindex(a::WrapperArray, inds...) = getindex(a.data, inds...)

wa = WrapperArray(rand(4, 5))
test_to_vec(wa; check_inferred=false)
end
end