Skip to content

Commit

Permalink
Reuse more code
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed May 13, 2024
1 parent f692da2 commit 1f1de4c
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 24 deletions.
23 changes: 7 additions & 16 deletions ext/LuxEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ end
# Case I: We have CachedEnzymeExtras and objective_function is unchanged.
function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data,
ts::Lux.Experimental.TrainState{<:CachedEnzymeExtras, F}) where {F}
dps = ts.cache.dparameters
Lux.__recursive_make_zero!(dps)
dps = Lux.__recursive_make_zero!!(ts.cache.dparameters)

_, loss = Enzyme.autodiff(
Enzyme.ReverseWithPrimal, ts.cache.objective_function, Active, Const(ts.model),
Expand All @@ -33,8 +32,7 @@ end
# Case II: We have CachedEnzymeExtras and objective_function is changed.
function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data,
ts::Lux.Experimental.TrainState{<:CachedEnzymeExtras}) where {F}
dps = ts.cache.dparameters
Lux.__recursive_make_zero!(dps)
dps = Lux.__recursive_make_zero!!(ts.cache.dparameters)

obj_fn, st_wrap, stats_wrap = Lux.Experimental.__wrap_objective_function(
objective_function, ts.states)
Expand All @@ -49,20 +47,13 @@ function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F,
end

# Case III: Nothing is cached. First call to `compute_gradients`
function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data,
function Lux.Experimental.compute_gradients(ad::AutoEnzyme, objective_function::F, data,
ts::Lux.Experimental.TrainState) where {F}
dps = Lux.__recursive_make_zero(ts.parameters)

obj_fn, st_wrap, stats_wrap = Lux.Experimental.__wrap_objective_function(
objective_function, ts.states)

_, loss = Enzyme.autodiff(Enzyme.ReverseWithPrimal, obj_fn, Active, Const(ts.model),
Duplicated(ts.parameters, dps), Const(ts.states), Const(data))

ts_new = __construct_new_trainstate(
st_wrap, ts.states, ts, objective_function, dps, obj_fn, st_wrap, stats_wrap)

return dps, loss, stats_wrap, ts_new
cache = CachedEnzymeExtras(dps, nothing, nothing, nothing)
ts_new = Lux.Experimental.TrainState(
cache, nothing, ts.model, ts.parameters, ts.states, ts.optimizer_state, ts.step)
return Lux.Experimental.compute_gradients(ad, objective_function, data, ts_new)
end

# If `st_new` is of a new type, we will have to recompute the cache anyway. Force it by not
Expand Down
16 changes: 9 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ end
@inline __size(x::AbstractArray) = size(x)
@inline __size(x::T) where {T} = hasmethod(size, Tuple{T}) ? size(x) : nothing

@inline __recursive_make_zero(x::Number) = zero(x)
@inline __recursive_make_zero(x::AbstractArray{<:Number}) = zero(x)
@inline __recursive_make_zero(x::AbstractArray) = map(__recursive_make_zero, x)
@inline __recursive_make_zero(x::Tuple) = map(__recursive_make_zero, x)
Expand All @@ -297,10 +298,11 @@ end
@inline __recursive_make_zero(v::Val) = v
@inline __recursive_make_zero(x) = fmap(__recursive_make_zero, x)

@inline __recursive_make_zero!(x::AbstractArray{<:Number}) = fill!(x, zero(eltype(x)))
@inline __recursive_make_zero!(x::AbstractArray) = map(__recursive_make_zero!, x)
@inline __recursive_make_zero!(x::Tuple) = map(__recursive_make_zero!, x)
@inline __recursive_make_zero!(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(map(
__recursive_make_zero!, values(x)))
@inline __recursive_make_zero!(::Nothing) = nothing
@inline __recursive_make_zero!(x) = fmap(__recursive_make_zero!, x)
@inline __recursive_make_zero!!(x::Number) = zero(x)
@inline __recursive_make_zero!!(x::AbstractArray{<:Number}) = fill!(x, zero(eltype(x)))
@inline __recursive_make_zero!!(x::AbstractArray) = map(__recursive_make_zero!!, x)
@inline __recursive_make_zero!!(x::Tuple) = map(__recursive_make_zero!!, x)
@inline __recursive_make_zero!!(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(map(
__recursive_make_zero!!, values(x)))
@inline __recursive_make_zero!!(::Nothing) = nothing
@inline __recursive_make_zero!!(x) = fmap(__recursive_make_zero!!, x)
3 changes: 2 additions & 1 deletion test/qa_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ end

@testitem "Explicit Imports: Quality Assurance" tags=[:others] begin
# Load all trigger packages
import Lux, ComponentArrays, ReverseDiff, Flux, LuxAMDGPU, SimpleChains, Tracker, Zygote
import Lux, ComponentArrays, ReverseDiff, Flux, LuxAMDGPU, SimpleChains, Tracker,
Zygote, Enzyme

using ExplicitImports

Expand Down

0 comments on commit 1f1de4c

Please sign in to comment.