diff --git a/Project.toml b/Project.toml index f135556c..65294a51 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Tapir" uuid = "07d77754-e150-4737-8c94-cd238a1fb45b" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.2.50" +version = "0.2.51" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/tangents.jl b/src/tangents.jl index c058f373..aaec52fe 100644 --- a/src/tangents.jl +++ b/src/tangents.jl @@ -428,42 +428,16 @@ handles both circular references and aliasing correctly. """ zero_tangent(x) function zero_tangent(x::P) where {P} - return isbitstype(P) ? zero_tangent_internal(x) : zero_tangent_internal(x, IdDict()) -end - -@inline zero_tangent_internal(::Union{Int8, Int16, Int32, Int64, Int128}) = NoTangent() -@inline zero_tangent_internal(x::IEEEFloat) = zero(x) -@inline function zero_tangent_internal(x::P) where {P<:Union{Tuple, NamedTuple}} - return tangent_type(P) == NoTangent ? NoTangent() : tuple_map(zero_tangent_internal, x) -end -@generated function zero_tangent_internal(x::P) where P - - tangent_type(P) == NoTangent && return NoTangent() - - # This method can only handle struct types. Tell user to implement tangent type - # directly for primitive types. - isprimitivetype(P) && throw(error( - "$P is a primitive type. Implement a method of `zero_tangent` for it." - )) - - # Derive zero tangent. Tangent types of fields, and types of zeros need only agree - # if field types are concrete. - tangent_field_zeros_exprs = ntuple(fieldcount(P)) do n - if tangent_field_type(P, n) <: PossiblyUninitTangent - V = PossiblyUninitTangent{tangent_type(fieldtype(P, n))} - return :(isdefined(x, $n) ? $V(zero_tangent_internal(getfield(x, $n))) : $V()) - else - return :(zero_tangent_internal(getfield(x, $n))) - end - end - backing_data_expr = Expr(:call, :tuple, tangent_field_zeros_exprs...) - backing_expr = :($(backing_type(P))($backing_data_expr)) - return :($(tangent_type(P))($backing_expr)) + return zero_tangent_internal(x, isbitstype(P) ? nothing : IdDict()) end # the `stackdict` naming following convention of Julia's `deepcopy` and `deepcopy_internal` # https://github.com/JuliaLang/julia/blob/48d4fd48430af58502699fdf3504b90589df3852/base/deepcopy.jl#L35 -@inline zero_tangent_internal(x::Union{Int8,Int16,Int32,Int64,Int128,IEEEFloat}, stackdict::IdDict) = zero_tangent_internal(x) +@inline zero_tangent_internal(::Union{Int8, Int16, Int32, Int64, Int128}, ::Any) = NoTangent() +@inline zero_tangent_internal(x::IEEEFloat, ::Any) = zero(x) +@inline function zero_tangent_internal(x::P, stackdict::Any) where {P<:Union{Tuple, NamedTuple}} + return tangent_type(P) == NoTangent ? NoTangent() : tuple_map(Base.Fix2(zero_tangent_internal, stackdict), x) +end @inline function zero_tangent_internal(x::SimpleVector, stackdict::IdDict) return map!(n -> zero_tangent_internal(x[n], stackdict), Vector{Any}(undef, length(x)), eachindex(x)) end @@ -474,14 +448,17 @@ end stackdict[x] = zt return _map_if_assigned!(Base.Fix2(zero_tangent_internal, stackdict), zt, x)::Array{tangent_type(P), N} end -@inline function zero_tangent_internal(x::P, stackdict::IdDict) where {P<:Union{Tuple, NamedTuple}} - return tangent_type(P) == NoTangent ? NoTangent() : tuple_map(Base.Fix2(zero_tangent_internal, stackdict), x) -end -function zero_tangent_internal(x::P, stackdict::IdDict) where {P} - +function zero_tangent_internal(x::P, stackdict) where {P} tangent_type(P) == NoTangent && return NoTangent() if tangent_type(P) <: MutableTangent + if !(stackdict isa IdDict) + throw( + ArgumentError( + "Internal error: stackdict must be an IdDict for mutable structs, not $(typeof(stackdict)). Please report this issue." + ) + ) + end if haskey(stackdict, x) return stackdict[x]::tangent_type(P) end @@ -489,30 +466,24 @@ function zero_tangent_internal(x::P, stackdict::IdDict) where {P} # if circular reference exists, then the recursive call will first look up the stackdict # and return the uninitialised MutableTangent # after the recursive call returns, the stackdict will be initialised - stackdict[x].fields = backing_type(P)(zero_tangent_struct_field(x, stackdict)) + stackdict[x].fields = zero_tangent_struct_field(x, stackdict) return stackdict[x]::tangent_type(P) else - if isbitstype(P) - return zero_tangent_internal(x) - else - return tangent_type(P)(backing_type(P)(zero_tangent_struct_field(x, stackdict))) - end + return tangent_type(P)(zero_tangent_struct_field(x, stackdict)) end end -@inline function zero_tangent_struct_field(x::P, stackdict::IdDict) where {P} - return ntuple(fieldcount(P)) do n +@generated function zero_tangent_struct_field(x::P, stackdict) where {P} + tangent_field_zeros_exprs = ntuple(fieldcount(P)) do n if tangent_field_type(P, n) <: PossiblyUninitTangent V = PossiblyUninitTangent{tangent_type(fieldtype(P, n))} - if isdefined(x, n) - return V(zero_tangent_internal(getfield(x, n), stackdict)) - else - return V() - end + return :(isdefined(x, $n) ? $V(zero_tangent_internal(getfield(x, $n), stackdict)) : $V()) else - return zero_tangent_internal(getfield(x, n), stackdict) + return :(zero_tangent_internal(getfield(x, $n), stackdict)) end end + tangent_fields_expr = Expr(:call, :tuple, tangent_field_zeros_exprs...) + return :($(backing_type(P))($tangent_fields_expr)) end """ @@ -529,47 +500,62 @@ details -- this docstring is intentionally non-specific in order to avoid becomi Required for testing. Generate a randomly-chosen tangent to `x`. +The design is closely modelled after `zero_tangent`. """ -randn_tangent(::AbstractRNG, ::NoTangent) = NoTangent() -randn_tangent(rng::AbstractRNG, ::T) where {T<:IEEEFloat} = randn(rng, T) -function randn_tangent(rng::AbstractRNG, x::Array{T, N}) where {T, N} - dx = Array{tangent_type(T), N}(undef, size(x)...) - return _map_if_assigned!(Base.Fix1(randn_tangent, rng), dx, x) +function randn_tangent(rng::AbstractRNG, x::T) where {T} + return randn_tangent_internal(rng, x, isbitstype(T) ? nothing : IdDict()) +end + +randn_tangent_internal(::AbstractRNG, ::NoTangent, ::Any) = NoTangent() +randn_tangent_internal(rng::AbstractRNG, ::T, ::Any) where {T<:IEEEFloat} = randn(rng, T) +function randn_tangent_internal(rng::AbstractRNG, x::P, stackdict::Any) where {P<:Union{Tuple, NamedTuple}} + return tangent_type(P) == NoTangent ? NoTangent() : tuple_map(x -> randn_tangent_internal(rng, x, stackdict), x) end -function randn_tangent(rng::AbstractRNG, x::SimpleVector) +function randn_tangent_internal(rng::AbstractRNG, x::SimpleVector, stackdict::IdDict) return map!(Vector{Any}(undef, length(x)), eachindex(x)) do n - return randn_tangent(rng, x[n]) + return randn_tangent_internal(rng, x[n], stackdict) end end -function randn_tangent(rng::AbstractRNG, x::P) where {P <: Union{Tuple, NamedTuple}} - tangent_type(P) == NoTangent && return NoTangent() - return tuple_map(x -> randn_tangent(rng, x), x) -end -function randn_tangent(rng::AbstractRNG, x::T) where {T<:Union{Tangent, MutableTangent}} - return T(randn_tangent(rng, x.fields)) +function randn_tangent_internal(rng::AbstractRNG, x::Array{T, N}, stackdict::IdDict) where {T, N} + haskey(stackdict, x) && return stackdict[x]::tangent_type(typeof(x)) + + dx = Array{tangent_type(T), N}(undef, size(x)...) + stackdict[x] = dx + return _map_if_assigned!(x -> randn_tangent_internal(rng, x, stackdict), dx, x) end -@generated function randn_tangent(rng::AbstractRNG, x::P) where {P} - - # If `P` doesn't have a tangent space, always return `NoTangent()`. - tangent_type(P) === NoTangent && return NoTangent() +function randn_tangent_internal(rng::AbstractRNG, x::P, stackdict) where {P} + tangent_type(P) == NoTangent && return NoTangent() - # This method can only handle struct types. Tell user to implement tangent type - # directly for primitive types. - isprimitivetype(P) && throw(error( - "$P is a primitive type. Implement a method of `randn_tangent` for it." - )) + if tangent_type(P) <: MutableTangent + if !(stackdict isa IdDict) + throw( + ArgumentError( + "Internal error: stackdict must be an IdDict for mutable structs, not $(typeof(stackdict)). Please report this issue." + ) + ) + end + if haskey(stackdict, x) + return stackdict[x]::tangent_type(P) + end + stackdict[x] = tangent_type(P)() + stackdict[x].fields = randn_tangent_struct_field(rng, x, stackdict) + return stackdict[x]::tangent_type(P) + else + return tangent_type(P)(randn_tangent_struct_field(rng, x, stackdict)) + end +end - # Assume `P` is a generic struct type, and derive the tangent recursively. +@generated function randn_tangent_struct_field(rng::AbstractRNG, x::P, stackdict) where {P} tangent_field_exprs = map(1:fieldcount(P)) do n if tangent_field_type(P, n) <: PossiblyUninitTangent V = PossiblyUninitTangent{tangent_type(fieldtype(P, n))} - return :(isdefined(x, $n) ? $V(randn_tangent(rng, getfield(x, $n))) : $V()) + return :(isdefined(x, $n) ? $V(randn_tangent_internal(rng, getfield(x, $n), stackdict)) : $V()) else - return :(randn_tangent(rng, getfield(x, $n))) + return :(randn_tangent_internal(rng, getfield(x, $n), stackdict)) end end tangent_fields_expr = Expr(:call, :tuple, tangent_field_exprs...) - return :($(tangent_type(P))($(backing_type(P))($tangent_fields_expr))) + return :($(backing_type(P))($tangent_fields_expr)) end """ @@ -793,7 +779,7 @@ for T in [Symbol, Int, Val] @eval increment_field!!(::NoTangent, ::NoTangent, f::Union{$T}) = NoTangent() end -#= +""" tangent_test_cases() Constructs a `Vector` of `Tuple`s containing test cases for the tangent infrastructure. @@ -809,13 +795,9 @@ If the returned tuple has 5 elements, then the elements are interpreted as follo 2 - primal value 3, 4, 5 - tangents, where <5> == increment!!(<3>, <4>). -Generally speaking, it's very straightforward to produce test cases in the first format, -while the second requires more work. Consequently, at the time of writing there are many -more instances of the first format than the second. - Test cases in the first format make use of `zero_tangent` / `randn_tangent` etc to generate tangents, but they're unable to check that `increment!!` is correct in an absolute sense. -=# +""" function tangent_test_cases() N_large = 33 diff --git a/src/test_utils.jl b/src/test_utils.jl index 5f1d8e3c..7687b5db 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -106,49 +106,98 @@ function report_opt(::Any, tt) throw(error("Load JET to use this function.")) end -has_equal_data(x::T, y::T; equal_undefs=true) where {T<:String} = x == y -has_equal_data(x::Type, y::Type; equal_undefs=true) = x == y -has_equal_data(x::Core.TypeName, y::Core.TypeName; equal_undefs=true) = x == y -has_equal_data(x::Module, y::Module; equal_undefs=true) = x == y -function has_equal_data(x::T, y::T; equal_undefs=true) where {T<:Array} +""" + has_equal_data(x, y; equal_undefs=true) + +Determine if two objects `x` and `y` have equivalent data. If `equal_undefs` +is `true`, undefined elements in arrays or unassigned fields in structs are +considered equal. + +The main logic is implemented in `has_equal_data_internal`, which is a recursive function +that takes an additional `visited` dictionary to track visited objects and avoid infinite +recursion in cases of circular references. +""" +function has_equal_data(x, y; equal_undefs=true) + return has_equal_data_internal(x, y, equal_undefs, Dict{Tuple{UInt, UInt}, Bool}()) +end + +has_equal_data_internal(x::Type, y::Type, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool}) = x == y +has_equal_data_internal(x::T, y::T, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool}) where {T<:String} = x == y +has_equal_data_internal(x::Core.TypeName, y::Core.TypeName, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool}) = x == y +function has_equal_data_internal(x::Float64, y::Float64, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool}) + return (isapprox(x, y) && !isnan(x)) || (isnan(x) && isnan(y)) +end +has_equal_data_internal(x::Module, y::Module, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool}) = x == y +function has_equal_data_internal(x::GlobalRef, y::GlobalRef; equal_undefs=true, d::Dict{Tuple{UInt, UInt}, Bool}) + return x.mod == y.mod && x.name == y.name +end +function has_equal_data_internal(x::T, y::T, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool}) where {T<:Array} size(x) != size(y) && return false + + # The dictionary is used to detect circular references in the data structures. + # For example, if x.a.a === x and y.a.a === y, we want to consider them to have equal data. + # + # When we first encounter a pair of objects: + # 1. We add them to the dictionary, marking that we've seen them. + # 2. This doesn't guarantee they're equal, just that we've encountered them. + # + # As we recursively compare x and y: + # - If we see a pair we've seen before, it indicates circular references. + # - We consider "circular references to itself" as equal data for this subcomponent. + # - However, other parts of x and y may still differ, so we continue checking. + + id_pair = (objectid(x), objectid(y)) + if haskey(d, id_pair) + return d[id_pair] + end + + d[id_pair] = true equality = map(1:length(x)) do n - (isassigned(x, n) != isassigned(y, n)) && return !equal_undefs - return (!isassigned(x, n) || has_equal_data(x[n], y[n])) + if isassigned(x, n) != isassigned(y, n) + return !equal_undefs + elseif !isassigned(x, n) + return true + else + return has_equal_data_internal(x[n], y[n], equal_undefs, d) + end end return all(equality) end -function has_equal_data(x::Float64, y::Float64; equal_undefs=true) - return (isapprox(x, y) && !isnan(x)) || (isnan(x) && isnan(y)) -end -function has_equal_data(x::T, y::T; equal_undefs=true) where {T<:Core.SimpleVector} - return all(map((a, b) -> has_equal_data(a, b; equal_undefs), x, y)) +function has_equal_data_internal(x::T, y::T, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool}) where {T<:Core.SimpleVector} + return all(map((a, b) -> has_equal_data_internal(a, b, equal_undefs, d), x, y)) end -function has_equal_data(x::T, y::T; equal_undefs=true) where {T} +function has_equal_data_internal(x::T, y::T, equal_undefs::Bool, d::Dict{Tuple{UInt, UInt}, Bool}) where {T} isprimitivetype(T) && return isequal(x, y) + + id_pair = (objectid(x), objectid(y)) + if haskey(d, id_pair) + return d[id_pair] + end + + d[id_pair] = true + if ismutabletype(x) - return all(map( - n -> isdefined(x, n) ? has_equal_data(getfield(x, n), getfield(y, n)) : true, - fieldnames(T), - )) + return all(map(fieldnames(T)) do n + isdefined(x, n) ? has_equal_data_internal(getfield(x, n), getfield(y, n), equal_undefs, d) : true + end) else for n in fieldnames(T) - if isdefined(x, n) - if isdefined(y, n) && has_equal_data(getfield(x, n), getfield(y, n)) + if !isdefined(x, n) && !isdefined(y, n) + continue # consider undefined fields as equal + elseif isdefined(x, n) && isdefined(y, n) + if has_equal_data_internal(getfield(x, n), getfield(y, n), equal_undefs, d) continue else return false end - else - return isdefined(y, n) ? false : true + else # one is defined and the other is not + return false end end return true end end -function has_equal_data(x::GlobalRef, y::GlobalRef; equal_undefs=true) - return x.mod == y.mod && x.name == y.name -end +has_equal_data_internal(x::T, y::P, equal_undefs::Bool, d::Dict{Tuple{UInt,UInt},Bool}) where {T,P} = false has_equal_data_up_to_undefs(x::T, y::T) where {T} = has_equal_data(x, y; equal_undefs=false) @@ -817,6 +866,8 @@ Verify that primal `p` with tangents `z_target`, `x`, and `y`, satisfies the tan interface. If these tests pass, then it should be possible to write rules for primals of type `P`, and to test them using [`test_rule`](@ref). +It should be the case that `z_target` == `increment!!(x, y)`. + As always, there are limits to the errors that these tests can identify -- they form necessary but not sufficient conditions for the correctness of your code. """ @@ -1035,10 +1086,13 @@ end mutable struct TypeUnstableMutableStruct a::Float64 b + TypeUnstableMutableStruct(a::Float64) = new(a) + TypeUnstableMutableStruct(a::Float64, b) = new(a, b) end -function Base.:(==)(a::TypeUnstableMutableStruct, b::TypeUnstableMutableStruct) - return equal_field(a, b, :a) && equal_field(a, b, :b) +mutable struct TypeUnstableMutableStruct2 + a + b end struct TypeStableStruct{T} @@ -1048,13 +1102,16 @@ struct TypeStableStruct{T} TypeStableStruct{T}(a::Float64, b::T) where {T} = new{T}(a, b) end -function Base.:(==)(a::TypeStableStruct, b::TypeStableStruct) - return equal_field(a, b, :a) && equal_field(a, b, :b) +struct TypeUnstableStruct2 + a + b end struct TypeUnstableStruct a::Float64 b + TypeUnstableStruct(a::Float64) = new(a) + TypeUnstableStruct(a::Float64, b) = new(a, b) end function Base.:(==)(a::TypeUnstableStruct, b::TypeUnstableStruct) @@ -1078,6 +1135,36 @@ struct StructNoRvs x::Vector{Float64} end +# +# generate test cases for circular references +# + +function make_circular_reference_struct() + c = TypeUnstableMutableStruct(1.0, nothing) + c.b = c + return c +end + +function make_indirect_circular_reference_struct() + c = TypeUnstableMutableStruct(1.0) + _c = TypeUnstableMutableStruct(1.0, c) + c.b = _c + return c +end + +function make_circular_reference_array() + a = Any[1.0, 2.0, 3.0] + a[1] = a + return a +end + +function make_indirect_circular_reference_array() + a = Any[1.0, 2.0, 3.0] + b = Any[a, 4.0] + a[1] = b + return a +end + # # Tests for AD. There are not rules defined directly on these functions, and they require # that most language primitives have rules defined. diff --git a/test/front_matter.jl b/test/front_matter.jl index b42f6693..82582859 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -74,7 +74,15 @@ using .TestUtils: using .TestResources: TypeStableMutableStruct, StructFoo, - MutableFoo + MutableFoo, + TypeUnstableStruct, + TypeUnstableStruct2, + TypeUnstableMutableStruct, + TypeUnstableMutableStruct2, + make_circular_reference_struct, + make_indirect_circular_reference_struct, + make_circular_reference_array, + make_indirect_circular_reference_array # The integration tests take ages to run, so we split them up. CI sets up two jobs -- the # "basic" group runs test that, when passed, _ought_ to imply correctness of the entire diff --git a/test/tangents.jl b/test/tangents.jl index 2f3ed36b..172b644b 100644 --- a/test/tangents.jl +++ b/test/tangents.jl @@ -163,51 +163,49 @@ end end -# TODO: ideally we want to add the following test to the above testset (defined src/tangent.jl) -# but we have to delay this until `randn_tangent` is implemented and working. -@testset "zero_tangent" begin +# TODO: add the following test to `tangent_test_cases` +@testset "zero_tangent and randn_tangent" begin @testset "circular reference" begin - foo = Tapir.TestResources.TypeUnstableMutableStruct(5.0, nothing) - foo.b = foo + foo = make_circular_reference_struct() zt = Tapir.zero_tangent(foo) - @test zt.fields.b === zt + @test zt.fields.b.tangent === zt + rt = Tapir.randn_tangent(Xoshiro(123456), foo) + @test rt.fields.b.tangent === rt end @testset "struct with non-concrete fields" begin bar = Tapir.TestResources.TypeUnstableStruct(5.0, 1.0) - @test Tapir.zero_tangent(bar) == Tangent{@NamedTuple{a::Float64, b}}(@NamedTuple{a::Float64, b}((0.0, 0.0))) + @test ==( + Tapir.zero_tangent(bar), + Tangent{@NamedTuple{a::Float64, b::PossiblyUninitTangent{Any}}}((a=0.0, b=PossiblyUninitTangent{Any}(0.0))) + ) end - + @testset "duplicate reference" begin @testset "subarray" begin - mutable struct MutDupRefSubArray - x - y - end - x = [1.0, 2.0, 3.0] - mut_struct = MutDupRefSubArray(view(x, 1:2), view(x, 1:2)) - mt = Tapir.zero_tangent(mut_struct) + immutable_struct = TypeUnstableStruct2(view(x, 1:2), view(x, 1:2)) + mt = Tapir.zero_tangent(immutable_struct) + @test mt isa Tapir.Tangent + @test mt.fields.a === mt.fields.b + rt = Tapir.randn_tangent(Xoshiro(123456), immutable_struct) + @test rt.fields.a === rt.fields.b + + mutable_struct = TypeUnstableMutableStruct2(view(x, 1:2), view(x, 1:2)) + mt = Tapir.zero_tangent(mutable_struct) @test mt isa Tapir.MutableTangent - @test mt.fields.x === mt.fields.y - - struct ImmutableDupRefSubArray - x - y - end - - immutable_struct = ImmutableDupRefSubArray(view(x, 1:2), view(x, 1:1)) - it = Tapir.zero_tangent(immutable_struct) - @test it isa Tapir.Tangent - @test it.fields.x.fields.parent === it.fields.y.fields.parent + @test mt.fields.a.fields.parent === mt.fields.b.fields.parent + rt = Tapir.randn_tangent(Xoshiro(123456), mutable_struct) + @test rt.fields.a.fields.parent === rt.fields.b.fields.parent end end @testset "indirect circular reference" begin - m = [Any[1], 2.0] - m[1][1] = m + m = make_indirect_circular_reference_array() zt = Tapir.zero_tangent(m) @test zt[1][1] === zt + rt = Tapir.randn_tangent(Xoshiro(123456), m) + @test rt[1][1] === rt end end diff --git a/test/test_utils.jl b/test/test_utils.jl index c7a309b6..0099bec1 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -1,11 +1,4 @@ @testset "test_utils" begin - struct TesterStruct - x - y::Float64 - TesterStruct() = new() - TesterStruct(x) = new(x) - TesterStruct(x, y) = new(x, y) - end @testset "has_equal_data" begin @test !has_equal_data(5.0, 4.0) @test has_equal_data(5.0, 5.0) @@ -18,15 +11,29 @@ @test !has_equal_data(ones(1), ones(2)) @test !has_equal_data(randn(5), randn(5)) @test has_equal_data(ones(5), ones(5)) - @test has_equal_data(Complex(5.0, 4.0), Complex(5.0, 4.0)) + @test has_equal_data(Base, Base) + @test !has_equal_data(Base, Core) + @test has_equal_data(GlobalRef(Base, :sin), GlobalRef(Base, :sin)) + @test !has_equal_data(GlobalRef(Base, :sin), GlobalRef(Base, :cos)) + @test !has_equal_data(GlobalRef(Base, :sin), GlobalRef(Core, :sin)) + @test has_equal_data(Complex(5.0, 4.0), Complex(5.0, 4.0))[] @test !has_equal_data(Complex(5.0, 4.0), Complex(5.0, 5.0)) @test !has_equal_data(Diagonal(randn(5)), Diagonal(randn(5))) @test has_equal_data(Diagonal(ones(5)), Diagonal(ones(5))) @test has_equal_data("hello", "hello") @test !has_equal_data("hello", "goodbye") - @test has_equal_data(TesterStruct(), TesterStruct()) - @test has_equal_data(TesterStruct(5, 4.0), TesterStruct(5, 4.0)) - @test !has_equal_data(TesterStruct(), TesterStruct(5)) + @test has_equal_data(TypeUnstableMutableStruct(4.0, 5), TypeUnstableMutableStruct(4.0, 5)) + @test !has_equal_data(TypeUnstableMutableStruct(4.0, 5), TypeUnstableMutableStruct(4.0, 6)) + @test has_equal_data(TypeUnstableStruct(4.0, 5), TypeUnstableStruct(4.0, 5)) + @test !has_equal_data(TypeUnstableStruct(0.0), TypeUnstableStruct(4.0)) + @test has_equal_data(make_circular_reference_struct(), make_circular_reference_struct()) + @test has_equal_data(make_indirect_circular_reference_struct(), make_indirect_circular_reference_struct()) + @test has_equal_data(make_circular_reference_array(), make_circular_reference_array()) + @test has_equal_data(make_indirect_circular_reference_array(), make_indirect_circular_reference_array()) + @test !has_equal_data((s = make_circular_reference_struct(); s.a = 1.0; s), (t = make_circular_reference_struct(); t.a = 2.0; t)) + @test !has_equal_data((a = make_indirect_circular_reference_array(); a[1][1] = 1.0; a), (b = make_indirect_circular_reference_array(); b[1][1] = 2.0; b)) + @test !has_equal_data((s = make_indirect_circular_reference_struct(); s.b.a = 1.0; s), (t = make_indirect_circular_reference_struct(); t.b.a = 2.0; t)) + @test !has_equal_data((a = make_indirect_circular_reference_array(); a[1][1] = 1.0; a), (b = make_indirect_circular_reference_array(); b[1][1] = 2.0; b)) end @testset "populate_address_map" begin @testset "primitive types" begin