Skip to content

Commit

Permalink
refactor: make LossFunctions an optional dep (#976)
Browse files Browse the repository at this point in the history
* refactor: make LossFunctions an optional dep

* feat: add custom derivative fast paths

* test: more tests got fixed
  • Loading branch information
avik-pal authored Oct 9, 2024
1 parent 04deedf commit 77eb5fb
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 57 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "1.1.0"
version = "1.2.0-DEV"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -18,7 +18,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
Expand All @@ -43,6 +42,7 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
Expand All @@ -55,6 +55,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
LuxComponentArraysExt = "ComponentArrays"
LuxEnzymeExt = "Enzyme"
LuxFluxExt = "Flux"
LuxLossFunctionsExt = "LossFunctions"
LuxMLUtilsExt = "MLUtils"
LuxMPIExt = "MPI"
LuxMPINCCLExt = ["CUDA", "MPI", "NCCL"]
Expand Down
71 changes: 71 additions & 0 deletions ext/LuxLossFunctionsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
module LuxLossFunctionsExt

using ArrayInterface: fast_scalar_indexing
using ChainRulesCore: ChainRulesCore, NoTangent, @thunk
using EnzymeCore: EnzymeCore, EnzymeRules
using FastClosures: @closure
using LossFunctions: LossFunctions
using Statistics: mean

using Lux: Lux, LossFunctionImpl

const CRC = ChainRulesCore

function LossFunctionImpl.fused_agg(
::typeof(mean), lfn::LossFunctions.Traits.Loss, x::AbstractArray, y::AbstractArray)
return LossFunctionImpl.fused_agg(sum, lfn, x, y) / length(x)
end

function LossFunctionImpl.fused_agg(
::typeof(sum), lfn::LossFunctions.Traits.Loss, x::Number, y::Number)
return lfn(x, y)
end
function LossFunctionImpl.fused_agg(
::typeof(sum), lfn::LossFunctions.Traits.Loss, x::AbstractArray, y::AbstractArray)
fast_scalar_indexing(x) && fast_scalar_indexing(y) && return sum(lfn, x, y)
return sum(lfn.(x, y))
end

function CRC.rrule(
::CRC.RuleConfig{>:CRC.HasReverseMode},
::typeof(LossFunctionImpl.fused_agg), ::typeof(sum),
lfn::LossFunctions.Traits.Loss, x, y)
∇fused_agg = @closure Δ -> begin
∂x = @thunk LossFunctions.deriv.(Ref(lfn), x, y) .* Δ
return NoTangent(), NoTangent(), NoTangent(), ∂x, NoTangent()
end
return LossFunctionImpl.fused_agg(sum, lfn, x, y), ∇fused_agg
end

function EnzymeRules.augmented_primal(
cfg::EnzymeRules.RevConfigWidth{1},
func::EnzymeCore.Const{typeof(LossFunctionImpl.fused_agg)},
::Type{<:EnzymeCore.Active}, agg_f::EnzymeCore.Const{typeof(sum)},
lfn::EnzymeCore.Const{<:LossFunctions.Traits.Loss},
x::EnzymeCore.Annotation{<:AbstractArray}, y::EnzymeCore.Const)
primal = EnzymeRules.needs_primal(cfg) ? func.val(agg_f.val, lfn.val, x.val, y.val) :
nothing

cache_x = EnzymeRules.overwritten(cfg)[4] ? copy(x.val) : nothing
cache_y = EnzymeRules.overwritten(cfg)[5] ? copy(y.val) : nothing

return EnzymeRules.AugmentedReturn(primal, nothing, (cache_x, cache_y))
end

function EnzymeRules.reverse(
cfg::EnzymeRules.RevConfigWidth{1},
::EnzymeCore.Const{typeof(LossFunctionImpl.fused_agg)},
dret::EnzymeCore.Active, (cache_x, cache_y), agg_f::EnzymeCore.Const{typeof(sum)},
lfn::EnzymeCore.Const{<:LossFunctions.Traits.Loss},
x::EnzymeCore.Annotation{<:AbstractArray}, y::EnzymeCore.Const)
EnzymeRules.overwritten(cfg)[4] || (cache_x = x.val)
EnzymeRules.overwritten(cfg)[5] || (cache_y = y.val)

if !(typeof(x) <: EnzymeCore.Const)
@. x.dval = LossFunctions.deriv(lfn.val, cache_x, cache_y) * dret.val
end

return ntuple(Returns(nothing), 4)
end

end
1 change: 0 additions & 1 deletion src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ using ConcreteStructs: @concrete
using FastClosures: @closure
using Functors: Functors, fmap
using GPUArraysCore: @allowscalar
using LossFunctions: LossFunctions
using Markdown: @doc_str
using NNlib: NNlib
using Optimisers: Optimisers
Expand Down
150 changes: 103 additions & 47 deletions src/helpers/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ module LossFunctionImpl

using ArrayInterface: fast_scalar_indexing
using ChainRulesCore: ChainRulesCore, NoTangent, @non_differentiable, @thunk
using EnzymeCore: EnzymeCore, EnzymeRules
using FastClosures: @closure
using LossFunctions: LossFunctions
using ForwardDiff: ForwardDiff, Dual, Partials
using Statistics: mean

using ..Utils: Utils
using ..LuxOps: xlogy

const CRC = ChainRulesCore
Expand All @@ -30,59 +30,66 @@ check_sizes(_, __) = nothing

# Aggregation. We are able to define custom aggregation fast paths
fused_agg(::typeof(mean), op::OP, x) where {OP} = fused_agg(sum, op, x) / length(x)
function fused_agg(::typeof(mean), lfn::LossFunctions.Traits.Loss, x, y)
return fused_agg(sum, lfn, x, y) / length(x)
end

fused_agg(::typeof(sum), op::OP, x::Number) where {OP} = op(x)
fused_agg(::typeof(sum), op::OP, x) where {OP} = sum(op, x)

fused_agg(::typeof(sum), lfn::LossFunctions.Traits.Loss, x::Number, y::Number) = lfn(x, y)
function fused_agg(::typeof(sum), lfn::LossFunctions.Traits.Loss, x, y)
fast_scalar_indexing(x) && fast_scalar_indexing(y) && return sum(lfn, x, y)
# mapreduce(Broadcast.BroadcastFunction(lfn), +, x, y) leads to slowdowns, better to
# allocate a new array
return sum(lfn.(x, y))
fused_agg(::typeof(mean), op::OP, x::Number, y::Number) where {OP} = op(x, y)
function fused_agg(::typeof(mean), op::OP, x::AbstractArray, y::AbstractArray) where {OP}
return fused_agg(sum, op, x, y) / length(x)
end

fused_agg(::Nothing, op::OP, args...) where {OP} = op.(args...)
fused_agg(f::F, op::OP, args...) where {F, OP} = f(op.(args...))

function CRC.rrule(::typeof(fused_agg), ::typeof(sum), lfn::LossFunctions.Traits.Loss, x, y)
∇fused_agg = @closure Δ -> begin
∂x = @thunk LossFunctions.deriv.(Ref(lfn), x, y) .* Δ
return NoTangent(), NoTangent(), NoTangent(), ∂x, NoTangent()
fused_agg(::typeof(sum), op::OP, x::Number, y::Number) where {OP} = op(x, y)
function fused_agg(::typeof(sum), op::OP, x::AbstractArray, y::AbstractArray) where {OP}
if fast_scalar_indexing(x) && fast_scalar_indexing(y)
res = Core.Compiler._return_type(op, Tuple{eltype(x), eltype(y)})(0)
@simd ivdep for i in eachindex(x, y)
@inbounds res += op(x[i], y[i])
end
return res
end
return fused_agg(sum, lfn, x, y), ∇fused_agg
return fallback_fused_agg(sum, op, x, y)
end

function EnzymeRules.augmented_primal(
cfg::EnzymeRules.RevConfigWidth{1}, func::EnzymeCore.Const{typeof(fused_agg)},
::Type{<:EnzymeCore.Active}, agg_f::EnzymeCore.Const{typeof(sum)},
lfn::EnzymeCore.Const{<:LossFunctions.Traits.Loss},
x::EnzymeCore.Annotation{<:AbstractArray}, y::EnzymeCore.Const)
primal = EnzymeRules.needs_primal(cfg) ? func.val(agg_f.val, lfn.val, x.val, y.val) :
nothing

cache_x = EnzymeRules.overwritten(cfg)[4] ? copy(x.val) : nothing
cache_y = EnzymeRules.overwritten(cfg)[5] ? copy(y.val) : nothing
fused_agg(::Nothing, op::OP, args...) where {OP} = op.(args...)
fused_agg(f::F, op::OP, args...) where {F, OP} = fallback_fused_agg(f, op, args...)

return EnzymeRules.AugmentedReturn(primal, nothing, (cache_x, cache_y))
end
@inline fallback_fused_agg(f::F, op::OP, args...) where {F, OP} = f(op.(args...))

function EnzymeRules.reverse(
cfg::EnzymeRules.RevConfigWidth{1}, ::EnzymeCore.Const{typeof(fused_agg)},
dret::EnzymeCore.Active, (cache_x, cache_y), agg_f::EnzymeCore.Const{typeof(sum)},
lfn::EnzymeCore.Const{<:LossFunctions.Traits.Loss},
x::EnzymeCore.Annotation{<:AbstractArray}, y::EnzymeCore.Const)
EnzymeRules.overwritten(cfg)[4] || (cache_x = x.val)
EnzymeRules.overwritten(cfg)[5] || (cache_y = y.val)
function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode},
::typeof(fused_agg), ::typeof(sum), op::OP, x, y) where {OP}
if has_custom_derivative(op)
res = fused_agg(sum, op, x, y)
∇fused_agg_custom_derivative = Δ -> begin
∂x = @thunk derivative.(Ref(op), x, y) .* Δ
return NoTangent(), NoTangent(), NoTangent(), ∂x, NoTangent()
end
return res, ∇fused_agg_custom_derivative
end

