Skip to content

Commit

Permalink
Merge pull request #21 from gdalle/chainrules
Browse files Browse the repository at this point in the history
Add AutoChainRules(ruleconfig)
  • Loading branch information
Vaibhavdixit02 authored Mar 8, 2024
2 parents f6fce4f + 2f6afad commit 8684021
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/ADTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ abstract type AbstractSparseReverseMode <: AbstractReverseMode end
abstract type AbstractSparseForwardMode <: AbstractForwardMode end
abstract type AbstractSparseFiniteDifferences <: AbstractFiniteDifferencesMode end

Base.@kwdef struct AutoChainRules{RC} <: AbstractADType
ruleconfig::RC
end

Base.@kwdef struct AutoFiniteDiff{T1, T2, T3} <: AbstractFiniteDifferencesMode
fdtype::T1 = Val(:forward)
fdjtype::T2 = fdtype
Expand Down Expand Up @@ -78,7 +82,8 @@ Base.@kwdef struct AutoSparseReverseDiff <: AbstractSparseReverseMode
compile::Bool = false
end

export AutoFiniteDiff,
export AutoChainRules,
AutoFiniteDiff,
AutoFiniteDifferences,
AutoForwardDiff,
AutoReverseDiff,
Expand Down
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ using Test
struct CustomTag end

@testset "ADTypes.jl" begin
adtype = AutoChainRules(:ruleconfig_placeholder)
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoChainRules
@test adtype.ruleconfig == :ruleconfig_placeholder

adtype = AutoFiniteDiff()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoFiniteDiff
Expand Down

0 comments on commit 8684021

Please sign in to comment.