diff --git a/src/ADTypes.jl b/src/ADTypes.jl index 63d7039..30a5c1d 100644 --- a/src/ADTypes.jl +++ b/src/ADTypes.jl @@ -14,6 +14,10 @@ abstract type AbstractSparseReverseMode <: AbstractReverseMode end abstract type AbstractSparseForwardMode <: AbstractForwardMode end abstract type AbstractSparseFiniteDifferences <: AbstractFiniteDifferencesMode end +Base.@kwdef struct AutoChainRules{RC} <: AbstractADType + ruleconfig::RC +end + Base.@kwdef struct AutoFiniteDiff{T1, T2, T3} <: AbstractFiniteDifferencesMode fdtype::T1 = Val(:forward) fdjtype::T2 = fdtype @@ -78,7 +82,8 @@ Base.@kwdef struct AutoSparseReverseDiff <: AbstractSparseReverseMode compile::Bool = false end -export AutoFiniteDiff, +export AutoChainRules, + AutoFiniteDiff, AutoFiniteDifferences, AutoForwardDiff, AutoReverseDiff, diff --git a/test/runtests.jl b/test/runtests.jl index f7d6fe8..2d874fc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,11 @@ using Test struct CustomTag end @testset "ADTypes.jl" begin + adtype = AutoChainRules(:ruleconfig_placeholder) + @test adtype isa ADTypes.AbstractADType + @test adtype isa AutoChainRules + @test adtype.ruleconfig == :ruleconfig_placeholder + adtype = AutoFiniteDiff() @test adtype isa ADTypes.AbstractADType @test adtype isa AutoFiniteDiff