if !(typeof(x) <: EnzymeCore.Const)
@. x.dval = LossFunctions.deriv(lfn.val, cache_x, cache_y) * dret.val
# Without custom derivatives use ForwardDiff for the looped implementation
if fast_scalar_indexing(x) && fast_scalar_indexing(y)
x_dual = Dual{
Nothing, eltype(x), 1}.(x, (Partials{1, eltype(x)}((one(eltype(x)),)),))
x_partials = similar(x)
T = eltype(x)
res = Core.Compiler._return_type(op, Tuple{T, eltype(y)})(0)
@inbounds @simd for i in eachindex(x_partials, x, y)
x_dual = Dual{Nothing, T, 1}(x[i], Partials{1, T}((one(T),)))
tmp = op(x_dual, y[i])
x_partials[i] = ForwardDiff.partials(tmp, 1)
res += ForwardDiff.value(tmp)
end
∇fused_agg_loop = Δ -> begin
@simd ivdep for i in eachindex(x_partials)
@inbounds x_partials[i] *= Δ
end
return NoTangent(), NoTangent(), NoTangent(), x_partials, NoTangent()
end
return res, ∇fused_agg_loop
end

return ntuple(Returns(nothing), 4)
return CRC.rrule_via_ad(cfg, fallback_fused_agg, sum, op, x, y)
end

get_ϵ(::Type{T}, ϵ::Real) where {T} = T(ϵ)
Expand All @@ -91,9 +98,57 @@ get_ϵ(::Type{T}, ::Nothing) where {T} = eps(float(T))
get_loss_dims(::AbstractVector) = Colon()
get_loss_dims(::AbstractArray{T, N}) where {T, N} = 1:(N - 1)

has_custom_derivative(::F) where {F} = false

has_custom_derivative(f::Utils.Fix3) = has_custom_derivative(f.f)
derivative(f::Utils.Fix3, x, y) = derivative(f.f, x, y, f.x)

