Skip to content

Commit

Permalink
Merge pull request #553 from LuxDL/ap/standardize_docs
Browse files Browse the repository at this point in the history
Standardize the handling of states
  • Loading branch information
avik-pal authored Mar 20, 2024
2 parents 76f0ff2 + 435aaa9 commit 32a909a
Show file tree
Hide file tree
Showing 22 changed files with 244 additions and 297 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/Downgrade.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
version: ['1.9']
version: ['1.10']
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
Expand Down
25 changes: 12 additions & 13 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.5.26"
version = "0.5.27"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -18,12 +18,10 @@ LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
Expand All @@ -33,6 +31,7 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand All @@ -44,6 +43,7 @@ LuxComponentArraysExt = "ComponentArrays"
LuxComponentArraysReverseDiffExt = ["ComponentArrays", "ReverseDiff"]
LuxFluxExt = "Flux"
LuxLuxAMDGPUExt = "LuxAMDGPU"
LuxOptimisersExt = "Optimisers"
LuxReverseDiffExt = "ReverseDiff"
LuxSimpleChainsExt = "SimpleChains"
LuxTrackerExt = "Tracker"
Expand All @@ -62,34 +62,33 @@ ConstructionBase = "1.5"
Flux = "0.14.11"
Functors = "0.4.4"
GPUArraysCore = "0.1.6"
LinearAlgebra = "1.9"
Logging = "1.9"
LinearAlgebra = "1.10"
Logging = "1.10"
LuxAMDGPU = "0.2.2"
LuxCUDA = "0.3.2"
LuxCore = "0.1.12"
LuxDeviceUtils = "0.1.14"
LuxLib = "0.3.10"
LuxTestUtils = "0.1.15"
MacroTools = "0.5.13"
Markdown = "1.9"
Markdown = "1.10"
Optimisers = "0.3"
Pkg = "1.9"
Pkg = "1.10"
PrecompileTools = "1.2"
Random = "1.9"
Random = "1.10"
ReTestItems = "1.23.1"
Reexport = "1"
ReverseDiff = "1.15"
Setfield = "1"
SimpleChains = "0.4.6"
SparseArrays = "1.9"
StableRNGs = "1"
Statistics = "1.9"
Test = "1.9"
Statistics = "1.10"
Test = "1.10"
Tracker = "0.2.31"
TruncatedStacktraces = "1.1"
TruncatedStacktraces = "1.4"
WeightInitializers = "0.1.5"
Zygote = "0.6.69"
julia = "1.9"
julia = "1.10"

