Skip to content

Commit

Permalink
Merge pull request #34 from SciML/gd/new_types
Browse files Browse the repository at this point in the history
Support FastDifferentiation.jl
  • Loading branch information
Vaibhavdixit02 authored Apr 3, 2024
2 parents 93b8c6c + 13857fb commit ec248df
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
20 changes: 19 additions & 1 deletion src/ADTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ abstract type AbstractSymbolicDifferentiationMode <: AbstractADType end
abstract type AbstractSparseReverseMode <: AbstractReverseMode end
abstract type AbstractSparseForwardMode <: AbstractForwardMode end
abstract type AbstractSparseFiniteDifferences <: AbstractFiniteDifferencesMode end
abstract type AbstractSparseSymbolicDifferentiationMode <:
AbstractSymbolicDifferentiationMode end

"""
AutoChainRules{RC}
Expand Down Expand Up @@ -233,6 +235,20 @@ Chooses [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl).
"""
struct AutoDiffractor <: AbstractADType end

"""
AutoFastDifferentiation
Chooses [FastDifferentiation.jl](https://github.com/brianguenter/FastDifferentiation.jl).
"""
struct AutoFastDifferentiation <: AbstractSymbolicDifferentiationMode end

"""
AutoSparseFastDifferentiation
Chooses [FastDifferentiation.jl](https://github.com/brianguenter/FastDifferentiation.jl) while exploiting sparsity.
"""
struct AutoSparseFastDifferentiation <: AbstractSparseSymbolicDifferentiationMode end

export AutoChainRules,
AutoDiffractor,
AutoFiniteDiff,
Expand All @@ -248,5 +264,7 @@ export AutoChainRules,
AutoSparseZygote,
AutoSparseReverseDiff,
AutoPolyesterForwardDiff,
AutoSparsePolyesterForwardDiff
AutoSparsePolyesterForwardDiff,
AutoFastDifferentiation,
AutoSparseFastDifferentiation
end
8 changes: 8 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,12 @@ struct CustomTag end
adtype = AutoDiffractor()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoDiffractor

adtype = AutoFastDifferentiation()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoFastDifferentiation

adtype = AutoSparseFastDifferentiation()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoSparseFastDifferentiation
end

0 comments on commit ec248df

Please sign in to comment.