# Functional forms of losses
l1_distance_loss(x::T1, y::T2) where {T1, T2} = abs(x - y)
has_custom_derivative(::typeof(l1_distance_loss)) = true
function derivative(::typeof(l1_distance_loss), x::T1, y::T2) where {T1, T2}
return convert(T1, sign(x - y))
end

l2_distance_loss(x::T1, y::T2) where {T1, T2} = abs2(x - y)
has_custom_derivative(::typeof(l2_distance_loss)) = true
function derivative(::typeof(l2_distance_loss), x::T1, y::T2) where {T1, T2}
return convert(T1, 2 * (x - y))
end

function huber_loss(x::T1, y::T2, δ::T3) where {T1, T2, T3}
T = promote_type(T1, T2, T3)
diff = x - y
abs_diff = abs(diff)
return ifelse(abs_diff δ, T(0.5) * abs2(diff), δ * (abs_diff - T(0.5) * δ))
end
has_custom_derivative(::typeof(huber_loss)) = true
function derivative(::typeof(huber_loss), x::T, y::T2, δ::T3) where {T, T2, T3}
diff = x - y
return ifelse(abs(diff) δ, T(diff), T(δ) * convert(T, sign(diff)))
end

function l1_hinge_loss(x::T1, y::T2) where {T1, T2}
agreement = x * y
return max(oftype(agreement, false), true - agreement)
end
has_custom_derivative(::typeof(l1_hinge_loss)) = true
function derivative(::typeof(l1_hinge_loss), x::T1, y::T2) where {T1, T2}
return T1(ifelse(x * y 1, false, true))
end

function l2_hinge_loss(x::T1, y::T2) where {T1, T2}
agreement = x * y
return ifelse(agreement 1, oftype(agreement, false), abs2(true - agreement))
end
has_custom_derivative(::typeof(l2_hinge_loss)) = true
function derivative(::typeof(l2_hinge_loss), x::T1, y::T2) where {T1, T2}
agreement = x * y
return T1(ifelse(agreement 1, false, 2 * (agreement - true)))
end

function siamese_contrastive_loss(x::T1, y::T2, margin=true) where {T1, T2}
return (1 - y) * x^2 + y * max(promote_type(T1, T2)(0), margin - x)^2
return (true - y) * x^2 + y * max(promote_type(T1, T2)(false), margin - x)^2
end

poisson_loss(x::T1, y::T2, ϵ) where {T1, T2} = x - xlogy(y, x + get_ϵ(T1, ϵ))
Expand Down Expand Up @@ -462,7 +517,7 @@ julia> loss(y_pred, y_true) ≈ 0.55
true
```
"""
HingeLoss(; agg=mean) = GenericLossFunction(LossFunctions.L1HingeLoss(); agg)
HingeLoss(; agg=mean) = GenericLossFunction(LossFunctionImpl.l1_hinge_loss; agg)

@doc doc"""
HuberLoss(; delta = 1, agg = mean)
Expand Down Expand Up @@ -490,7 +545,8 @@ true
"""
function HuberLoss(; delta::Union{Nothing, AbstractFloat}=nothing, agg=mean)
return GenericLossFunction(
LossFunctions.HuberLoss(ifelse(delta === nothing, Float16(1), delta)); agg)
Utils.Fix3(LossFunctionImpl.huber_loss, ifelse(delta === nothing, true, delta));
agg)
end

@doc doc"""
Expand Down Expand Up @@ -566,7 +622,7 @@ julia> loss(y_model, 1:3) ≈ 0.1
true
```
"""
MAELoss(; agg=mean) = GenericLossFunction(LossFunctions.L1DistLoss(); agg)
MAELoss(; agg=mean) = GenericLossFunction(LossFunctionImpl.l1_distance_loss; agg)

const L1Loss = MAELoss

Expand All @@ -588,7 +644,7 @@ julia> loss(y_model, 1:3) ≈ 0.01
true
```
"""
MSELoss(; agg=mean) = GenericLossFunction(LossFunctions.L2DistLoss(); agg)
MSELoss(; agg=mean) = GenericLossFunction(LossFunctionImpl.l2_distance_loss; agg)

const L2Loss = MSELoss

