Skip to content

Commit

Permalink
Merge pull request #584 from LuxDL/ap/promote_compact
Browse files Browse the repository at this point in the history
Improvement to the `@compact` API
  • Loading branch information
avik-pal authored Apr 13, 2024
2 parents fc591bd + 72fc49f commit 9f1d902
Show file tree
Hide file tree
Showing 26 changed files with 267 additions and 125 deletions.
13 changes: 13 additions & 0 deletions .github/workflows/SpellCheck.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name: Spell Check

on: [pull_request]

jobs:
typos-check:
name: Spell Check with Typos
runs-on: ubuntu-latest
steps:
- name: Checkout Actions Repository
uses: actions/checkout@v4
- name: Check spelling
uses: crate-ci/typos@v1.18.0
2 changes: 2 additions & 0 deletions .typos.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[default.extend-words]
numer = "numer"
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.5.34"
version = "0.5.35"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
16 changes: 11 additions & 5 deletions docs/src/api/Lux/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ CurrentModule = Lux

All features listed on this page are **experimental** which means:

1. No SemVer Guarantees. We use code here to iterate fast and most users should wait for
these features to be marked non-experimental.
1. No SemVer Guarantees. We use code here to iterate fast. That said, historically we have
never broken any code in this module and have always provided a deprecation period.
2. Expect edge-cases and report them. It will help us move these features out of
experimental sooner.
3. None of the features are exported.
Expand Down Expand Up @@ -74,8 +74,14 @@ Lux.Experimental.DebugLayer
Lux.Experimental.share_parameters
```

## StatefulLuxLayer

[`Lux.StatefulLuxLayer`](@ref) used to be part of experimental features, but has been
promoted to stable API. It is now available via `Lux.StatefulLuxLayer`. Change all uses of
`Lux.Experimental.StatefulLuxLayer` to `Lux.StatefulLuxLayer`.

## Compact Layer API

```@docs
Lux.Experimental.@compact
```
[`Lux.@compact`](@ref) used to be part of experimental features, but has been promoted to
stable API. It is now available via `Lux.@compact`. Change all uses of
`Lux.Experimental.@compact` to `Lux.@compact`.
5 changes: 5 additions & 0 deletions docs/src/api/Lux/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ Lux.f64
StatefulLuxLayer
```

## Compact Layer

```@docs
@compact
```

## Truncated Stacktraces

Expand Down
11 changes: 5 additions & 6 deletions docs/src/introduction/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,21 @@ standard AD and Optimisers API.

```@example quickstart
# Get the device determined by Lux
device = gpu_device()
dev = gpu_device()
# Parameter and State Variables
ps, st = Lux.setup(rng, model) .|> device
ps, st = Lux.setup(rng, model) .|> dev
# Dummy Input
x = rand(rng, Float32, 128, 2) |> device
x = rand(rng, Float32, 128, 2) |> dev
# Run the model
y, st = Lux.apply(model, x, ps, st)
# Gradients
## Pullback API to capture change in state
(l, st_), pb = pullback(p -> Lux.apply(model, x, p, st), ps)
gs = pb((one.(l), nothing))[1]
(l, st_), pb = pullback(Lux.apply, model, x, ps, st)
gs = pb((one.(l), nothing))[3]
# Optimization
st_opt = Optimisers.setup(Adam(0.0001f0), ps)
Expand All @@ -74,7 +74,6 @@ st_opt, ps = Optimisers.update(st_opt, ps, gs)
```@example custom_compact
using Lux, Random, Optimisers, Zygote
# using LuxCUDA, LuxAMDGPU, Metal # Optional packages for GPU support
import Lux.Experimental: @compact
using Printf # For pretty printing
```

Expand Down
13 changes: 10 additions & 3 deletions docs/src/manual/interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ First let's set the expectations straight.
functionality in the core library (and officially supported ones) **must** adhere to
the interface

!!! tip

While writing out a custom struct and defining dispatches manually is a good way to
understand the interface, it is not the most concise way. We recommend using the
[`Lux.@compact`](@ref) macro to define layers which makes handling the states and
parameters downright trivial.

## Layer Interface

### Singular Layer
Expand All @@ -35,8 +42,8 @@ architecture cannot change.

!!! tip

For people coming from Flux.jl background this might be weird. We recommend checking out
[the Flux to Lux migration guide](@ref migrate-from-flux) first before proceeding.
For people coming from Flux.jl background, this might be weird. We recommend checking
out [the Flux to Lux migration guide](@ref migrate-from-flux) first before proceeding.

```@example layer_interface
using Lux, Random
Expand Down Expand Up @@ -80,7 +87,7 @@ reconstruction of the parameters and states.
println("Parameter Length: ", Lux.parameterlength(l), "; State Length: ",
Lux.statelength(l))
# But still recommened to define these
# But still recommended to define these
Lux.parameterlength(l::Linear) = l.out_dims * l.in_dims + l.out_dims
Lux.statelength(::Linear) = 0
Expand Down
2 changes: 1 addition & 1 deletion docs/src/manual/migrate_from_flux.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ end
# `A` is not trainable
Optimisers.trainable(f::FluxLinear) = (B=f.B,)

