Skip to content

Commit

Permalink
perf: use the permuted formulation
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 26, 2024
1 parent 994f53a commit c2439c6
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions bench/lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ train!(args...; kwargs...) = train!(MSELoss(), AutoZygote(), args...; kwargs...)
function train!(loss, backend, model, ps, st, data; epochs=10)
l1 = loss(model, ps, st, first(data))

tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.01f0))
tstate = Training.TrainState(model, ps, st, Adam(0.01f0))
for _ in 1:epochs, (x, y) in data
_, _, _, tstate = Lux.Experimental.single_train_step!(backend, loss, (x, y), tstate)
_, _, _, tstate = Training.single_train_step!(backend, loss, (x, y), tstate)
end

l2 = loss(model, ps, st, first(data))
Expand All @@ -25,14 +25,14 @@ end
n_points = 128
batch_size = 64

x = rand(Float32, 1, n_points, batch_size);
y = rand(Float32, 1, n_points, batch_size);
x = rand(Float32, n_points, 1, batch_size);
y = rand(Float32, n_points, 1, batch_size);
data = [(x, y)];
t_fwd = zeros(5)
t_train = zeros(5)
for i in 1:5
chs = (1, 128, fill(64, i)..., 128, 1)
model = FourierNeuralOperator(gelu; chs=chs, modes=(16,))
model = FourierNeuralOperator(gelu; chs, modes=(16,), permuted=Val(true))
ps, st = Lux.setup(rng, model)
model(x, ps, st) # TTFX

Expand Down

0 comments on commit c2439c6

Please sign in to comment.