Expand Down Expand Up @@ -696,7 +752,7 @@ julia> loss(y_pred, y_true) ≈ 0.625
true
```
"""
SquaredHingeLoss(; agg=mean) = GenericLossFunction(LossFunctions.L2HingeLoss(); agg)
SquaredHingeLoss(; agg=mean) = GenericLossFunction(LossFunctionImpl.l2_hinge_loss; agg)

@doc doc"""
GenericLossFunction(loss_fn; agg = mean)
Expand Down
10 changes: 3 additions & 7 deletions test/helpers/loss_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,7 @@ end

@jet MSLELoss()(ŷ, y)

if VERSION v"1.11-"
@test @inferred(Zygote.gradient(MSLELoss(), ŷ, y)) isa Any
else
@test_broken @inferred(Zygote.gradient(MSLELoss(), ŷ, y)) isa Any
end
@test @inferred(Zygote.gradient(MSLELoss(), ŷ, y)) isa Any broken=ongpu

__f = Base.Fix2(MSLELoss(), y)
@test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3)
Expand Down Expand Up @@ -343,7 +339,7 @@ end
@test Lux.PoissonLoss()(y, y) 0.5044459776946685

@jet Lux.PoissonLoss()(ŷ, y)
@test_broken @inferred Zygote.gradient(Lux.PoissonLoss(), ŷ, y)
@test @inferred Zygote.gradient(Lux.PoissonLoss(), ŷ, y) isa Any

__f = Base.Fix2(Lux.PoissonLoss(), y)
@test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3)
Expand All @@ -357,7 +353,7 @@ end
@test DiceCoeffLoss()(y, y) 0.0

@jet DiceCoeffLoss()(ŷ, y)
@test_broken @inferred Zygote.gradient(DiceCoeffLoss(), ŷ, y)
@test @inferred(Zygote.gradient(DiceCoeffLoss(), ŷ, y)) isa Any broken=true

__f = Base.Fix2(DiceCoeffLoss(), y)
@test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3,
Expand Down

1 comment on commit 77eb5fb

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lux Benchmarks

Benchmark suite Current: 77eb5fb Previous: 04deedf Ratio
Dense(512 => 512, identity)(512 x 128)/forward/CPU/2 thread(s) 414958 ns 411750 ns 1.01
Dense(512 => 512, identity)(512 x 128)/forward/CPU/4 thread(s) 322541 ns 322271 ns 1.00
Dense(512 => 512, identity)(512 x 128)/forward/CPU/8 thread(s) 323167 ns 323042 ns 1.00
Dense(512 => 512, identity)(512 x 128)/forward/CPU/1 thread(s) 739562.5 ns 749375 ns 0.99
Dense(512 => 512, identity)(512 x 128)/forward/GPU/CUDA 44543 ns 43905 ns 1.01
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/2 thread(s) 1335729 ns 1306583 ns 1.02
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/4 thread(s) 485000 ns 465625 ns 1.04
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/8 thread(s) 14073833 ns 13617333 ns 1.03
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/1 thread(s) 2211312.5 ns 2245750 ns 0.98
Dense(512 => 512, identity)(512 x 128)/zygote/GPU/CUDA 194175 ns 192831 ns 1.01
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/2 thread(s) 1374959 ns 1394875 ns 0.99
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/4 thread(s) 596188 ns 634729.5 ns 0.94
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/8 thread(s) 13290875.5 ns 14050875 ns 0.95
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/1 thread(s) 2199270.5 ns 2238000 ns 0.98
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) 1665666 ns 1661542 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) 1186833.5 ns 1196103.5 ns 0.99
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) 1536854.5 ns 1534187.5 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) 2912062.5 ns 3005667 ns 0.97
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/GPU/CUDA 213313 ns 209529 ns 1.02
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) 12145187.5 ns 12111521 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) 9486083 ns 9554687 ns 0.99
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) 9213083 ns 9247000 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) 18563708 ns 18626583 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/GPU/CUDA 1921274.5 ns 1910271 ns 1.01
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) 17291000 ns 17307250 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) 14310062.5 ns 14377958 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) 14535333 ns 14526875 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) 21812208 ns 21836458.5 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) 250754270.5 ns 250439041.5 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) 174424541 ns 174592521 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) 115532521 ns 115955208.5 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) 446573667 ns 447243084 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/GPU/CUDA 5489738 ns 5470843 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) 1222307709 ns 1228722500 ns 0.99
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) 543403209 ns 543561875 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) 832977124.5 ns 830623396.5 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) 1653507000 ns 1628878000 ns 1.02
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/GPU/CUDA 34972271 ns 38000637 ns 0.92
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) 1142743917 ns 1136994583 ns 1.01
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) 686139667 ns 679379084 ns 1.01
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) 1325824667 ns 1328113771 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) 1748793708.5 ns 1733752146 ns 1.01
lenet(28, 28, 1, 32)/forward/CPU/2 thread(s) 1120083 ns 1103375 ns 1.02
lenet(28, 28, 1, 32)/forward/CPU/4 thread(s) 820062 ns 823209 ns 1.00
lenet(28, 28, 1, 32)/forward/CPU/8 thread(s) 3738667 ns 3578479 ns 1.04
lenet(28, 28, 1, 32)/forward/CPU/1 thread(s) 785458.5 ns 786500 ns 1.00
lenet(28, 28, 1, 32)/forward/GPU/CUDA 280004 ns 266091.5 ns 1.05
lenet(28, 28, 1, 32)/zygote/CPU/2 thread(s) 2992209 ns 2986021 ns 1.00
lenet(28, 28, 1, 32)/zygote/CPU/4 thread(s) 2457292 ns 2426000 ns 1.01
lenet(28, 28, 1, 32)/zygote/CPU/8 thread(s) 10152708 ns 10461250 ns 0.97
lenet(28, 28, 1, 32)/zygote/CPU/1 thread(s) 3200812.5 ns 3150042 ns 1.02
lenet(28, 28, 1, 32)/zygote/GPU/CUDA 1093024 ns 1055864 ns 1.04
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) 2350792 ns 2335042 ns 1.01
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) 1546500 ns 1537708 ns 1.01
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) 1703875 ns 1740000 ns 0.98
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) 4314041 ns 4348437.5 ns 0.99
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/GPU/CUDA 214178 ns 212286 ns 1.01
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) 20293542 ns 20266645.5 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) 17667520.5 ns 17701209 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) 18215687.5 ns 17495416 ns 1.04
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) 26742770.5 ns 26797000 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/GPU/CUDA 1989179 ns 1973706 ns 1.01
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) 44338833 ns 44317750 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) 29803125 ns 42027646 ns 0.71
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) 41253750.5 ns 41325000 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) 49627062.5 ns 47734917 ns 1.04
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) 4666292 ns 4664854 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) 2867729.5 ns 2868521.5 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) 3027042 ns 3015958 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) 8637375 ns 8658937.5 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/GPU/CUDA 512379 ns 516555 ns 0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) 40744375 ns 40579000.5 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) 34861000 ns 34830104 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) 34130125 ns 34148292 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) 53656458 ns 53661812 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/GPU/CUDA 3039460 ns 2969951 ns 1.02
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) 109854709 ns 109640958 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) 60211500 ns 84133666 ns 0.72
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) 244742208.5 ns 255828791 ns 0.96
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) 100222291.5 ns 96388416 ns 1.04
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) 270538750 ns 270215792 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) 187102791.5 ns 186630271 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) 128131083.5 ns 128172709 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) 496544584 ns 489605542 ns 1.01
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/GPU/CUDA 7095920 ns 7104246 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) 1493043979 ns 1502664042 ns 0.99
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) 820794750 ns 821183792 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) 1089880791.5 ns 1092397958.5 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) 2057983854 ns 2032173187.5 ns 1.01
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/GPU/CUDA 33958491 ns 33798333 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) 2031772896 ns 2027767896 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) 1169902125 ns 1563910958 ns 0.75
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) 2031263062.5 ns 2210346833.5 ns 0.92
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) 2621950583 ns 2560629834 ns 1.02
lenet(28, 28, 1, 128)/forward/CPU/2 thread(s) 2080791 ns 2006833 ns 1.04
lenet(28, 28, 1, 128)/forward/CPU/4 thread(s) 1266458.5 ns 1257333 ns 1.01
lenet(28, 28, 1, 128)/forward/CPU/8 thread(s) 7120666.5 ns 7451041.5 ns 0.96
lenet(28, 28, 1, 128)/forward/CPU/1 thread(s) 2469583 ns 2470458 ns 1.00
lenet(28, 28, 1, 128)/forward/GPU/CUDA 271494.5 ns 275531 ns 0.99
lenet(28, 28, 1, 128)/zygote/CPU/2 thread(s) 9645292 ns 9463416 ns 1.02
lenet(28, 28, 1, 128)/zygote/CPU/4 thread(s) 6578875 ns 6552500 ns 1.00
lenet(28, 28, 1, 128)/zygote/CPU/8 thread(s) 23723729.5 ns 25529541 ns 0.93
lenet(28, 28, 1, 128)/zygote/CPU/1 thread(s) 11743000 ns 11734125 ns 1.00
lenet(28, 28, 1, 128)/zygote/GPU/CUDA 1108810 ns 1130415 ns 0.98
vgg16(32, 32, 3, 32)/forward/CPU/2 thread(s) 378620500 ns 380676854.5 ns 0.99
vgg16(32, 32, 3, 32)/forward/CPU/4 thread(s) 148222625 ns 145328000 ns 1.02
vgg16(32, 32, 3, 32)/forward/CPU/8 thread(s) 232625416.5 ns 243564083 ns 0.96
vgg16(32, 32, 3, 32)/forward/CPU/1 thread(s) 452981958.5 ns 452336354.5 ns 1.00
vgg16(32, 32, 3, 32)/forward/GPU/CUDA 4877366.5 ns 4879283 ns 1.00
vgg16(32, 32, 3, 32)/zygote/CPU/2 thread(s) 1151789959 ns 1156932333 ns 1.00
vgg16(32, 32, 3, 32)/zygote/CPU/4 thread(s) 608267042 ns 487570458 ns 1.25
vgg16(32, 32, 3, 32)/zygote/CPU/8 thread(s) 958784875 ns 973572458 ns 0.98
vgg16(32, 32, 3, 32)/zygote/CPU/1 thread(s) 1398102250 ns 1399439834 ns 1.00
vgg16(32, 32, 3, 32)/zygote/GPU/CUDA 17573518 ns 16976929 ns 1.04
lenet(28, 28, 1, 64)/forward/CPU/2 thread(s) 1044000 ns 1062687.5 ns 0.98
lenet(28, 28, 1, 64)/forward/CPU/4 thread(s) 966666.5 ns 971124.5 ns 1.00
lenet(28, 28, 1, 64)/forward/CPU/8 thread(s) 5483500 ns 6269583 ns 0.87
lenet(28, 28, 1, 64)/forward/CPU/1 thread(s) 1369000 ns 1393375 ns 0.98
lenet(28, 28, 1, 64)/forward/GPU/CUDA 277684.5 ns 277704.5 ns 1.00
lenet(28, 28, 1, 64)/zygote/CPU/2 thread(s) 6395583 ns 6494541.5 ns 0.98
lenet(28, 28, 1, 64)/zygote/CPU/4 thread(s) 4649000 ns 4635437.5 ns 1.00
lenet(28, 28, 1, 64)/zygote/CPU/8 thread(s) 18457937.5 ns 19450479 ns 0.95
lenet(28, 28, 1, 64)/zygote/CPU/1 thread(s) 6087375 ns 6080229 ns 1.00
lenet(28, 28, 1, 64)/zygote/GPU/CUDA 1153126 ns 1148981 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) 70616187.5 ns 70442208 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) 34338895.5 ns 35305229 ns 0.97
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) 39546146 ns 39532604 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) 132480333 ns 132574604 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/GPU/CUDA 1837859.5 ns 1848251 ns 0.99
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) 354650895.5 ns 356785937.5 ns 0.99
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) 158687854 ns 159371854 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) 253835396.5 ns 254893688 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) 535065979.5 ns 535009020.5 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/GPU/CUDA 16493787 ns 16489529.5 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) 394789208 ns 395707667 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) 245506292 ns 245564417 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) 682534167 ns 652089584 ns 1.05
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) 711436834 ns 712574333 ns 1.00
vgg16(32, 32, 3, 128)/forward/CPU/2 thread(s) 1186860667 ns 1191762375 ns 1.00
vgg16(32, 32, 3, 128)/forward/CPU/4 thread(s) 435266750 ns 434009729.5 ns 1.00
vgg16(32, 32, 3, 128)/forward/CPU/8 thread(s) 628427791 ns 631038834 ns 1.00
vgg16(32, 32, 3, 128)/forward/CPU/1 thread(s) 1780484229 ns 1771033395.5 ns 1.01
vgg16(32, 32, 3, 128)/forward/GPU/CUDA 12477417 ns 12471861 ns 1.00
vgg16(32, 32, 3, 128)/zygote/CPU/2 thread(s) 3652621271 ns 3670803208.5 ns 1.00
vgg16(32, 32, 3, 128)/zygote/CPU/4 thread(s) 1639329875 ns 1633483458 ns 1.00
vgg16(32, 32, 3, 128)/zygote/CPU/8 thread(s) 2709465041 ns 2737701958 ns 0.99
vgg16(32, 32, 3, 128)/zygote/CPU/1 thread(s) 5075123916 ns 5038709417 ns 1.01
vgg16(32, 32, 3, 128)/zygote/GPU/CUDA 49797376 ns 49641386 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) 3423958 ns 3412146 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) 2097249.5 ns 2094750 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) 2534854 ns 2533833.5 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) 6018625 ns 6034292 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/GPU/CUDA 580639 ns 586721 ns 0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) 25949208 ns 26096750.5 ns 0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) 20274833 ns 20315791.5 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) 19561271 ns 19312917 ns 1.01
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) 39224583 ns 39366625 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/GPU/CUDA 2980196 ns 2989473.5 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) 55399562.5 ns 54095229 ns 1.02
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) 28378292 ns 28393083 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) 172128937.5 ns 177757792 ns 0.97
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) 45636375 ns 45278750 ns 1.01
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) 1783542 ns 1778208 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) 1197959 ns 1204708 ns 0.99
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) 1577021.5 ns 1564000 ns 1.01
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) 3027083.5 ns 3038771 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/GPU/CUDA 218302 ns 217944 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) 12541854.5 ns 12531437.5 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) 9966458 ns 9964292 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) 9641875 ns 9707042 ns 0.99
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) 18982208 ns 18974500 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/GPU/CUDA 1943759 ns 1963028.5 ns 0.99
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) 17642208 ns 17644270.5 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) 14745479 ns 14745500 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) 14622083 ns 14639333 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) 22196749.5 ns 22173792 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) 70503791.5 ns 70409562 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) 34154375 ns 34786542 ns 0.98
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) 39724625.5 ns 39571499.5 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) 133426312.5 ns 132610521 ns 1.01
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/GPU/CUDA 1855504 ns 1837717 ns 1.01
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) 357508542 ns 360588187.5 ns 0.99
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) 236762959 ns 237608334 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) 305563667 ns 299913354 ns 1.02
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) 731068500 ns 725805833 ns 1.01
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/GPU/CUDA 13898567 ns 13956738 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) 418493479 ns 418949812.5 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) 253429500 ns 251360792 ns 1.01
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) 696829083.5 ns 712732021 ns 0.98
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) 717012750 ns 717284542 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/2 thread(s) 1657812.5 ns 1912041.5 ns 0.87
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/4 thread(s) 1559500 ns 1579125 ns 0.99
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/8 thread(s) 1547979 ns 1549791.5 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/1 thread(s) 2615500 ns 2657625 ns 0.98
mlp7layer_bn(gelu)(32 x 256)/forward/GPU/CUDA 579410 ns 573525 ns 1.01
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/2 thread(s) 8948521 ns 9220000 ns 0.97
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/4 thread(s) 5918125 ns 5936166 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/8 thread(s) 30404791 ns 31895937.5 ns 0.95
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/1 thread(s) 10061562 ns 10214937.5 ns 0.98
mlp7layer_bn(gelu)(32 x 256)/zygote/GPU/CUDA 1389319 ns 1399984.5 ns 0.99
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/2 thread(s) 22293791 ns 22182333.5 ns 1.01
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/4 thread(s) 19118125 ns 19138291.5 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/8 thread(s) 50278833 ns 52527562.5 ns 0.96
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/1 thread(s) 19441208.5 ns 18888042 ns 1.03
Dense(512 => 512, relu)(512 x 128)/forward/CPU/2 thread(s) 687479 ns 791291.5 ns 0.87
Dense(512 => 512, relu)(512 x 128)/forward/CPU/4 thread(s) 71083 ns 69958.5 ns 1.02
Dense(512 => 512, relu)(512 x 128)/forward/CPU/8 thread(s) 1021000 ns 997167 ns 1.02
Dense(512 => 512, relu)(512 x 128)/forward/CPU/1 thread(s) 725458.5 ns 724499.5 ns 1.00
Dense(512 => 512, relu)(512 x 128)/forward/GPU/CUDA 48336 ns 48324 ns 1.00
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/2 thread(s) 1568500 ns 1508042 ns 1.04
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/4 thread(s) 284021 ns 320291 ns 0.89
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/8 thread(s) 1426229 ns 1445145.5 ns 0.99
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/1 thread(s) 2289417 ns 2258458.5 ns 1.01
Dense(512 => 512, relu)(512 x 128)/zygote/GPU/CUDA 213525.5 ns 216350 ns 0.99
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/2 thread(s) 1518000 ns 1537083 ns 0.99
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/4 thread(s) 446709 ns 428792 ns 1.04
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/8 thread(s) 1398833.5 ns 1444584 ns 0.97
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/1 thread(s) 2227979 ns 2250333 ns 0.99
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) 3424312.5 ns 3421750 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) 2076145.5 ns 2084312.5 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) 2517375 ns 2519375.5 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) 6002625 ns 6015021 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/GPU/CUDA 580319 ns 584297 ns 0.99
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) 24064958.5 ns 24071521.5 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) 18099937.5 ns 18050833 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) 17179812.5 ns 17227375 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) 37498749.5 ns 37583145.5 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/GPU/CUDA 2895115 ns 2895440 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) 53787500 ns 52599188 ns 1.02
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) 27724333.5 ns 27644250 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) 165600125 ns 170611917 ns 0.97
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) 44506604 ns 44514250 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) 250628729 ns 250102292 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) 174546375 ns 174510104 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) 115593979.5 ns 115645729 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) 447286479 ns 448140124.5 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/GPU/CUDA 5483866.5 ns 5446378 ns 1.01
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) 1100854958 ns 1105120833 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) 467966979 ns 467780729.5 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) 825353979.5 ns 825455520.5 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) 1759896083 ns 1753431125 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/GPU/CUDA 32267946 ns 35149612 ns 0.92
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) 1018635708.5 ns 1021983312.5 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) 665400833 ns 662517187.5 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) 1204699750 ns 1286071167 ns 0.94
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) 1733389813 ns 1721665437.5 ns 1.01
mlp7layer_bn(relu)(32 x 256)/forward/CPU/2 thread(s) 1226521 ns 1312041 ns 0.93
mlp7layer_bn(relu)(32 x 256)/forward/CPU/4 thread(s) 961167 ns 928625 ns 1.04
mlp7layer_bn(relu)(32 x 256)/forward/CPU/8 thread(s) 918083 ns 903208 ns 1.02
mlp7layer_bn(relu)(32 x 256)/forward/CPU/1 thread(s) 2051083.5 ns 2032416 ns 1.01
mlp7layer_bn(relu)(32 x 256)/forward/GPU/CUDA 578415 ns 575428 ns 1.01
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/2 thread(s) 5660417 ns 5922771 ns 0.96
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/4 thread(s) 2618917 ns 2615500 ns 1.00
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/8 thread(s) 23019583 ns 24427083.5 ns 0.94
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/1 thread(s) 7086333 ns 7104916.5 ns 1.00
mlp7layer_bn(relu)(32 x 256)/zygote/GPU/CUDA 1349133 ns 1363516 ns 0.99
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/2 thread(s) 9707250 ns 9705958.5 ns 1.00
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/4 thread(s) 6502604.5 ns 6499000 ns 1.00
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/8 thread(s) 30901167 ns 31929750 ns 0.97
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/1 thread(s) 7612917 ns 7614042 ns 1.00
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/2 thread(s) 383916 ns 483291 ns 0.79
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/4 thread(s) 31791 ns 31750 ns 1.00
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/8 thread(s) 2087375 ns 1795375 ns 1.16
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/1 thread(s) 91083 ns 91542 ns 0.99
Dense(128 => 128, gelu)(128 x 128)/forward/GPU/CUDA 28712 ns 28996 ns 0.99
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/2 thread(s) 406208 ns 392958 ns 1.03
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/4 thread(s) 175875 ns 175542 ns 1.00
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/8 thread(s) 4346812.5 ns 4708417 ns 0.92
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/1 thread(s) 272959 ns 273000 ns 1.00
Dense(128 => 128, gelu)(128 x 128)/zygote/GPU/CUDA 216512 ns 224707.5 ns 0.96
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/2 thread(s) 678958 ns 666333 ns 1.02
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/4 thread(s) 442584 ns 442250 ns 1.00
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/8 thread(s) 4679166 ns 4499167 ns 1.04
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/1 thread(s) 543542 ns 510979.5 ns 1.06
Dense(128 => 128, relu)(128 x 128)/forward/CPU/2 thread(s) 329792 ns 430437.5 ns 0.77
Dense(128 => 128, relu)(128 x 128)/forward/CPU/4 thread(s) 13125 ns 13583 ns 0.97
Dense(128 => 128, relu)(128 x 128)/forward/CPU/8 thread(s) 603208 ns 709208 ns 0.85
Dense(128 => 128, relu)(128 x 128)/forward/CPU/1 thread(s) 54708 ns 52584 ns 1.04
Dense(128 => 128, relu)(128 x 128)/forward/GPU/CUDA 27935 ns 29296 ns 0.95
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/2 thread(s) 354458 ns 337250 ns 1.05
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/4 thread(s) 25792 ns 26375 ns 0.98
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/8 thread(s) 719333 ns 484812.5 ns 1.48
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/1 thread(s) 151792 ns 151333 ns 1.00
Dense(128 => 128, relu)(128 x 128)/zygote/GPU/CUDA 206354.5 ns 213308.5 ns 0.97
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/2 thread(s) 370041 ns 352521 ns 1.05
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/4 thread(s) 45958 ns 45792 ns 1.00
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/8 thread(s) 469708 ns 487125 ns 0.96
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/1 thread(s) 151167 ns 151000 ns 1.00
vgg16(32, 32, 3, 64)/forward/CPU/2 thread(s) 600144917 ns 603223875 ns 0.99
vgg16(32, 32, 3, 64)/forward/CPU/4 thread(s) 239512791.5 ns 239241354 ns 1.00
vgg16(32, 32, 3, 64)/forward/CPU/8 thread(s) 368675770.5 ns 377713896 ns 0.98
vgg16(32, 32, 3, 64)/forward/CPU/1 thread(s) 878611917 ns 872019458 ns 1.01
vgg16(32, 32, 3, 64)/forward/GPU/CUDA 7673480 ns 7676104.5 ns 1.00
vgg16(32, 32, 3, 64)/zygote/CPU/2 thread(s) 2001190937.5 ns 2005520125 ns 1.00
vgg16(32, 32, 3, 64)/zygote/CPU/4 thread(s) 950953500.5 ns 947653916.5 ns 1.00
vgg16(32, 32, 3, 64)/zygote/CPU/8 thread(s) 1611271729.5 ns 1551514604.5 ns 1.04
vgg16(32, 32, 3, 64)/zygote/CPU/1 thread(s) 2652863625 ns 2653038416 ns 1.00
vgg16(32, 32, 3, 64)/zygote/GPU/CUDA 27099744 ns 27180094 ns 1.00
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/2 thread(s) 532625 ns 525604 ns 1.01
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/4 thread(s) 175791 ns 168333 ns 1.04
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/8 thread(s) 1765937 ns 1740625 ns 1.01
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/1 thread(s) 873854 ns 875541 ns 1.00
Dense(512 => 512, gelu)(512 x 128)/forward/GPU/CUDA 48010 ns 47837 ns 1.00
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/2 thread(s) 1862104.5 ns 1943750 ns 0.96
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/4 thread(s) 1105875 ns 1100208 ns 1.01
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/8 thread(s) 14941916 ns 14661875 ns 1.02
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/1 thread(s) 2753625 ns 2836709 ns 0.97
Dense(512 => 512, gelu)(512 x 128)/zygote/GPU/CUDA 222255 ns 232330 ns 0.96
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/2 thread(s) 2909834 ns 2974229 ns 0.98
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/4 thread(s) 2231271 ns 2208583.5 ns 1.01
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/8 thread(s) 15268500 ns 15024229.5 ns 1.02
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/1 thread(s) 3879541.5 ns 3751750 ns 1.03
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/2 thread(s) 1471375 ns 1602291.5 ns 0.92
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/4 thread(s) 1256917 ns 1221084 ns 1.03
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/8 thread(s) 1257021 ns 1264750 ns 0.99
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/1 thread(s) 2211792 ns 2362750 ns 0.94
mlp7layer_bn(tanh)(32 x 256)/forward/GPU/CUDA 567447 ns 576709 ns 0.98
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/2 thread(s) 5894166 ns 5931125 ns 0.99
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/4 thread(s) 2856229.5 ns 2866334 ns 1.00
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/8 thread(s) 24506146 ns 25035834 ns 0.98
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/1 thread(s) 7294791 ns 6650208 ns 1.10
mlp7layer_bn(tanh)(32 x 256)/zygote/GPU/CUDA 1317261 ns 1379411 ns 0.95
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/2 thread(s) 11665375 ns 11605146 ns 1.01
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/4 thread(s) 8766624.5 ns 8767458 ns 1.00
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/8 thread(s) 34961958 ns 35255000 ns 0.99
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/1 thread(s) 9529500 ns 9570000.5 ns 1.00
Dense(16 => 16, relu)(16 x 128)/forward/CPU/2 thread(s) 2959 ns 2541 ns 1.16
Dense(16 => 16, relu)(16 x 128)/forward/CPU/4 thread(s) 2542 ns 2292 ns 1.11
Dense(16 => 16, relu)(16 x 128)/forward/CPU/8 thread(s) 2959 ns 3000 ns 0.99
Dense(16 => 16, relu)(16 x 128)/forward/CPU/1 thread(s) 2520.5 ns 2333 ns 1.08
Dense(16 => 16, relu)(16 x 128)/forward/GPU/CUDA 24587 ns 25379.5 ns 0.97
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/2 thread(s) 7167 ns 7125 ns 1.01
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/4 thread(s) 7084 ns 7083 ns 1.00
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/8 thread(s) 7458 ns 7375 ns 1.01
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/1 thread(s) 7167 ns 7270.5 ns 0.99
Dense(16 => 16, relu)(16 x 128)/zygote/GPU/CUDA 185970 ns 193729.5 ns 0.96
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/2 thread(s) 8208 ns 8334 ns 0.98
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/4 thread(s) 8375 ns 8500 ns 0.99
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/8 thread(s) 8416 ns 8417 ns 1.00
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/1 thread(s) 6000 ns 6084 ns 0.99
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/2 thread(s) 10208 ns 10375.5 ns 0.98
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/4 thread(s) 14937.5 ns 14916 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/8 thread(s) 10916 ns 11854 ns 0.92
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/1 thread(s) 8833 ns 7625 ns 1.16
Dense(16 => 16, gelu)(16 x 128)/forward/GPU/CUDA 24920 ns 25646 ns 0.97
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/2 thread(s) 21542 ns 21708 ns 0.99
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/4 thread(s) 21625 ns 21500 ns 1.01
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/8 thread(s) 21916 ns 21750 ns 1.01
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/1 thread(s) 21916.5 ns 21875 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/zygote/GPU/CUDA 195496 ns 203851 ns 0.96
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/2 thread(s) 53500 ns 53417 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/4 thread(s) 56875 ns 56583.5 ns 1.01
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/8 thread(s) 53625 ns 53583.5 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/1 thread(s) 55083 ns 51333 ns 1.07
Dense(128 => 128, identity)(128 x 128)/forward/CPU/2 thread(s) 28541 ns 26895.5 ns 1.06
Dense(128 => 128, identity)(128 x 128)/forward/CPU/4 thread(s) 28437.5 ns 28333.5 ns 1.00
Dense(128 => 128, identity)(128 x 128)/forward/CPU/8 thread(s) 28250 ns 29000 ns 0.97
Dense(128 => 128, identity)(128 x 128)/forward/CPU/1 thread(s) 46083 ns 48291 ns 0.95
Dense(128 => 128, identity)(128 x 128)/forward/GPU/CUDA 25773 ns 26739 ns 0.96
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/2 thread(s) 224458 ns 220875 ns 1.02
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/4 thread(s) 44229.5 ns 44583 ns 0.99
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/8 thread(s) 4410250 ns 4132667 ns 1.07
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/1 thread(s) 145625 ns 145458 ns 1.00
Dense(128 => 128, identity)(128 x 128)/zygote/GPU/CUDA 167043 ns 172310 ns 0.97
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/2 thread(s) 242125 ns 237312.5 ns 1.02
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/4 thread(s) 68875 ns 68625 ns 1.00
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/8 thread(s) 4299458 ns 4360708 ns 0.99
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/1 thread(s) 145667 ns 145917 ns 1.00
Dense(16 => 16, identity)(16 x 128)/forward/CPU/2 thread(s) 2125 ns 2292 ns 0.93
Dense(16 => 16, identity)(16 x 128)/forward/CPU/4 thread(s) 1750 ns 1750 ns 1
Dense(16 => 16, identity)(16 x 128)/forward/CPU/8 thread(s) 2583 ns 2166 ns 1.19
Dense(16 => 16, identity)(16 x 128)/forward/CPU/1 thread(s) 1917 ns 1520.5 ns 1.26
Dense(16 => 16, identity)(16 x 128)/forward/GPU/CUDA 22918.5 ns 23935 ns 0.96
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/2 thread(s) 5250 ns 5125 ns 1.02
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/4 thread(s) 5292 ns 5042 ns 1.05
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/8 thread(s) 5417 ns 5458 ns 0.99
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/1 thread(s) 5250 ns 5084 ns 1.03
Dense(16 => 16, identity)(16 x 128)/zygote/GPU/CUDA 171129 ns 176841 ns 0.97
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/2 thread(s) 7583 ns 7292 ns 1.04
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/4 thread(s) 8125 ns 8166 ns 0.99
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/8 thread(s) 7500 ns 7541 ns 0.99
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/1 thread(s) 5125 ns 5167 ns 0.99
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) 81032541 ns 80940833 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) 39920458 ns 41092709 ns 0.97
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) 45590917 ns 45570541 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) 153513167 ns 153559792 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/GPU/CUDA 2660470 ns 2660311 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) 675206709 ns 621714834 ns 1.09
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) 319221521 ns 421739375 ns 0.76
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) 412689584 ns 414510667 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) 704326792 ns 697568292 ns 1.01
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/GPU/CUDA 15217384 ns 15148414 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) 875714312.5 ns 872377937.5 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) 502738834 ns 706482291.5 ns 0.71
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) 1160733354 ns 1162546146 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) 1210583500 ns 1175739375 ns 1.03

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.