# Needed so that both `A` and `B` can be transfered between devices
# Needed so that both `A` and `B` can be transferred between devices
Flux.@functor FluxLinear

(l::FluxLinear)(x) = l.A * l.B * x
Expand Down
4 changes: 2 additions & 2 deletions docs/src/tutorials/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ const advanced = [
}
];
const thrid_party = [
const third_party = [
{
href: "https://docs.sciml.ai/Overview/stable/showcase/pinngpu/",
src: "../pinn.gif",
Expand Down Expand Up @@ -114,7 +114,7 @@ of them are non-functional and we will try to get them updated.
:::
<Gallery :images="thrid_party" />
<Gallery :images="third_party" />
::: tip
Expand Down
2 changes: 1 addition & 1 deletion examples/DDIM/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The model generates images from Gaussian noises by <em>denoising</em> iterativel
# Usage
Install Julia and instantiate `Project.toml`.

Follwoing scripts are tested on a single NVIDIA Tesla T4 instance.
Following scripts are tested on a single NVIDIA Tesla T4 instance.
## Dataset
Download and extract `Dataset images` from [102 Category Flower Dataset](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/).

Expand Down
2 changes: 1 addition & 1 deletion examples/GravitationalWaveForm/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {
m₁ = mass_ratio * m₂

orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit1, mass1, orbit2, mass2)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
Expand Down
40 changes: 16 additions & 24 deletions examples/HyperNet/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,33 +31,25 @@ function load_datasets(n_train=1024, n_eval=32, batchsize=256)
end

# ## Implement a HyperNet Layer
struct HyperNet{W <: Lux.AbstractExplicitLayer, C <: Lux.AbstractExplicitLayer, A} <:
Lux.AbstractExplicitContainerLayer{(:weight_generator, :core_network)}
weight_generator::W
core_network::C
ca_axes::A
end

function HyperNet(w::Lux.AbstractExplicitLayer, c::Lux.AbstractExplicitLayer)
ca_axes = Lux.initialparameters(Random.default_rng(), c) |> ComponentArray |> getaxes
return HyperNet(w, c, ca_axes)
end

function Lux.initialparameters(rng::AbstractRNG, h::HyperNet)
return (weight_generator=Lux.initialparameters(rng, h.weight_generator),)
function HyperNet(weight_generator::Lux.AbstractExplicitLayer,
core_network::Lux.AbstractExplicitLayer)
ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
ComponentArray |>
getaxes
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
## Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
return core_network(y, ps_new)
end
end

function (hn::HyperNet)(x, ps, st::NamedTuple)
ps_new, st_ = hn.weight_generator(x, ps.weight_generator, st.weight_generator)
@set! st.weight_generator = st_
return ComponentArray(vec(ps_new), hn.ca_axes), st
end
# Defining functions on the CompactLuxLayer requires some understanding of how the layer
# is structured, as such we don't recommend doing it unless you are familiar with the
# internals. In this case, we simply write it to ignore the initialization of the
# `core_network` parameters.

function (hn::HyperNet)((x, y)::T, ps, st::NamedTuple) where {T <: Tuple}
ps_ca, st = hn(x, ps, st)
pred, st_ = hn.core_network(y, ps_ca, st.core_network)
@set! st.core_network = st_
return pred, st
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
end

# ## Create and Initialize the HyperNet
Expand Down
33 changes: 29 additions & 4 deletions examples/NeuralODE/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,23 @@ function loadmnist(batchsize, train_split)
end

# ## Define the Neural ODE Layer
#
#
# First we will use the [`@compact`](@ref) macro to define the Neural ODE Layer.

function NeuralODECompact(
model::Lux.AbstractExplicitLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...)
return @compact(; model, solver, tspan, kwargs...) do x, p
dudt(u, p, t) = vec(model(reshape(u, size(x)), p))
## Note the `p.model` here
prob = ODEProblem(ODEFunction{false}(dudt), vec(x), tspan, p.model)
return solve(prob, solver; kwargs...)
end
end

# We recommend using the compact macro for creating custom layers. The below implementation
# exists mostly for historical reasons when `@compact` was not part of the stable API. Also,
# it helps users understand how the layer interface of Lux works.

# The NeuralODE is a ContainerLayer, which stores a `model`. The parameters and states of
# the NeuralODE are same as those of the underlying model.
struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, T, K} <:
Expand Down Expand Up @@ -154,6 +170,8 @@ function train(model_function; cpu::Bool=false, kwargs...)
end
end

train(NeuralODECompact)

train(NeuralODE)

# We can also change the sensealg and train the model! `GaussAdjoint` allows you to use
Expand All @@ -173,8 +191,9 @@ train(NeuralODE; sensealg=ReverseDiffAdjoint(), cpu=true)

# ## Alternate Implementation using Stateful Layer

# Starting `v0.5.5`, Lux provides a `Lux.Experimental.StatefulLuxLayer` which can be used
# to avoid the [`Box`ing of `st`](https://github.com/JuliaLang/julia/issues/15276).
# Starting `v0.5.5`, Lux provides a [`StatefulLuxLayer`](@ref) which can be used
# to avoid the [`Box`ing of `st`](https://github.com/JuliaLang/julia/issues/15276). Using
# the `@compact` API avoids this problem entirely.
struct StatefulNeuralODE{M <: Lux.AbstractExplicitLayer, So, T, K} <:
Lux.AbstractExplicitContainerLayer{(:model,)}
model::M
Expand All @@ -189,7 +208,7 @@ function StatefulNeuralODE(
end

function (n::StatefulNeuralODE)(x, ps, st)
st_model = Lux.StatefulLuxLayer(n.model, ps, st)
st_model = StatefulLuxLayer(n.model, ps, st)
dudt(u, p, t) = st_model(u, p)
prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps)
return solve(prob, n.solver; n.kwargs...), st_model.st
Expand Down Expand Up @@ -219,3 +238,9 @@ x = gpu_device()(ones(Float32, 28, 28, 1, 3));

# Note, that we still recommend using this layer internally and not exposing this as the
# default API to the users.

# Finally checking the compact model

model_compact, ps_compact, st_compact = create_model(NeuralODECompact)

@code_warntype model_compact(x, ps_compact, st_compact)
30 changes: 27 additions & 3 deletions examples/SimpleRNN/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,25 @@ function (s::SpiralClassifier)(
return vec(y), st
end

# ## Using the `@compact` API

# We can also define the model using the [`Lux.@compact`](@ref) API, which is a more concise
# way of defining models. This macro automatically handles the boilerplate code for you and
# as such we recommend this way of defining custom layers

function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
y, carry = lstm_cell((x, carry))
end
return vec(classifier(y))
end
end

# ## Defining Accuracy, Loss and Optimiser

# Now let's define the binarycrossentropy loss. Typically it is recommended to use
Expand All @@ -125,12 +144,12 @@ accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)

# ## Training the Model

function main()
function main(model_type)
## Get the dataloaders
(train_loader, val_loader) = get_dataloaders()

## Create the model
model = SpiralClassifier(2, 8, 1)
model = model_type(2, 8, 1)
rng = Xoshiro(0)

dev = gpu_device()
Expand Down Expand Up @@ -164,7 +183,12 @@ function main()
return (train_state.parameters, train_state.states) |> cpu_device()
end

ps_trained, st_trained = main()
ps_trained, st_trained = main(SpiralClassifier)
nothing #hide

# We can also train the compact model with the exact same code!

ps_trained2, st_trained2 = main(SpiralClassifierCompact)
nothing #hide

# ## Saving the Model
Expand Down
9 changes: 8 additions & 1 deletion src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ using PrecompileTools: @recompile_invalidations
inputsize, outputsize, update_state, trainmode, testmode, setup, apply,
display_name, replicate
using LuxDeviceUtils: get_device

# @compact specific
using MacroTools: block, combinedef, splitdef
using ConstructionBase: ConstructionBase
end

@reexport using LuxCore, LuxLib, LuxDeviceUtils, WeightInitializers
Expand Down Expand Up @@ -56,6 +60,7 @@ include("contrib/contrib.jl")

# Helpful Functionalities
include("helpers/stateful.jl")
include("helpers/compact.jl")

# Transform to and from other frameworks
include("transform/types.jl")
Expand All @@ -70,7 +75,8 @@ include("distributed/public_api.jl")
include("deprecated.jl")

# Layers
export cpu, gpu
export cpu, gpu # deprecated

export Chain, Parallel, SkipConnection, PairwiseFusion, BranchLayer, Maxout, RepeatedLayer
export Bilinear, Dense, Embedding, Scale
export Conv, ConvTranspose, CrossCor, MaxPool, MeanPool, GlobalMaxPool, GlobalMeanPool,
Expand All @@ -83,6 +89,7 @@ export RNNCell, LSTMCell, GRUCell, Recurrence, StatefulRecurrentCell
export SamePad, TimeLastIndex, BatchLastIndex

export StatefulLuxLayer
export @compact, CompactLuxLayer

export f16, f32, f64

Expand Down
Loading

1 comment on commit 9f1d902

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 9f1d902 Previous: fc591bd Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3643 ns 3653 ns 1.00
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 7953.166666666667 ns 7729.5 ns 1.03
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 14727 ns 14106 ns 1.04
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9702 ns 9916 ns 0.98
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 8947 ns 8698.75 ns 1.03
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4445.75 ns 4506.5625 ns 0.99
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1981.7 ns 1971.7 ns 1.01
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1634.3680555555557 ns 1648.5314685314686 ns 0.99
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1772.7592592592594 ns 1824.8510638297873 ns 0.97
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 179.9111424541608 ns 179.4728789986092 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17322 ns 17333 ns 1.00
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 18544 ns 18394 ns 1.01
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 34985 ns 35396 ns 0.99
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 28292 ns 28633 ns 0.99
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 20618 ns 19607 ns 1.05
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 16972 ns 17032 ns 1.00
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 4747.428571428572 ns 4768.857142857143 ns 1.00
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 4751.714285714285 ns 4800.428571428572 ns 0.99
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 4766 ns 4800.428571428572 ns 0.99
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1661.1 ns 1659.1 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 47236878 ns 48367699 ns 0.98
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 79390814 ns 90662926 ns 0.88
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 110354486.5 ns 97653785.5 ns 1.13
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 107007516 ns 107727588 ns 0.99
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 90584786 ns 108249388 ns 0.84
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 12157594 ns 12110710 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 18467350 ns 18210910.5 ns 1.01
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 12094742 ns 18544073 ns 0.65
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 18057739 ns 18466654 ns 0.98
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6406764 ns 6396982 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 1) 103302825 ns 106620467.5 ns 0.97
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 915303511 ns 763640160 ns 1.20
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 3001164643 ns 2762978316 ns 1.09
vgg16/cpu/reverse/Tracker/(32, 32, 3, 1) 161191117 ns 163403619 ns 0.99
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 1077040118 ns 1198898689 ns 0.90
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 3836744884 ns 3765767577 ns 1.02
vgg16/cpu/reverse/Flux/(32, 32, 3, 1) 84401417 ns 85276372.5 ns 0.99
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 821989995 ns 840374369 ns 0.98
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 3595869124 ns 3347793443 ns 1.07
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 1) 24852621 ns 25080614.5 ns 0.99
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 221149106.5 ns 232258093 ns 0.95
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 885666932.5 ns 1019038431 ns 0.87
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 1) 26159258.5 ns 25059892 ns 1.04
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 223180012 ns 236184814.5 ns 0.94
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 902076527.5 ns 999233211 ns 0.90
vgg16/cpu/forward/Flux/(32, 32, 3, 1) 24166346 ns 24562440.5 ns 0.98
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 212906366 ns 211748278 ns 1.01
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 824038154.5 ns 712431369.5 ns 1.16
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1052403840 ns 1132641019 ns 0.93
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1950974267 ns 1842889677.5 ns 1.06
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2254863718 ns 2124383065.5 ns 1.06
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2569062884 ns 2365462129 ns 1.09
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 1987293572 ns 1854224454.5 ns 1.07
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 372672192.5 ns 456010240 ns 0.82
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 384292918.5 ns 359691595 ns 1.07
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 375958081.5 ns 359652717.5 ns 1.05
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 11991204 ns 11966091 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 18085257 ns 18076793 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 19537014 ns 19252254 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 23942477 ns 23893264 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 18064912 ns 18061934 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1147017 ns 1158025 ns 0.99
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 2077290 ns 2075109 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 2088844.5 ns 2081892 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2092047.5 ns 2071516.5 ns 1.01
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 201450.5 ns 200054 ns 1.01
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 297764.5 ns 298147 ns 1.00
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 271320.5 ns 273642 ns 0.99
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 359580 ns 365467.5 ns 0.98
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 409393 ns 414444.5 ns 0.99
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 273400 ns 275154 ns 0.99
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 409113 ns 410968 ns 1.00
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 88825 ns 89371.5 ns 0.99
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 91651 ns 89357.5 ns 1.03
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 86871.5 ns 87022 ns 1.00
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104455 ns 104495 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 210883198 ns 197534448 ns 1.07
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 411576886 ns 372121710 ns 1.11
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 445918896 ns 403011132 ns 1.11
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 474241975 ns 482377826 ns 0.98
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 409801025 ns 371969112 ns 1.10
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 346098127 ns 334264188.5 ns 1.04
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 57379896.5 ns 59961589 ns 0.96
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 57238569.5 ns 53644168 ns 1.07
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 57607353 ns 56527647 ns 1.02
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 28476232 ns 29291598.5 ns 0.97
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 19438952.5 ns 19730534 ns 0.99
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19717994 ns 19802579.5 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 23319997 ns 23663463 ns 0.99
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 24172745 ns 24349385 ns 0.99
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19750594.5 ns 19922312 ns 0.99
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6603171 ns 6620742 ns 1.00
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6609978 ns 6614070 ns 1.00
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6584581 ns 6529781 ns 1.01

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.