[extras]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
68 changes: 0 additions & 68 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,74 +74,6 @@ Checkout our [Ecosystem](http://lux.csail.mit.edu/dev/ecosystem/) page for more

For usage related questions, please use [Github Discussions](https://github.com/LuxDL/Lux.jl/discussions) or [JuliaLang Discourse (machine learning domain)](https://discourse.julialang.org/c/domain/ml/) which allows questions and answers to be indexed. To report bugs use [github issues](https://github.com/LuxDL/Lux.jl/issues) or even better send in a [pull request](https://github.com/LuxDL/Lux.jl/pulls).


## Package Ecosystem Structure

Structure of the packages part of the `Lux.jl` Universe[^1]: (Rounded Rectangles denote packages maintained by `Lux.jl` developers)

[^1]: These packages only constitute a subset of the ecosystem. Specifically these are the packages which the maintainers of Lux.jl have personally tested out. If you want a new package to be listed here, please open an issue.

```mermaid
flowchart LR
subgraph Interface
LuxCore(LuxCore)
end
subgraph Backend
LuxLib(LuxLib)
NNlib
CUDA
end
subgraph ExternalML[External ML Packages]
Flux
Metalhead
end
subgraph CompViz[Computer Vision]
Boltz(Boltz)
end
subgraph SciML[Scientific Machine Learning]
DeepEquilibriumNetworks(DeepEquilibriumNetworks)
DiffEqFlux(DiffEqFlux)
NeuralPDE[Neural PDE: PINNs]
end
subgraph AD[Automatic Differentiation]
Zygote
Enzyme["Enzyme (experimental)"]
end
subgraph Dist[Distributed Training]
FluxMPI(FluxMPI)
end
subgraph SerializeModels[Serialize Models]
Serial[Serialization]
JLD2
BSON
end
subgraph Opt[Optimization]
Optimisers
Optimization
end
subgraph Parameters
ComponentArrays
end
Lux(Lux)
Parameters --> Lux
LuxCore --> Lux
Backend --> Lux
Lux --> SciML
AD --> Lux
Lux --> Dist
Lux --> SerializeModels
Lux --> Opt
Lux --> CompViz
ExternalML -.-> CompViz
```

## Related Projects

* [Flux.jl](https://github.com/FluxML/Flux.jl) -- We share most of the backend infrastructure with Flux ([Roadmap](https://github.com/FluxML/Flux.jl/issues/1829) hints towards making Flux explicit-parameter first)
* [Knet.jl](https://github.com/denizyuret/Knet.jl) -- One of the mature and OG Julia Deep Learning Frameworks
* [SimpleChains.jl](https://github.com/PumasAI/SimpleChains.jl) -- Extremely Efficient for Small Neural Networks on CPU
* [Avalon.jl](https://github.com/dfdx/Avalon.jl) -- Uses tracing based AD [Yota.jl](https://github.com/dfdx/Yota.jl)

## Citation

If you found this library to be useful in academic work, then please cite:
Expand Down
22 changes: 13 additions & 9 deletions examples/Basics/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ f(x) = x' * x / 2
∇f(x) = x # `∇` can be typed as `\nabla<TAB>`
v = randn(rng, Float32, 4)

# Let's use AbstractDifferentiation and Zygote to compute the gradients.
# Let's use ForwardDiff and Zygote to compute the gradients.

println("Actual Gradient: ", ∇f(v))
println("Computed Gradient via Reverse Mode AD (Zygote): ", only(Zygote.gradient(f, v)))
Expand Down Expand Up @@ -313,7 +313,10 @@ opt = Optimisers.Descent(0.01f0)
opt_state = Optimisers.setup(opt, ps)

# Define the loss function
mse(model, ps, st, X, y) = sum(abs2, model(X, ps, st)[1] .- y)
function mse(model, ps, st, X, y)
y_pred, st_new = model(X, ps, st)
return sum(abs2, y_pred .- y), st_new
end
mse(weight, bias, X, y) = sum(abs2, weight * X .+ bias .- y)
loss_function(ps, X, y) = mse(model, ps, st, X, y)

Expand All @@ -323,12 +326,13 @@ for i in 1:100
## In actual code, don't use globals. But here I will simply for the sake of
## demonstration
global ps, st, opt_state
## Compute the gradient
gs = gradient(loss_function, ps, x_samples, y_samples)[1]
## Compute the gradient using the pullback API to update the states
(loss, st), pb_f = Zygote.pullback(loss_function, ps, x_samples, y_samples)
## We pass nothing as the seed for `st`, since we don't want to propagate any gradient
## for st
gs = pb_f((one(loss), nothing))[1]
## Update model parameters
opt_state, ps = Optimisers.update(opt_state, ps, gs)
if i % 10 == 1 || i == 100
println(
"Loss Value after $i iterations: ", mse(model, ps, st, x_samples, y_samples))
end
## `Optimisers.update` can be used if mutation is not desired
opt_state, ps = Optimisers.update!(opt_state, ps, gs)
(i % 10 == 1 || i == 100) && println(lazy"Loss Value after $i iterations: $loss")
end
2 changes: 0 additions & 2 deletions examples/BayesianNN/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MakiePublication = "dde8697e-0d61-460d-88dd-856f66710dd1"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
Expand All @@ -17,7 +16,6 @@ Functors = "0.2, 0.3, 0.4"
LinearAlgebra = "1"
Literate = "2"
Lux = "0.5"
MakiePublication = "0.3"
Random = "1"
Tracker = "0.2"
Turing = "0.30"
Expand Down
32 changes: 15 additions & 17 deletions examples/BayesianNN/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Pkg.instantiate(; io=pkg_io) #hide
Pkg.develop(; path=joinpath(__DIR, "..", ".."), io=pkg_io) #hide
Pkg.precompile(; io=pkg_io) #hide
close(pkg_io) #hide
using Lux, Turing, CairoMakie, Random, Tracker, Functors, MakiePublication, LinearAlgebra
using Lux, Turing, CairoMakie, Random, Tracker, Functors, LinearAlgebra

## Sampling progress
Turing.setprogress!(true);
Expand Down Expand Up @@ -60,15 +60,11 @@ function plot_data()
x2 = first.(xt0s)
y2 = last.(xt0s)

fig = with_theme(theme_web()) do
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

scatter!(ax, x1, y1; markersize=8, color=:red, strokecolor=:black, strokewidth=1)
scatter!(ax, x2, y2; markersize=8, color=:blue, strokecolor=:black, strokewidth=1)

return fig
end
scatter!(ax, x1, y1; markersize=16, color=:red, strokecolor=:black, strokewidth=2)
scatter!(ax, x2, y2; markersize=16, color=:blue, strokecolor=:black, strokewidth=2)

return fig
end
Expand Down Expand Up @@ -117,19 +113,21 @@ function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
return fmap(get_ps, ps)
end

# To interface with external libraries it is often desirable to use the
# [`StatefulLuxLayer`](@ref) to automatically handle the neural network states.
const model = StatefulLuxLayer(nn, st)

## Specify the probabilistic model.
@model function bayes_nn(xs, ts)
global st

## Sample the parameters
nparameters = Lux.parameterlength(nn)
parameters ~ MvNormal(zeros(nparameters), Diagonal(abs2.(sig .* ones(nparameters))))

## Forward NN to make predictions
preds, st = Lux.apply(nn, xs, vector_to_parameters(parameters, ps), st)
preds = Lux.apply(model, xs, vector_to_parameters(parameters, ps))

## Observe each prediction.
for i in 1:length(ts)
for i in eachindex(ts)
ts[i] ~ Bernoulli(preds[i])
end
end
Expand All @@ -150,7 +148,7 @@ ch = sample(bayes_nn(reduce(hcat, xs), ts), HMC(0.05, 4; adtype=AutoTracker()),
# ## Prediction Visualization

## A helper to run the nn through data `x` using parameters `θ`
nn_forward(x, θ) = first(nn(x, vector_to_parameters(θ, ps), st))
nn_forward(x, θ) = model(x, vector_to_parameters(θ, ps))

## Plot the data we have.
fig = plot_data()
Expand All @@ -165,7 +163,7 @@ i = i.I[1]
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_forward([x1, x2], θ[i, :])[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z)
contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
fig

# The contour plot above shows that the MAP method is not too bad at classifying our data.
Expand All @@ -192,7 +190,7 @@ n_end = 1500
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_predict([x1, x2], θ, n_end)[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z)
contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
fig

# Suppose we are interested in how the predictive power of our Bayesian neural network
Expand All @@ -201,7 +199,7 @@ fig

fig = plot_data()
Z = [first(nn_forward([x1, x2], θ[1, :])) for x1 in x1_range, x2 in x2_range]
c = contour!(x1_range, x2_range, Z)
c = contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
record(fig, "results.gif", 1:250:size(θ, 1)) do i
fig.current_axis[].title = "Iteration: $i"
Z = [first(nn_forward([x1, x2], θ[i, :])) for x1 in x1_range, x2 in x2_range]
Expand Down
3 changes: 1 addition & 2 deletions examples/GravitationalWaveForm/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
MakiePublication = "dde8697e-0d61-460d-88dd-856f66710dd1"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"

Expand All @@ -22,7 +22,6 @@ Literate = "2"
Lux = "0.5"
LuxAMDGPU = "0.1, 0.2"
LuxCUDA = "0.2, 0.3"
MakiePublication = "0.3"
Optimization = "3"
OptimizationOptimJL = "0.1, 0.2"
OrdinaryDiffEq = "6"
Expand Down
Loading

2 comments on commit 32a909a

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/103289

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.27 -m "<description of version>" 32a909ae3e32f708cf52f344e87e21f3f2b12bd1
git push origin v0.5.27

Please sign in to comment.