Skip to content

Commit

Permalink
Merge pull request #13 from prbzrg/use-kwdef
Browse files Browse the repository at this point in the history
Use `Base.@kwdef`
  • Loading branch information
Vaibhavdixit02 authored Aug 12, 2023
2 parents d29dd9d + 0e9ace0 commit 7dd6cc3
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 41 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ jobs:
fail-fast: false
matrix:
version:
- '1.0'
- '1.8'
- '1.6'
- '1'
- 'nightly'
os:
- ubuntu-latest
Expand Down
64 changes: 29 additions & 35 deletions src/ADTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,65 +5,59 @@ Base type for AD choices.
"""
abstract type AbstractADType end

struct AutoFiniteDiff{T1, T2, T3} <: AbstractADType
fdtype::T1
fdjtype::T2
fdhtype::T3
Base.@kwdef struct AutoFiniteDiff{T1, T2, T3} <: AbstractADType
fdtype::T1 = Val(:forward)
fdjtype::T2 = fdtype
fdhtype::T3 = Val(:hcentral)
end

function AutoFiniteDiff(; fdtype = Val(:forward), fdjtype = fdtype,
fdhtype = Val(:hcentral))
AutoFiniteDiff(fdtype, fdjtype, fdhtype)
Base.@kwdef struct AutoFiniteDifferences{T} <: AbstractADType
fdm::T = nothing
end

struct AutoFiniteDifferences{T} <: AbstractADType
fdm::T
end

AutoFiniteDifferences(; fdm = nothing) = AutoFiniteDifferences(fdm)

struct AutoForwardDiff{chunksize,T} <: AbstractADType
struct AutoForwardDiff{chunksize, T} <: AbstractADType
tag::T
end

function AutoForwardDiff(; chunksize = nothing, tag = nothing)
AutoForwardDiff{chunksize,typeof(tag)}(tag)
AutoForwardDiff{chunksize, typeof(tag)}(tag)
end

struct AutoReverseDiff <: AbstractADType
compile::Bool
Base.@kwdef struct AutoReverseDiff <: AbstractADType
compile::Bool = false
end

AutoReverseDiff(; compile = false) = AutoReverseDiff(compile)

struct AutoZygote <: AbstractADType end

struct AutoEnzyme{M} <: AbstractADType
mode::M
Base.@kwdef struct AutoEnzyme{M} <: AbstractADType
mode::M = nothing
end

AutoEnzyme(; mode = nothing) = AutoEnzyme(mode)

struct AutoTracker <: AbstractADType end

struct AutoModelingToolkit <: AbstractADType
obj_sparse::Bool
cons_sparse::Bool
end

function AutoModelingToolkit(; obj_sparse = false, cons_sparse = false)
AutoModelingToolkit(obj_sparse, cons_sparse)
Base.@kwdef struct AutoModelingToolkit <: AbstractADType
obj_sparse::Bool = false
cons_sparse::Bool = false
end

struct AutoSparseFiniteDiff <: AbstractADType end

struct AutoSparseForwardDiff{chunksize,T} <: AbstractADType
struct AutoSparseForwardDiff{chunksize, T} <: AbstractADType
tag::T
end

function AutoSparseForwardDiff(; chunksize = nothing, tag = nothing)
AutoSparseForwardDiff{chunksize,typeof(tag)}(tag)
end

export AutoFiniteDiff, AutoFiniteDifferences, AutoForwardDiff, AutoReverseDiff, AutoZygote, AutoEnzyme, AutoTracker, AutoModelingToolkit, AutoSparseFiniteDiff, AutoSparseForwardDiff
AutoSparseForwardDiff{chunksize, typeof(tag)}(tag)
end

export AutoFiniteDiff,
AutoFiniteDifferences,
AutoForwardDiff,
AutoReverseDiff,
AutoZygote,
AutoEnzyme,
AutoTracker,
AutoModelingToolkit,
AutoSparseFiniteDiff,
AutoSparseForwardDiff
end
8 changes: 4 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ struct CustomTag end
adtype = AutoForwardDiff()
@test adtype isa ADTypes.AbstractADType

@test adtype isa AutoForwardDiff{nothing,Nothing}
@test adtype isa AutoForwardDiff{nothing, Nothing}

adtype = AutoForwardDiff(; chunksize = 10, tag = CustomTag())
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoForwardDiff{10,CustomTag}
@test adtype isa AutoForwardDiff{10, CustomTag}

adtype = AutoReverseDiff()
@test adtype isa ADTypes.AbstractADType
Expand All @@ -51,11 +51,11 @@ struct CustomTag end

adtype = AutoSparseForwardDiff()
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoSparseForwardDiff{nothing,Nothing}
@test adtype isa AutoSparseForwardDiff{nothing, Nothing}

adtype = AutoSparseForwardDiff(; chunksize = 10, tag = CustomTag())
@test adtype isa ADTypes.AbstractADType
@test adtype isa AutoSparseForwardDiff{10,CustomTag}
@test adtype isa AutoSparseForwardDiff{10, CustomTag}

adtype = AutoModelingToolkit()
@test adtype isa ADTypes.AbstractADType
Expand Down

0 comments on commit 7dd6cc3

Please sign in to comment.