Skip to content

Commit

Permalink
docs: added to Nested AD example how to use batched_jacobian (#964)
Browse files Browse the repository at this point in the history
* Added to Nested AD example how to use `batched_jacobian`

* Complete example with loss function and tests

* Update docs/src/manual/nested_autodiff.md
  • Loading branch information
facusapienza21 authored Oct 1, 2024
1 parent e4bd1af commit dcb6c6d
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions docs/src/manual/nested_autodiff.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,53 @@ nothing; # hide
That's pretty good, of course you will have some error from the finite differences
calculation.

### Using Batched Jacobian for Multiple Inputs

Notice that in this example the Jacobian `J` consists on the full matrix of derivatives of `smodel` with respect
the different inputs in `x`. In many cases, we are interested in computing the Jacobian with respect to each
input individually, avoiding the unnecessary calculation of zero entries of the Jacobian. This can be achived with
[`batched_jacobian`](@ref) to parse the calculation of the Jacobian per each single input. Using the same example
from the previous section:

```@example nested_ad
model = Chain(Dense(2 => 4, tanh), Dense(4 => 2))
ps, st = Lux.setup(StableRNG(0), model)
x = randn(StableRNG(0), Float32, 2, 10)
y = randn(StableRNG(11), Float32, 2, 10)
function loss_function_batched(model, x, ps, st, y)
# Make it a stateful layer
smodel = StatefulLuxLayer{true}(model, ps, st)
ŷ = smodel(x)
loss_emp = sum(abs2, ŷ .- y)
# You can use `AutoZygote()` as well but `AutoForwardDiff()` tends to be more efficient here
J = batched_jacobian(smodel, AutoForwardDiff(), x)
loss_reg = abs2(norm(J .* 0.01f0))
return loss_emp + loss_reg
end
loss_function_batched(model, x, ps, st, y)
```

Notice that in this last example we removed `BatchNorm()` from the neural network. This is done so outputs corresponding
to differern inputs don't have an algebraic dependency due to the batch normalization happening in the neural network.
We can now verify again the value of the Jacobian:

```@example nested_ad
∂x_fd = FiniteDiff.finite_difference_gradient(x -> loss_function_batched(model, x, ps, st, y), x)
∂ps_fd = FiniteDiff.finite_difference_gradient(ps -> loss_function_batched(model, x, ps, st, y),
ComponentArray(ps))
_, ∂x_b, ∂ps_b, _, _ = Zygote.gradient(loss_function_batched, model, x, ps, st, y)
println("∞-norm(∂x_b - ∂x_fd): ", norm(∂x_b .- ∂x_fd, Inf))
@assert norm(∂x_b .- ∂x_fd, Inf) < 1e-2 # hide
println("∞-norm(∂ps_b - ∂ps_fd): ", norm(ComponentArray(∂ps_b) .- ∂ps_fd, Inf))
@assert norm(ComponentArray(∂ps_b) .- ∂ps_fd, Inf) < 1e-2 # hide
```

In this example, it is important to remark that now `batched_jacobian` returns a 3D array with the Jacobian calculation
for each independent input value in `x`.

## Loss Function contains Gradient Computation

Ok here I am going to cheat a bit. This comes from a discussion on nested AD for PINNs
Expand Down

1 comment on commit dcb6c6d

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lux Benchmarks

Benchmark suite Current: dcb6c6d Previous: e4bd1af Ratio
Dense(512 => 512, identity)(512 x 128)/forward/CPU/2 thread(s) 415083 ns 412125 ns 1.01
Dense(512 => 512, identity)(512 x 128)/forward/CPU/4 thread(s) 243562.5 ns 322375 ns 0.76
Dense(512 => 512, identity)(512 x 128)/forward/CPU/8 thread(s) 243917 ns 321625 ns 0.76
Dense(512 => 512, identity)(512 x 128)/forward/CPU/1 thread(s) 740187.5 ns 739375 ns 1.00
Dense(512 => 512, identity)(512 x 128)/forward/GPU/CUDA 43145 ns 44132 ns 0.98
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/2 thread(s) 1349145.5 ns 647917 ns 2.08
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/4 thread(s) 1217021 ns 2404667 ns 0.51
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/8 thread(s) 16523666 ns 13901084 ns 1.19
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/1 thread(s) 2260375 ns 2211917 ns 1.02
Dense(512 => 512, identity)(512 x 128)/zygote/GPU/CUDA 198205.5 ns 201549 ns 0.98
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/2 thread(s) 1319125 ns 740042 ns 1.78
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/4 thread(s) 1304979 ns 2593084 ns 0.50
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/8 thread(s) 16162208.5 ns 14418542 ns 1.12
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/1 thread(s) 2198917 ns 2199209 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) 1670458 ns 1526583 ns 1.09
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) 1107375 ns 1096708 ns 1.01
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) 1527771 ns 1529625 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) 3019125 ns 3028083 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/GPU/CUDA 211316 ns 210375.5 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) 12175041 ns 12223834 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) 8824145.5 ns 8813167 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) 9233625 ns 9206687.5 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) 18591583 ns 18597853.5 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/GPU/CUDA 1930057 ns 1948580 ns 0.99
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) 17307313 ns 17338770.5 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) 13969291.5 ns 13950583.5 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) 14519583 ns 14476791.5 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) 21863458 ns 21850833 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) 250175667 ns 124925271 ns 2.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) 148788625 ns 148389000 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) 116216917 ns 115877562.5 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) 446783750 ns 447112875 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/GPU/CUDA 5483992 ns 5460574 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) 1221582792 ns 600322042 ns 2.03
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) 934823708 ns 930867334 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) 825393979 ns 825580604 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) 1634434500 ns 1687470250.5 ns 0.97
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/GPU/CUDA 31104295 ns 31224338 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) 1147938166 ns 706851312.5 ns 1.62
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) 996908396 ns 988058125.5 ns 1.01
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) 1315038312.5 ns 1348418729 ns 0.98
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) 1733258437.5 ns 1806342854 ns 0.96
lenet(28, 28, 1, 32)/forward/CPU/2 thread(s) 1124250 ns 863834 ns 1.30
lenet(28, 28, 1, 32)/forward/CPU/4 thread(s) 1648541.5 ns 1622583.5 ns 1.02
lenet(28, 28, 1, 32)/forward/CPU/8 thread(s) 3458500 ns 3450625 ns 1.00
lenet(28, 28, 1, 32)/forward/CPU/1 thread(s) 790708 ns 784875 ns 1.01
lenet(28, 28, 1, 32)/forward/GPU/CUDA 276890 ns 267055.5 ns 1.04
lenet(28, 28, 1, 32)/zygote/CPU/2 thread(s) 2989917 ns 2714792 ns 1.10
lenet(28, 28, 1, 32)/zygote/CPU/4 thread(s) 4140375 ns 4119812 ns 1.00
lenet(28, 28, 1, 32)/zygote/CPU/8 thread(s) 10581541.5 ns 10424458 ns 1.02
lenet(28, 28, 1, 32)/zygote/CPU/1 thread(s) 3136958 ns 3144166 ns 1.00
lenet(28, 28, 1, 32)/zygote/GPU/CUDA 1129684 ns 1090149.5 ns 1.04
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) 2390166 ns 2166312 ns 1.10
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) 1353000 ns 1479000 ns 0.91
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) 1581708 ns 1744292 ns 0.91
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) 4332708 ns 4339875 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/GPU/CUDA 210207 ns 208596 ns 1.01
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) 20303291.5 ns 20428875 ns 0.99
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) 16973958 ns 16963479 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) 18209958 ns 17405708 ns 1.05
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) 26748042 ns 26734729 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/GPU/CUDA 2004316 ns 2018993 ns 0.99
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) 44366000 ns 45033583 ns 0.99
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) 40975041.5 ns 40993666.5 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) 41237167 ns 41173500 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) 47733416.5 ns 47738437 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) 4673042 ns 4301666.5 ns 1.09
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) 2607958 ns 2844667 ns 0.92
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) 2740083 ns 2996709 ns 0.91
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) 8646250 ns 8653334 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/GPU/CUDA 471597 ns 472874 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) 40513208 ns 40060542 ns 1.01
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) 33898583 ns 33920959 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) 34004896 ns 33907687.5 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) 53682375 ns 53575541.5 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/GPU/CUDA 3025195 ns 3254220 ns 0.93
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) 109957125 ns 90139000 ns 1.22
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) 136423624.5 ns 135574958.5 ns 1.01
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) 249203917 ns 249787833 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) 96417375 ns 96223792 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) 270485625 ns 142522459 ns 1.90
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) 157422417 ns 161123167 ns 0.98
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) 125021063 ns 128478042 ns 0.97
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) 489717917 ns 493238750 ns 0.99
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/GPU/CUDA 6887253.5 ns 7031961.5 ns 0.98
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) 1500178749.5 ns 881412625 ns 1.70
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) 1209776166 ns 1203181667 ns 1.01
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) 1101673604 ns 1089986000.5 ns 1.01
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) 2033012896.5 ns 2129205729 ns 0.95
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/GPU/CUDA 34855481.5 ns 34708690 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) 2031056270.5 ns 1668841500 ns 1.22
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) 1850536958 ns 1865068750 ns 0.99
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) 2173376541.5 ns 2075940833.5 ns 1.05
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) 2563569208 ns 2608730625 ns 0.98
lenet(28, 28, 1, 128)/forward/CPU/2 thread(s) 2043208 ns 1545708 ns 1.32
lenet(28, 28, 1, 128)/forward/CPU/4 thread(s) 3056708 ns 3042541 ns 1.00
lenet(28, 28, 1, 128)/forward/CPU/8 thread(s) 8256479.5 ns 7339916 ns 1.12
lenet(28, 28, 1, 128)/forward/CPU/1 thread(s) 2476666 ns 2318125 ns 1.07
lenet(28, 28, 1, 128)/forward/GPU/CUDA 276146.5 ns 277569.5 ns 0.99
lenet(28, 28, 1, 128)/zygote/CPU/2 thread(s) 9654583 ns 7874959 ns 1.23
lenet(28, 28, 1, 128)/zygote/CPU/4 thread(s) 12054625 ns 12022125 ns 1.00
lenet(28, 28, 1, 128)/zygote/CPU/8 thread(s) 24288042 ns 23765959 ns 1.02
lenet(28, 28, 1, 128)/zygote/CPU/1 thread(s) 11746854.5 ns 11654708 ns 1.01
lenet(28, 28, 1, 128)/zygote/GPU/CUDA 1181147.5 ns 1196174 ns 0.99
vgg16(32, 32, 3, 32)/forward/CPU/2 thread(s) 381419291.5 ns 186253812 ns 2.05
vgg16(32, 32, 3, 32)/forward/CPU/4 thread(s) 308744166.5 ns 283266353.5 ns 1.09
vgg16(32, 32, 3, 32)/forward/CPU/8 thread(s) 262197666.5 ns 242835500 ns 1.08
vgg16(32, 32, 3, 32)/forward/CPU/1 thread(s) 453805292 ns 463794333 ns 0.98
vgg16(32, 32, 3, 32)/forward/GPU/CUDA 4853504 ns 4830735 ns 1.00
vgg16(32, 32, 3, 32)/zygote/CPU/2 thread(s) 1144266542 ns 630927250 ns 1.81
vgg16(32, 32, 3, 32)/zygote/CPU/4 thread(s) 964566583 ns 990257541 ns 0.97
vgg16(32, 32, 3, 32)/zygote/CPU/8 thread(s) 971379334 ns 1035740417 ns 0.94
vgg16(32, 32, 3, 32)/zygote/CPU/1 thread(s) 1404606542 ns 1415342041 ns 0.99
vgg16(32, 32, 3, 32)/zygote/GPU/CUDA 16465783 ns 16300060 ns 1.01
lenet(28, 28, 1, 64)/forward/CPU/2 thread(s) 1058521 ns 1085229 ns 0.98
lenet(28, 28, 1, 64)/forward/CPU/4 thread(s) 1665374.5 ns 2098166 ns 0.79
lenet(28, 28, 1, 64)/forward/CPU/8 thread(s) 6526666 ns 4972000 ns 1.31
lenet(28, 28, 1, 64)/forward/CPU/1 thread(s) 1370042 ns 1299500 ns 1.05
lenet(28, 28, 1, 64)/forward/GPU/CUDA 274033.5 ns 278783 ns 0.98
lenet(28, 28, 1, 64)/zygote/CPU/2 thread(s) 6516541 ns 6008145.5 ns 1.08
lenet(28, 28, 1, 64)/zygote/CPU/4 thread(s) 13102708.5 ns 12421208 ns 1.05
lenet(28, 28, 1, 64)/zygote/CPU/8 thread(s) 18363000 ns 20005041 ns 0.92
lenet(28, 28, 1, 64)/zygote/CPU/1 thread(s) 6084354.5 ns 6082792 ns 1.00
lenet(28, 28, 1, 64)/zygote/GPU/CUDA 1233343.5 ns 1220466 ns 1.01
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) 70574042 ns 23693938 ns 2.98
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) 43797125 ns 43500791.5 ns 1.01
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) 39782958.5 ns 39526833.5 ns 1.01
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) 132781271 ns 132823145.5 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/GPU/CUDA 1956000 ns 1948314 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) 355154708.5 ns 184396041 ns 1.93
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) 270770334 ns 270116291 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) 254052708 ns 253589145.5 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) 534690875 ns 534281562.5 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/GPU/CUDA 13245522.5 ns 13222993 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) 396827750 ns 297123437 ns 1.34
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) 372318834 ns 404377895.5 ns 0.92
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) 671683959 ns 696065958 ns 0.96
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) 713207834 ns 713613916 ns 1.00
vgg16(32, 32, 3, 128)/forward/CPU/2 thread(s) 1189840458 ns 656595541 ns 1.81
vgg16(32, 32, 3, 128)/forward/CPU/4 thread(s) 834600270.5 ns 689413604.5 ns 1.21
vgg16(32, 32, 3, 128)/forward/CPU/8 thread(s) 643996000 ns 634330625 ns 1.02
vgg16(32, 32, 3, 128)/forward/CPU/1 thread(s) 1771218270.5 ns 1789031312.5 ns 0.99
vgg16(32, 32, 3, 128)/forward/GPU/CUDA 12386792 ns 12386066 ns 1.00
vgg16(32, 32, 3, 128)/zygote/CPU/2 thread(s) 3632394041.5 ns 1908648333.5 ns 1.90
vgg16(32, 32, 3, 128)/zygote/CPU/4 thread(s) 2819490917 ns 2827932125 ns 1.00
vgg16(32, 32, 3, 128)/zygote/CPU/8 thread(s) 2703852750 ns 2698654250 ns 1.00
vgg16(32, 32, 3, 128)/zygote/CPU/1 thread(s) 5046837084 ns 5716413416 ns 0.88
vgg16(32, 32, 3, 128)/zygote/GPU/CUDA 49275819 ns 49345511 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) 3417875 ns 3047688 ns 1.12
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) 2080042 ns 2062437 ns 1.01
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) 2540459 ns 2519583 ns 1.01
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) 6037792 ns 6053042 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/GPU/CUDA 571807.5 ns 574063 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) 25947667 ns 25654333 ns 1.01
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) 18971396.5 ns 19054583.5 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) 19516791.5 ns 19323500 ns 1.01
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) 39348958.5 ns 39330000 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/GPU/CUDA 3001343 ns 3195551.5 ns 0.94
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) 55429166.5 ns 35130041.5 ns 1.58
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) 81557583 ns 82097417 ns 0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) 172942167 ns 168348625 ns 1.03
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) 45661541.5 ns 45591875 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) 1786354.5 ns 1644375 ns 1.09
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) 1106458 ns 1090250 ns 1.01
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) 1570978.5 ns 1572750 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) 3033375 ns 3038167 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/GPU/CUDA 214775.5 ns 214850 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) 12557750 ns 12701083 ns 0.99
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) 9236583.5 ns 9189625 ns 1.01
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) 9630708 ns 9640458.5 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) 19044937.5 ns 18968854.5 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/GPU/CUDA 1985531 ns 1987617.5 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) 17664084 ns 17682875 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) 14332709 ns 14327834 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) 14595146 ns 14625958 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) 22201042 ns 22177500 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) 70526583 ns 23739833.5 ns 2.97
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) 43708417 ns 43469541 ns 1.01
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) 39735812.5 ns 39647750 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) 132615771 ns 132812271.5 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/GPU/CUDA 1938634 ns 1879600 ns 1.03
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) 360222063 ns 189384875 ns 1.90
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) 348659791.5 ns 346944938 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) 302374833.5 ns 303748958 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) 727881666 ns 748909417 ns 0.97
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/GPU/CUDA 14325162 ns 14283912.5 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) 419531958.5 ns 302085833 ns 1.39
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) 434088375 ns 421708625 ns 1.03
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) 691688416.5 ns 689499625 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) 717541625 ns 719890000 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/2 thread(s) 1673625 ns 1926375 ns 0.87
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/4 thread(s) 1384958 ns 1579042 ns 0.88
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/8 thread(s) 1378083 ns 1571792 ns 0.88
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/1 thread(s) 2664374.5 ns 2497917 ns 1.07
mlp7layer_bn(gelu)(32 x 256)/forward/GPU/CUDA 568730 ns 573991 ns 0.99
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/2 thread(s) 9240188 ns 6186000 ns 1.49
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/4 thread(s) 14792541.5 ns 13018375 ns 1.14
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/8 thread(s) 32052875 ns 31151958 ns 1.03
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/1 thread(s) 10208834 ns 9378042 ns 1.09
mlp7layer_bn(gelu)(32 x 256)/zygote/GPU/CUDA 1422888 ns 1403069 ns 1.01
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/2 thread(s) 22285625 ns 18793000 ns 1.19
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/4 thread(s) 28463000 ns 27709979.5 ns 1.03
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/8 thread(s) 56517729 ns 49574542 ns 1.14
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/1 thread(s) 18854687.5 ns 18852542 ns 1.00
Dense(512 => 512, relu)(512 x 128)/forward/CPU/2 thread(s) 699792 ns 68959 ns 10.15
Dense(512 => 512, relu)(512 x 128)/forward/CPU/4 thread(s) 644209 ns 541125 ns 1.19
Dense(512 => 512, relu)(512 x 128)/forward/CPU/8 thread(s) 1065062.5 ns 1011562 ns 1.05
Dense(512 => 512, relu)(512 x 128)/forward/CPU/1 thread(s) 728292 ns 728542 ns 1.00
Dense(512 => 512, relu)(512 x 128)/forward/GPU/CUDA 47086.5 ns 47294 ns 1.00
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/2 thread(s) 1513416 ns 277500 ns 5.45
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/4 thread(s) 1010604 ns 988020.5 ns 1.02
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/8 thread(s) 1606083 ns 1388416.5 ns 1.16
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/1 thread(s) 2291666 ns 2250812 ns 1.02
Dense(512 => 512, relu)(512 x 128)/zygote/GPU/CUDA 226725.5 ns 225164 ns 1.01
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/2 thread(s) 1516750.5 ns 407500 ns 3.72
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/4 thread(s) 1076208 ns 1045583 ns 1.03
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/8 thread(s) 1449125 ns 1418917 ns 1.02
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/1 thread(s) 2256125 ns 2256958 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) 3415417 ns 3042083 ns 1.12
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) 2053167 ns 2062771 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) 2513229.5 ns 2510104.5 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) 6017583.5 ns 6011000 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/GPU/CUDA 568598 ns 564983 ns 1.01
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) 24077208 ns 23609021 ns 1.02
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) 17182291.5 ns 17178792 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) 17150417 ns 17120458 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) 37549833 ns 37462729 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/GPU/CUDA 2938820 ns 3146695 ns 0.93
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) 53630958.5 ns 33304750 ns 1.61
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) 81466625 ns 83679583.5 ns 0.97
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) 169486084 ns 167872042 ns 1.01
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) 44624500 ns 44785187.5 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) 250522209 ns 120247125 ns 2.08
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) 148626708 ns 148479500 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) 116110708.5 ns 115610813 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) 447858917 ns 447816417 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/GPU/CUDA 5427690.5 ns 5450922 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) 1104123000 ns 470730291 ns 2.35
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) 859505875 ns 856712645.5 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) 829538646 ns 825513875.5 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) 1754815708 ns 1750589417 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/GPU/CUDA 28735758 ns 28864938 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) 1018403979.5 ns 640143291 ns 1.59
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) 983568208 ns 964190458 ns 1.02
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) 1335719333 ns 1286413958 ns 1.04
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) 1728379395.5 ns 1842051438 ns 0.94
mlp7layer_bn(relu)(32 x 256)/forward/CPU/2 thread(s) 1082292 ns 1241583 ns 0.87
mlp7layer_bn(relu)(32 x 256)/forward/CPU/4 thread(s) 764959 ns 917166 ns 0.83
mlp7layer_bn(relu)(32 x 256)/forward/CPU/8 thread(s) 682709 ns 906584 ns 0.75
mlp7layer_bn(relu)(32 x 256)/forward/CPU/1 thread(s) 2044125 ns 1938583 ns 1.05
mlp7layer_bn(relu)(32 x 256)/forward/GPU/CUDA 554259.5 ns 553409.5 ns 1.00
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/2 thread(s) 5934375 ns 2941333 ns 2.02
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/4 thread(s) 9162896 ns 6314437.5 ns 1.45
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/8 thread(s) 26061854.5 ns 24719833.5 ns 1.05
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/1 thread(s) 7111479 ns 7090125 ns 1.00
mlp7layer_bn(relu)(32 x 256)/zygote/GPU/CUDA 1357512.5 ns 1346593.5 ns 1.01
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/2 thread(s) 9683542 ns 6639250 ns 1.46
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/4 thread(s) 16162959 ns 13128667 ns 1.23
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/8 thread(s) 33355375 ns 30481375 ns 1.09
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/1 thread(s) 7620375 ns 7632854 ns 1.00
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/2 thread(s) 388541 ns 39042 ns 9.95
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/4 thread(s) 518208.5 ns 372792 ns 1.39
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/8 thread(s) 3052583 ns 1833875 ns 1.66
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/1 thread(s) 89500 ns 91792 ns 0.98
Dense(128 => 128, gelu)(128 x 128)/forward/GPU/CUDA 27832 ns 27047.5 ns 1.03
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/2 thread(s) 404666 ns 175458 ns 2.31
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/4 thread(s) 454791 ns 455792 ns 1.00
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/8 thread(s) 4601375 ns 4338875 ns 1.06
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/1 thread(s) 280000 ns 272792 ns 1.03
Dense(128 => 128, gelu)(128 x 128)/zygote/GPU/CUDA 213087 ns 210187.5 ns 1.01
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/2 thread(s) 677583 ns 441709 ns 1.53
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/4 thread(s) 726708.5 ns 728375 ns 1.00
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/8 thread(s) 4653542 ns 4896125 ns 0.95
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/1 thread(s) 522959 ns 511041.5 ns 1.02
Dense(128 => 128, relu)(128 x 128)/forward/CPU/2 thread(s) 334437.5 ns 12416.5 ns 26.93
Dense(128 => 128, relu)(128 x 128)/forward/CPU/4 thread(s) 451521 ns 303334 ns 1.49
Dense(128 => 128, relu)(128 x 128)/forward/CPU/8 thread(s) 774437.5 ns 721771 ns 1.07
Dense(128 => 128, relu)(128 x 128)/forward/CPU/1 thread(s) 52833 ns 55209 ns 0.96
Dense(128 => 128, relu)(128 x 128)/forward/GPU/CUDA 28056 ns 27615.5 ns 1.02
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/2 thread(s) 352584 ns 25917 ns 13.60
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/4 thread(s) 333875 ns 336500 ns 0.99
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/8 thread(s) 902834 ns 850083 ns 1.06
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/1 thread(s) 151959 ns 151500 ns 1.00
Dense(128 => 128, relu)(128 x 128)/zygote/GPU/CUDA 199603.5 ns 198567.5 ns 1.01
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/2 thread(s) 367333 ns 45208.5 ns 8.13
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/4 thread(s) 348125 ns 351625 ns 0.99
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/8 thread(s) 945562.5 ns 712459 ns 1.33
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/1 thread(s) 151375 ns 151084 ns 1.00
vgg16(32, 32, 3, 64)/forward/CPU/2 thread(s) 601502916 ns 318202459 ns 1.89
vgg16(32, 32, 3, 64)/forward/CPU/4 thread(s) 430191604 ns 430387020.5 ns 1.00
vgg16(32, 32, 3, 64)/forward/CPU/8 thread(s) 390437000 ns 368378458.5 ns 1.06
vgg16(32, 32, 3, 64)/forward/CPU/1 thread(s) 871755417 ns 883484291 ns 0.99
vgg16(32, 32, 3, 64)/forward/GPU/CUDA 7623148 ns 7628205 ns 1.00
vgg16(32, 32, 3, 64)/zygote/CPU/2 thread(s) 1994407979.5 ns 1097576562.5 ns 1.82
vgg16(32, 32, 3, 64)/zygote/CPU/4 thread(s) 1636880541.5 ns 1620619666.5 ns 1.01
vgg16(32, 32, 3, 64)/zygote/CPU/8 thread(s) 1572982645.5 ns 1583682354 ns 0.99
vgg16(32, 32, 3, 64)/zygote/CPU/1 thread(s) 2658913333 ns 2698758083 ns 0.99
vgg16(32, 32, 3, 64)/zygote/GPU/CUDA 26625956 ns 26674131 ns 1.00
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/2 thread(s) 525833 ns 189813 ns 2.77
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/4 thread(s) 401229.5 ns 443792 ns 0.90
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/8 thread(s) 2770750 ns 1747875 ns 1.59
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/1 thread(s) 872645.5 ns 873374.5 ns 1.00
Dense(512 => 512, gelu)(512 x 128)/forward/GPU/CUDA 46979 ns 46821 ns 1.00
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/2 thread(s) 1876563 ns 1205958.5 ns 1.56
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/4 thread(s) 1830166.5 ns 2354667 ns 0.78
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/8 thread(s) 16303459 ns 14475333.5 ns 1.13
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/1 thread(s) 2794834 ns 2826417 ns 0.99
Dense(512 => 512, gelu)(512 x 128)/zygote/GPU/CUDA 240187.5 ns 237435.5 ns 1.01
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/2 thread(s) 2919520.5 ns 2299604.5 ns 1.27
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/4 thread(s) 5015167 ns 5735750 ns 0.87
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/8 thread(s) 16524271 ns 14836917 ns 1.11
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/1 thread(s) 3743292 ns 3683375 ns 1.02
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/2 thread(s) 1368417 ns 1579292 ns 0.87
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/4 thread(s) 979958 ns 1180250 ns 0.83
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/8 thread(s) 930917 ns 1174479 ns 0.79
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/1 thread(s) 2342208.5 ns 2370125 ns 0.99
mlp7layer_bn(tanh)(32 x 256)/forward/GPU/CUDA 565552 ns 570253.5 ns 0.99
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/2 thread(s) 5910334 ns 3184000 ns 1.86
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/4 thread(s) 8430229 ns 4719584 ns 1.79
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/8 thread(s) 25837625 ns 24816709 ns 1.04
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/1 thread(s) 7325812 ns 7307438 ns 1.00
mlp7layer_bn(tanh)(32 x 256)/zygote/GPU/CUDA 1327441 ns 1344428.5 ns 0.99
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/2 thread(s) 11696354 ns 8830562.5 ns 1.32
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/4 thread(s) 18020208.5 ns 15640333.5 ns 1.15
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/8 thread(s) 39373729 ns 34223791 ns 1.15
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/1 thread(s) 9553833 ns 9547375 ns 1.00
Dense(16 => 16, relu)(16 x 128)/forward/CPU/2 thread(s) 2459 ns 2209 ns 1.11
Dense(16 => 16, relu)(16 x 128)/forward/CPU/4 thread(s) 2416 ns 2167 ns 1.11
Dense(16 => 16, relu)(16 x 128)/forward/CPU/8 thread(s) 2792 ns 3541 ns 0.79
Dense(16 => 16, relu)(16 x 128)/forward/CPU/1 thread(s) 4583 ns 2625 ns 1.75
Dense(16 => 16, relu)(16 x 128)/forward/GPU/CUDA 24428 ns 24463 ns 1.00
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/2 thread(s) 7291 ns 7000 ns 1.04
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/4 thread(s) 6958 ns 6833 ns 1.02
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/8 thread(s) 7333 ns 7292 ns 1.01
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/1 thread(s) 6750 ns 7167 ns 0.94
Dense(16 => 16, relu)(16 x 128)/zygote/GPU/CUDA 200289 ns 202989.5 ns 0.99
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/2 thread(s) 8416 ns 8334 ns 1.01
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/4 thread(s) 8333 ns 8250 ns 1.01
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/8 thread(s) 8250 ns 8375 ns 0.99
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/1 thread(s) 5625 ns 6041 ns 0.93
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/2 thread(s) 10459 ns 10583 ns 0.99
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/4 thread(s) 12958 ns 15875 ns 0.82
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/8 thread(s) 11333.5 ns 10333 ns 1.10
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/1 thread(s) 7791 ns 7625.5 ns 1.02
Dense(16 => 16, gelu)(16 x 128)/forward/GPU/CUDA 24856 ns 24500 ns 1.01
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/2 thread(s) 21709 ns 21542 ns 1.01
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/4 thread(s) 21459 ns 21625 ns 0.99
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/8 thread(s) 21792 ns 21750 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/1 thread(s) 21167 ns 21667 ns 0.98
Dense(16 => 16, gelu)(16 x 128)/zygote/GPU/CUDA 220349.5 ns 221414.5 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/2 thread(s) 53584 ns 56833 ns 0.94
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/4 thread(s) 53583 ns 53708 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/8 thread(s) 53770.5 ns 53625 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/1 thread(s) 51125 ns 51583.5 ns 0.99
Dense(128 => 128, identity)(128 x 128)/forward/CPU/2 thread(s) 28750 ns 28834 ns 1.00
Dense(128 => 128, identity)(128 x 128)/forward/CPU/4 thread(s) 28916 ns 28584 ns 1.01
Dense(128 => 128, identity)(128 x 128)/forward/CPU/8 thread(s) 28875 ns 28458 ns 1.01
Dense(128 => 128, identity)(128 x 128)/forward/CPU/1 thread(s) 45875 ns 46708 ns 0.98
Dense(128 => 128, identity)(128 x 128)/forward/GPU/CUDA 26054 ns 25617 ns 1.02
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/2 thread(s) 228541 ns 44375 ns 5.15
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/4 thread(s) 275333 ns 274708 ns 1.00
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/8 thread(s) 4217667 ns 4275000 ns 0.99
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/1 thread(s) 145250 ns 145000 ns 1.00
Dense(128 => 128, identity)(128 x 128)/zygote/GPU/CUDA 199681 ns 206652.5 ns 0.97
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/2 thread(s) 246459 ns 68542 ns 3.60
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/4 thread(s) 293145.5 ns 292958 ns 1.00
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/8 thread(s) 4145854 ns 4229958 ns 0.98
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/1 thread(s) 145542 ns 145666 ns 1.00
Dense(16 => 16, identity)(16 x 128)/forward/CPU/2 thread(s) 1959 ns 1833 ns 1.07
Dense(16 => 16, identity)(16 x 128)/forward/CPU/4 thread(s) 2000 ns 1750 ns 1.14
Dense(16 => 16, identity)(16 x 128)/forward/CPU/8 thread(s) 2000 ns 2500 ns 0.80
Dense(16 => 16, identity)(16 x 128)/forward/CPU/1 thread(s) 1708 ns 1666 ns 1.03
Dense(16 => 16, identity)(16 x 128)/forward/GPU/CUDA 22940 ns 22972 ns 1.00
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/2 thread(s) 5334 ns 5208 ns 1.02
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/4 thread(s) 5125 ns 5167 ns 0.99
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/8 thread(s) 5166 ns 5208 ns 0.99
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/1 thread(s) 4792 ns 5250 ns 0.91
Dense(16 => 16, identity)(16 x 128)/zygote/GPU/CUDA 232790 ns 244140 ns 0.95
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/2 thread(s) 7417 ns 8208 ns 0.90
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/4 thread(s) 7375 ns 7375 ns 1
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/8 thread(s) 7459 ns 7542 ns 0.99
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/1 thread(s) 5250 ns 5292 ns 0.99
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) 81082749.5 ns 34124291 ns 2.38
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) 48527208 ns 49799333 ns 0.97
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) 43737084 ns 45669229.5 ns 0.96
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) 153734041 ns 153888625 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/GPU/CUDA 2717702 ns 2656121 ns 1.02
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) 621583083 ns 481321500.5 ns 1.29
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) 427560417 ns 424493583 ns 1.01
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) 412343333.5 ns 412050834 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) 697842291 ns 724714916 ns 0.96
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/GPU/CUDA 15532428 ns 15594271 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) 851105979 ns 744920541 ns 1.14
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) 840062312.5 ns 840757958.5 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) 1156974917 ns 1131213854 ns 1.02
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) 1177103062.5 ns 1186689479.5 ns 0.99

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.