diff --git a/src/ADTypes.jl b/src/ADTypes.jl index 965afa5..8212fa8 100644 --- a/src/ADTypes.jl +++ b/src/ADTypes.jl @@ -18,6 +18,8 @@ abstract type AbstractSymbolicDifferentiationMode <: AbstractADType end abstract type AbstractSparseReverseMode <: AbstractReverseMode end abstract type AbstractSparseForwardMode <: AbstractForwardMode end abstract type AbstractSparseFiniteDifferences <: AbstractFiniteDifferencesMode end +abstract type AbstractSparseSymbolicDifferentiationMode <: + AbstractSymbolicDifferentiationMode end """ AutoChainRules{RC} @@ -229,6 +231,20 @@ Chooses [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl). """ struct AutoDiffractor <: AbstractADType end +""" + AutoFastDifferentiation + +Chooses [FastDifferentiation.jl](https://github.com/brianguenter/FastDifferentiation.jl). +""" +struct AutoFastDifferentiation <: AbstractSymbolicDifferentiationMode end + +""" + AutoSparseFastDifferentiation + +Chooses [FastDifferentiation.jl](https://github.com/brianguenter/FastDifferentiation.jl) while exploiting sparsity. +""" +struct AutoSparseFastDifferentiation <: AbstractSparseSymbolicDifferentiationMode end + export AutoChainRules, AutoDiffractor, AutoFiniteDiff, @@ -244,5 +260,7 @@ export AutoChainRules, AutoSparseZygote, AutoSparseReverseDiff, AutoPolyesterForwardDiff, - AutoSparsePolyesterForwardDiff + AutoSparsePolyesterForwardDiff, + AutoFastDifferentiation, + AutoSparseFastDifferentiation end diff --git a/test/runtests.jl b/test/runtests.jl index 1f18a10..42b1fe0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -95,4 +95,12 @@ struct CustomTag end adtype = AutoDiffractor() @test adtype isa ADTypes.AbstractADType @test adtype isa AutoDiffractor + + adtype = AutoFastDifferentiation() + @test adtype isa ADTypes.AbstractADType + @test adtype isa AutoFastDifferentiation + + adtype = AutoSparseFastDifferentiation() + @test adtype isa ADTypes.AbstractADType + @test adtype isa AutoSparseFastDifferentiation end