diff --git a/previews/PR474/api/Accelerator_Support/LuxAMDGPU.md b/previews/PR474/api/Accelerator_Support/LuxAMDGPU.md new file mode 100644 index 000000000..7c4cc8c98 --- /dev/null +++ b/previews/PR474/api/Accelerator_Support/LuxAMDGPU.md @@ -0,0 +1,35 @@ + + + +# LuxAMDGPU + + + + +`LuxAMDGPU` is meant to be used as a trigger package for all `AMDGPU` dependencies in `Lux`. Users requiring AMDGPU support should install `LuxAMDGPU` and load it alongside `Lux`. + + + + +## Index + +- [`LuxAMDGPU.functional`](#LuxAMDGPU.functional-Tuple{}) + + + + +## API + +
+# LuxAMDGPU.functionalMethod. + + + +```julia +functional() +``` + +Check if LuxAMDGPU is functional. + +
+
diff --git a/previews/PR474/api/Accelerator_Support/LuxCUDA.md b/previews/PR474/api/Accelerator_Support/LuxCUDA.md new file mode 100644 index 000000000..e8df00654 --- /dev/null +++ b/previews/PR474/api/Accelerator_Support/LuxCUDA.md @@ -0,0 +1,35 @@ + + + + + +# LuxCUDA + + +`LuxCUDA` is meant to be used as a trigger package for all `CUDA` dependencies in `Lux`. Users requiring CUDA support should install `LuxCUDA` and load it alongside `Lux`. + + + + +## Index + +- [`LuxCUDA.functional`](#LuxCUDA.functional-Tuple{}) + + + + +## API Reference + +
+# LuxCUDA.functionalMethod. + + + +```julia +functional() +``` + +Check if LuxCUDA is functional. + +
+
diff --git a/previews/PR474/api/Accelerator_Support/LuxDeviceUtils.md b/previews/PR474/api/Accelerator_Support/LuxDeviceUtils.md new file mode 100644 index 000000000..98b79a25e --- /dev/null +++ b/previews/PR474/api/Accelerator_Support/LuxDeviceUtils.md @@ -0,0 +1,125 @@ + + + + + +# LuxDeviceUtils + + +`LuxDeviceUtils.jl` is a lightweight package defining rules for transferring data across devices. Most users should directly use Lux.jl instead. + + + + +## Index + +- [`LuxDeviceUtils.cpu_device`](#LuxDeviceUtils.cpu_device) +- [`LuxDeviceUtils.gpu_backend!`](#LuxDeviceUtils.gpu_backend!) +- [`LuxDeviceUtils.gpu_device`](#LuxDeviceUtils.gpu_device) +- [`LuxDeviceUtils.reset_gpu_device!`](#LuxDeviceUtils.reset_gpu_device!) +- [`LuxDeviceUtils.supported_gpu_backends`](#LuxDeviceUtils.supported_gpu_backends) + + + + +## Preferences + +
+# LuxDeviceUtils.gpu_backend!Function. + + + +```julia +gpu_backend!() = gpu_backend!("") +gpu_backend!(backend) = gpu_backend!(string(backend)) +gpu_backend!(backend::AbstractLuxGPUDevice) +gpu_backend!(backend::String) +``` + +Creates a `LocalPreferences.toml` file with the desired GPU backend. + +If `backend == ""`, then the `gpu_backend` preference is deleted. Otherwise, `backend` is validated to be one of the possible backends and the preference is set to `backend`. + +If a new backend is successfully set, then the Julia session must be restarted for the change to take effect. + +
+
+ + + +## Data Transfer + +
+# LuxDeviceUtils.cpu_deviceFunction. + + + +```julia +cpu_device() -> LuxCPUDevice() +``` + +Return a `LuxCPUDevice` object which can be used to transfer data to CPU. + +
+
+
+# LuxDeviceUtils.gpu_deviceFunction. + + + +```julia +gpu_device(; force_gpu_usage::Bool=false) -> AbstractLuxDevice() +``` + +Selects GPU device based on the following criteria: + +1. If `gpu_backend` preference is set and the backend is functional on the system, then that device is selected. +2. Otherwise, an automatic selection algorithm is used. We go over possible device backends in the order specified by `supported_gpu_backends()` and select the first functional backend. +3. If no GPU device is functional and `force_gpu_usage` is `false`, then `cpu_device()` is invoked. +4. If nothing works, an error is thrown. + +
+
+ + + +## Miscellaneous + +
+# LuxDeviceUtils.reset_gpu_device!Function. + + + +```julia +reset_gpu_device!() +``` + +Resets the selected GPU device. This is useful when automatic GPU selection needs to be run again. + +
+
+
+# LuxDeviceUtils.supported_gpu_backendsFunction. + + + +```julia +supported_gpu_backends() -> Tuple{String, ...} +``` + +Return a tuple of supported GPU backends. + +::: warning + +This is not the list of functional backends on the system, but rather backends which `Lux.jl` supports. + +::: + +::: danger + +`Metal.jl` support is **extremely** experimental and most things are not expected to work. + +::: + +
+
diff --git a/previews/PR474/api/Building_Blocks/LuxCore.md b/previews/PR474/api/Building_Blocks/LuxCore.md new file mode 100644 index 000000000..a132b7958 --- /dev/null +++ b/previews/PR474/api/Building_Blocks/LuxCore.md @@ -0,0 +1,270 @@ + + + +# LuxCore + + +`LuxCore.jl` defines the abstract layers for Lux. Allows users to be compatible with the entirely of `Lux.jl` without having such a heavy dependency. If you are depending on `Lux.jl` directly, you do not need to depend on `LuxCore.jl` (all the functionality is exported via `Lux.jl`). + + + + +## Index + +- [`LuxCore.AbstractExplicitContainerLayer`](#LuxCore.AbstractExplicitContainerLayer) +- [`LuxCore.AbstractExplicitLayer`](#LuxCore.AbstractExplicitLayer) +- [`LuxCore.apply`](#LuxCore.apply) +- [`LuxCore.check_fmap_condition`](#LuxCore.check_fmap_condition) +- [`LuxCore.contains_lux_layer`](#LuxCore.contains_lux_layer) +- [`LuxCore.display_name`](#LuxCore.display_name) +- [`LuxCore.initialparameters`](#LuxCore.initialparameters) +- [`LuxCore.initialstates`](#LuxCore.initialstates) +- [`LuxCore.parameterlength`](#LuxCore.parameterlength) +- [`LuxCore.setup`](#LuxCore.setup) +- [`LuxCore.statelength`](#LuxCore.statelength) +- [`LuxCore.testmode`](#LuxCore.testmode) +- [`LuxCore.trainmode`](#LuxCore.trainmode) +- [`LuxCore.update_state`](#LuxCore.update_state) + + + + +## Abstract Types + +
+# LuxCore.AbstractExplicitLayerType. + + + +```julia +abstract type AbstractExplicitLayer +``` + +Abstract Type for all Lux Layers + +Users implementing their custom layer, **must** implement + + * `initialparameters(rng::AbstractRNG, layer::CustomAbstractExplicitLayer)` – This returns a `NamedTuple` containing the trainable parameters for the layer. + * `initialstates(rng::AbstractRNG, layer::CustomAbstractExplicitLayer)` – This returns a NamedTuple containing the current state for the layer. For most layers this is typically empty. Layers that would potentially contain this include `BatchNorm`, `LSTM`, `GRU` etc. + +Optionally: + + * `parameterlength(layer::CustomAbstractExplicitLayer)` – These can be automatically calculated, but it is recommended that the user defines these. + * `statelength(layer::CustomAbstractExplicitLayer)` – These can be automatically calculated, but it is recommended that the user defines these. + +See also [`AbstractExplicitContainerLayer`](LuxCore#LuxCore.AbstractExplicitContainerLayer) + +
+
+
+# LuxCore.AbstractExplicitContainerLayerType. + + + +```julia +abstract type AbstractExplicitContainerLayer{layers} <: AbstractExplicitLayer +``` + +Abstract Container Type for certain Lux Layers. `layers` is a tuple containing fieldnames for the layer, and constructs the parameters and states using those. + +Users implementing their custom layer can extend the same functions as in [`AbstractExplicitLayer`](LuxCore#LuxCore.AbstractExplicitLayer). + +::: tip + +Advanced structure manipulation of these layers post construction is possible via `Functors.fmap`. For a more flexible interface, we recommend using the experimental feature [`Lux.Experimental.@layer_map`](@ref). + +::: + +
+
+ + + +## General + +
+# LuxCore.applyFunction. + + + +```julia +apply(model, x, ps, st) +``` + +Simply calls `model(x, ps, st)` + +
+
+
+# LuxCore.check_fmap_conditionFunction. + + + +```julia +check_fmap_condition(cond, tmatch, x) -> Bool +``` + +`fmap`s into the structure `x` and see if `cond` is statisfied for any of the leaf elements. + +**Arguments** + +``` +* `cond` - A function that takes a single argument and returns a `Bool`. +* `tmatch` - A shortcut to check if `x` is of type `tmatch`. Can be disabled by passing + `nothing`. +* `x` - The structure to check. +``` + +**Returns** + +A Boolean Value + +
+
+
+# LuxCore.contains_lux_layerFunction. + + + +```julia +contains_lux_layer(l) -> Bool +``` + +Check if the structure `l` is a Lux AbstractExplicitLayer or a container of such a layer. + +
+
+
+# LuxCore.display_nameFunction. + + + +```julia +display_name(layer::AbstractExplicitLayer) +``` + +Printed Name of the `layer`. If the `layer` has a field `name` that is used, else the type name is used. + +
+
+
+# LuxCore.setupFunction. + + + +```julia +setup(rng::AbstractRNG, layer) +``` + +Shorthand for getting the parameters and states of the layer `l`. Is equivalent to `(initialparameters(rng, l), initialstates(rng, l))`. + +::: warning + +This function is not pure, it mutates `rng`. + +::: + +
+
+ + + +## Parameters + +
+# LuxCore.initialparametersFunction. + + + +```julia +initialparameters(rng::AbstractRNG, layer) +``` + +Generate the initial parameters of the layer `l`. + +
+
+
+# LuxCore.parameterlengthFunction. + + + +```julia +parameterlength(layer) +``` + +Return the total number of parameters of the layer `l`. + +
+
+ + + +## States + +
+# LuxCore.initialstatesFunction. + + + +```julia +initialstates(rng::AbstractRNG, layer) +``` + +Generate the initial states of the layer `l`. + +
+
+
+# LuxCore.statelengthFunction. + + + +```julia +statelength(layer) +``` + +Return the total number of states of the layer `l`. + +
+
+
+# LuxCore.testmodeFunction. + + + +```julia +testmode(st::NamedTuple) +``` + +Make all occurances of `training` in state `st` – `Val(false)`. + +
+
+
+# LuxCore.trainmodeFunction. + + + +```julia +trainmode(st::NamedTuple) +``` + +Make all occurances of `training` in state `st` – `Val(true)`. + +
+
+
+# LuxCore.update_stateFunction. + + + +```julia +update_state(st::NamedTuple, key::Symbol, value; + layer_check=_default_layer_check(key)) +``` + +Recursively update all occurances of the `key` in the state `st` with the `value`. + +
+
diff --git a/previews/PR474/api/Building_Blocks/LuxLib.md b/previews/PR474/api/Building_Blocks/LuxLib.md new file mode 100644 index 000000000..333e46140 --- /dev/null +++ b/previews/PR474/api/Building_Blocks/LuxLib.md @@ -0,0 +1,260 @@ + + + +# LuxLib + + +Backend for Lux.jl + + + + + + +## Index + +- [`LuxLib.alpha_dropout`](#LuxLib.alpha_dropout) +- [`LuxLib.batchnorm`](#LuxLib.batchnorm) +- [`LuxLib.dropout`](#LuxLib.dropout) +- [`LuxLib.groupnorm`](#LuxLib.groupnorm) +- [`LuxLib.instancenorm`](#LuxLib.instancenorm) +- [`LuxLib.layernorm`](#LuxLib.layernorm) + + + + +## Dropout + +
+# LuxLib.alpha_dropoutFunction. + + + +```julia +alpha_dropout(rng::AbstractRNG, x, p, ::Val{training}) +alpha_dropout(rng::AbstractRNG, x, p, ::Val{training}, α, A, B) +``` + +Alpha Dropout: Dropout ensuring that the mean and variance of the output remains same as the input. For details see [1]. Use the second call signature to avoid recomputing the constants for a fixed dropout probability. + +**Arguments** + + * `rng`: Random number generator + * `x`: Input Array + * `p`: Probability of an element to be dropped out + * `Val(training)`: If `true` then dropout is applied on `x` with probability `p`. Else, `x` is returned + * `α`: `-1.7580993408473766`. Computed at limit x tends to infinity, `selu(x) = -λβ = α` + * `A`: Scaling factor for the mean + * `B`: Scaling factor for the variance + +**Returns** + + * Output Array after applying alpha dropout + * Updated state for the random number generator + +**References** + +[1] Klambauer, Günter, et al. "Self-normalizing neural networks." Advances in neural information processing systems 30 (2017). + +
+
+
+# LuxLib.dropoutFunction. + + + +```julia +dropout(rng::AbstractRNG, x, p, ::Val{training}, invp; dims) +dropout(rng::AbstractRNG, x, mask, p, ::Val{training}, ::Val{update_mask}, invp; + dims) +``` + +Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1]. + +**Arguments** + + * `rng`: Random number generator + * `x`: Input Array + * `mask`: Dropout Mask. If not used then it is constructed automatically + * `p`: Probability of an element to be dropped out + * `Val(training)`: If `true` then dropout is applied on `x` with probability `p` along `dims`. Else, `x` is returned + * `Val(update_mask)`: If `true` then the mask is generated and used. Else, the `mask` provided is directly used + * `invp`: Inverse of the probability + +**Keyword Arguments** + + * `dims`: Dimensions along which dropout is applied + * `invp`: Inverse of the probability ($\frac{1}{p}$) + +**Returns** + + * Output Array after applying dropout + * Dropout Mask (if `training == false`, the returned value is meaningless) + * Updated state for the random number generator + +**References** + +[1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent neural networks from overfitting." The journal of machine learning research 15.1 (2014): 1929-1958. + +
+
+ + + +## Normalization + +
+# LuxLib.batchnormFunction. + + + +```julia +batchnorm(x, scale, bias, running_mean, running_var; momentum, epsilon, training) +``` + +Batch Normalization. For details see [1]. + +Batch Normalization computes the mean and variance for each $D_1 \times ... \times D_{N - 2} \times 1 \times D_N$ input slice and normalises the input accordingly. + +**Arguments** + + * `x`: Input to be Normalized + * `scale`: Scale factor ($\gamma$) (can be `nothing`) + * `bias`: Bias factor ($\beta$) (can be `nothing`) + * `running_mean`: Running mean (can be `nothing`) + * `running_var`: Running variance (can be `nothing`) + +**Keyword Arguments** + + * `momentum`: Momentum for updating running mean and variance + * `epsilon`: Value added to the denominator for numerical stability + * `training`: Set to `Val(true)` if running in training mode + +**Returns** + +Normalized Array of same size as `x`. And a Named Tuple containing the updated running mean and variance. + +**Performance Considerations** + +If the input array is `2D`, `4D`, or `5D` `CuArray` with element types `Float16`, `Float32` and `Float64`, then the CUDNN code path will be used. In all other cases, a broadcasting fallback is used which is not highly optimized. + +**References** + +[1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." International conference on machine learning. PMLR, 2015. + +
+
+
+# LuxLib.groupnormFunction. + + + +```julia +groupnorm(x, scale, bias; groups, epsilon) +``` + +Group Normalization. For details see [1]. + +This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics. + +**Arguments** + + * `x`: Input to be Normalized + * `scale`: Scale factor ($\gamma$) (can be `nothing`) + * `bias`: Bias factor ($\beta$) (can be `nothing`) + +**Keyword Arguments** + + * `groups`: Number of groups + * `epsilon`: Value added to the denominator for numerical stability + +**Returns** + +The normalized array is returned. + +**Performance Considerations** + +The most common case of this Op – `x` is a 4D array – is optimized using KernelAbstractions and has a fast custom backwards pass implemented. All other cases have a fallback implementation which is not especially optimized. + +We have tested the code path for `Float16` and it works, but gradient accumulation is extremely fragile. Hence, for `Float16` inputs, it uses the fallback implementation. + +If the batch size is small (< 16), then the fallback implementation will be faster than the KA version. However, this customization is not possible using the direct `groupnorm` interface. + +**References** + +[1] Wu, Yuxin, and Kaiming He. "Group normalization." Proceedings of the European conference on computer vision (ECCV). 2018. + +
+
+
+# LuxLib.instancenormFunction. + + + +```julia +instancenorm(x, scale, bias; epsilon, training) +``` + +Instance Normalization. For details see [1]. + +Instance Normalization computes the mean and variance for each $D_1 \times ... \times D_{N - 2} \times 1 \times 1$ input slice and normalises the input accordingly. + +**Arguments** + + * `x`: Input to be Normalized (must be atleast 3D) + * `scale`: Scale factor ($\gamma$) (can be `nothing`) + * `bias`: Bias factor ($\beta$) (can be `nothing`) + +**Keyword Arguments** + + * `epsilon`: Value added to the denominator for numerical stability + * `training`: Set to `Val(true)` if running in training mode + +**Returns** + +Normalized Array of same size as `x`. And a Named Tuple containing the updated running mean and variance. + +**References** + +[1] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Instance normalization: The missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016). + +
+
+
+# LuxLib.layernormFunction. + + + +```julia +layernorm(x, scale, bias; dims, epsilon) +``` + +Layer Normalization. For details see [1]. + +Given an input array $x$, this layer computes + +$$ +y = \frac{x - \mathbb{E}[x]}{\sqrt{Var[x] + \epsilon}} * \gamma + \beta +$$ + +**Arguments** + + * `x`: Input to be Normalized + * `scale`: Scale factor ($\gamma$) (can be `nothing`) + * `bias`: Bias factor ($\beta$) (can be `nothing`) + +**Keyword Arguments** + + * `dims`: Dimensions along which the mean and std of `x` is computed + * `epsilon`: Value added to the denominator for numerical stability + +**Returns** + +Normalized Array of same size as `x`. + +**References** + +[1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv preprint arXiv:1607.06450 (2016). + +
+
diff --git a/previews/PR474/api/Building_Blocks/WeightInitializers.md b/previews/PR474/api/Building_Blocks/WeightInitializers.md new file mode 100644 index 000000000..42877d44d --- /dev/null +++ b/previews/PR474/api/Building_Blocks/WeightInitializers.md @@ -0,0 +1,458 @@ + + + + + +# WeightInitializers + + +This package is a light dependency providing common weight initialization schemes for deep learning models. + + + + +## Index + +- [`WeightInitializers.glorot_normal`](#WeightInitializers.glorot_normal) +- [`WeightInitializers.glorot_uniform`](#WeightInitializers.glorot_uniform) +- [`WeightInitializers.kaiming_normal`](#WeightInitializers.kaiming_normal) +- [`WeightInitializers.kaiming_uniform`](#WeightInitializers.kaiming_uniform) +- [`WeightInitializers.ones16`](#WeightInitializers.ones16) +- [`WeightInitializers.ones32`](#WeightInitializers.ones32) +- [`WeightInitializers.ones64`](#WeightInitializers.ones64) +- [`WeightInitializers.onesC16`](#WeightInitializers.onesC16) +- [`WeightInitializers.onesC32`](#WeightInitializers.onesC32) +- [`WeightInitializers.onesC64`](#WeightInitializers.onesC64) +- [`WeightInitializers.rand16`](#WeightInitializers.rand16) +- [`WeightInitializers.rand32`](#WeightInitializers.rand32) +- [`WeightInitializers.rand64`](#WeightInitializers.rand64) +- [`WeightInitializers.randC16`](#WeightInitializers.randC16) +- [`WeightInitializers.randC32`](#WeightInitializers.randC32) +- [`WeightInitializers.randC64`](#WeightInitializers.randC64) +- [`WeightInitializers.randn16`](#WeightInitializers.randn16) +- [`WeightInitializers.randn32`](#WeightInitializers.randn32) +- [`WeightInitializers.randn64`](#WeightInitializers.randn64) +- [`WeightInitializers.randnC16`](#WeightInitializers.randnC16) +- [`WeightInitializers.randnC32`](#WeightInitializers.randnC32) +- [`WeightInitializers.randnC64`](#WeightInitializers.randnC64) +- [`WeightInitializers.truncated_normal`](#WeightInitializers.truncated_normal) +- [`WeightInitializers.zeros16`](#WeightInitializers.zeros16) +- [`WeightInitializers.zeros32`](#WeightInitializers.zeros32) +- [`WeightInitializers.zeros64`](#WeightInitializers.zeros64) +- [`WeightInitializers.zerosC16`](#WeightInitializers.zerosC16) +- [`WeightInitializers.zerosC32`](#WeightInitializers.zerosC32) +- [`WeightInitializers.zerosC64`](#WeightInitializers.zerosC64) + + + + +## API Reference + + + + +### Main Functions + +
+# WeightInitializers.glorot_normalFunction. + + + +```julia +glorot_normal([::AbstractRNG=_default_rng()], [T=Float32], size...; + gain = 1) -> AbstractArray{T, length(size)} +``` + +Return an `AbstractArray{T}` of the given `size` containing random numbers drawn from a normal distribution with standard deviation `gain * sqrt(2 / (fan_in + fan_out))`. This method is described in [1] and also known as Xavier initialization. + +**References** + +[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." *Proceedings of the thirteenth international conference on artificial intelligence and statistics*. 2010. + +
+
+
+# WeightInitializers.glorot_uniformFunction. + + + +```julia +glorot_uniform([::AbstractRNG=_default_rng()], [T=Float32], size...; + gain = 1) -> AbstractArray{T, length(size)} +``` + +Return an `AbstractArray{T}` of the given `size` containing random numbers drawn from a uniform distribution on the interval $[-x, x]$, where `x = gain * sqrt(6 / (fan_in + fan_out))`. This method is described in [1] and also known as Xavier initialization. + +**References** + +[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." *Proceedings of the thirteenth international conference on artificial intelligence and statistics*. 2010. + +
+
+
+# WeightInitializers.kaiming_normalFunction. + + + +```julia +kaiming_normal([::AbstractRNG=_default_rng()], [T=Float32], size...; + gain = √T(2)) -> AbstractArray{T, length(size)} +``` + +Return an `AbstractArray{T}` of the given `size` containing random numbers taken from a normal distribution standard deviation `gain / sqrt(fan_in)` + +**References** + +[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." *Proceedings of the IEEE international conference on computer vision*. 2015. + +
+
+
+# WeightInitializers.kaiming_uniformFunction. + + + +```julia +kaiming_uniform([::AbstractRNG=_default_rng()], [T=Float32], size...; + gain = √T(2)) -> AbstractArray{T, length(size)} +``` + +Return an `AbstractArray{T}` of the given `size` containing random numbers drawn from a uniform distribution on the interval `[-x, x]`, where `x = gain * sqrt(3/fan_in)`. + +**References** + +[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." *Proceedings of the IEEE international conference on computer vision*. 2015. + +
+
+
+# WeightInitializers.truncated_normalFunction. + + + +```julia +truncated_normal([::AbstractRNG=_default_rng()], [T=Float32], size...; mean = 0, + std = 1, lo = -2, hi = 2) -> AbstractArray{T, length(size)} +``` + +Return an `AbstractArray{T}` of the given `size` where each element is drawn from a truncated normal distribution. The numbers are distributed like `filter(x -> lo ≤ x ≤ hi, mean .+ std .* randn(100))`. + +
+
+ + + +### Commonly Used Wrappers + +
+# WeightInitializers.zeros16Function. + + + +```julia +zeros16([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{Float16, length(size)} +``` + +Return an `AbstractArray{Float16}` of the given `size` containing an AbstractArray of zeros. + +
+
+
+# WeightInitializers.ones16Function. + + + +```julia +ones16([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{Float16, length(size)} +``` + +Return an `AbstractArray{Float16}` of the given `size` containing an AbstractArray of ones. + +
+
+
+# WeightInitializers.rand16Function. + + + +```julia +rand16([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{Float16, length(size)} +``` + +Return an `AbstractArray{Float16}` of the given `size` containing random numbers from a uniform distribution. + +
+
+
+# WeightInitializers.randn16Function. + + + +```julia +randn16([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{Float16, length(size)} +``` + +Return an `AbstractArray{Float16}` of the given `size` containing random numbers from a standard normal distribution. + +
+
+
+# WeightInitializers.zeros32Function. + + + +```julia +zeros32([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{Float32, length(size)} +``` + +Return an `AbstractArray{Float32}` of the given `size` containing an AbstractArray of zeros. + +
+
+
+# WeightInitializers.ones32Function. + + + +```julia +ones32([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{Float32, length(size)} +``` + +Return an `AbstractArray{Float32}` of the given `size` containing an AbstractArray of ones. + +
+
+
+# WeightInitializers.rand32Function. + + + +```julia +rand32([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{Float32, length(size)} +``` + +Return an `AbstractArray{Float32}` of the given `size` containing random numbers from a uniform distribution. + +
+
+
+# WeightInitializers.randn32Function. + + + +```julia +randn32([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{Float32, length(size)} +``` + +Return an `AbstractArray{Float32}` of the given `size` containing random numbers from a standard normal distribution. + +
+
+
+# WeightInitializers.zeros64Function. + + + +```julia +zeros64([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{Float64, length(size)} +``` + +Return an `AbstractArray{Float64}` of the given `size` containing an AbstractArray of zeros. + +
+
+
+# WeightInitializers.ones64Function. + + + +```julia +ones64([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{Float64, length(size)} +``` + +Return an `AbstractArray{Float64}` of the given `size` containing an AbstractArray of ones. + +
+
+
+# WeightInitializers.rand64Function. + + + +```julia +rand64([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{Float64, length(size)} +``` + +Return an `AbstractArray{Float64}` of the given `size` containing random numbers from a uniform distribution. + +
+
+
+# WeightInitializers.randn64Function. + + + +```julia +randn64([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{Float64, length(size)} +``` + +Return an `AbstractArray{Float64}` of the given `size` containing random numbers from a standard normal distribution. + +
+
+
+# WeightInitializers.zerosC16Function. + + + +```julia +zerosC16([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{ComplexF16, length(size)} +``` + +Return an `AbstractArray{ComplexF16}` of the given `size` containing an AbstractArray of zeros. + +
+
+
+# WeightInitializers.onesC16Function. + + + +```julia +onesC16([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{ComplexF16, length(size)} +``` + +Return an `AbstractArray{ComplexF16}` of the given `size` containing an AbstractArray of ones. + +
+
+
+# WeightInitializers.randC16Function. + + + +```julia +randC16([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{ComplexF16, length(size)} +``` + +Return an `AbstractArray{ComplexF16}` of the given `size` containing random numbers from a uniform distribution. + +
+
+
+# WeightInitializers.randnC16Function. + + + +```julia +randnC16([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{ComplexF16, length(size)} +``` + +Return an `AbstractArray{ComplexF16}` of the given `size` containing random numbers from a standard normal distribution. + +
+
+
+# WeightInitializers.zerosC32Function. + + + +```julia +zerosC32([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{ComplexF32, length(size)} +``` + +Return an `AbstractArray{ComplexF32}` of the given `size` containing an AbstractArray of zeros. + +
+
+
+# WeightInitializers.onesC32Function. + + + +```julia +onesC32([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{ComplexF32, length(size)} +``` + +Return an `AbstractArray{ComplexF32}` of the given `size` containing an AbstractArray of ones. + +
+
+
+# WeightInitializers.randC32Function. + + + +```julia +randC32([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{ComplexF32, length(size)} +``` + +Return an `AbstractArray{ComplexF32}` of the given `size` containing random numbers from a uniform distribution. + +
+
+
+# WeightInitializers.randnC32Function. + + + +```julia +randnC32([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{ComplexF32, length(size)} +``` + +Return an `AbstractArray{ComplexF32}` of the given `size` containing random numbers from a standard normal distribution. + +
+
+
+# WeightInitializers.zerosC64Function. + + + +```julia +zerosC64([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{ComplexF64, length(size)} +``` + +Return an `AbstractArray{ComplexF64}` of the given `size` containing an AbstractArray of zeros. + +
+
+
+# WeightInitializers.onesC64Function. + + + +```julia +onesC64([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{ComplexF64, length(size)} +``` + +Return an `AbstractArray{ComplexF64}` of the given `size` containing an AbstractArray of ones. + +
+
+
+# WeightInitializers.randC64Function. + + + +```julia +randC64([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{ComplexF64, length(size)} +``` + +Return an `AbstractArray{ComplexF64}` of the given `size` containing random numbers from a uniform distribution. + +
+
+
+# WeightInitializers.randnC64Function. + + + +```julia +randnC64([::AbstractRNG=_default_rng()], size...; kwargs...) -> AbstractArray{ComplexF64, length(size)} +``` + +Return an `AbstractArray{ComplexF64}` of the given `size` containing random numbers from a standard normal distribution. + +
+
diff --git a/previews/PR474/api/Domain_Specific_Modeling/Boltz.md b/previews/PR474/api/Domain_Specific_Modeling/Boltz.md new file mode 100644 index 000000000..42971990f --- /dev/null +++ b/previews/PR474/api/Domain_Specific_Modeling/Boltz.md @@ -0,0 +1,300 @@ + + + + + +# Boltz + + +Accelerate ⚡ your ML research using pre-built Deep Learning Models with Lux. + + + + +## Index + +- [`Boltz.ClassTokens`](#Boltz.ClassTokens) +- [`Boltz.MultiHeadAttention`](#Boltz.MultiHeadAttention) +- [`Boltz.ViPosEmbedding`](#Boltz.ViPosEmbedding) +- [`Boltz._fast_chunk`](#Boltz._fast_chunk) +- [`Boltz._flatten_spatial`](#Boltz._flatten_spatial) +- [`Boltz._seconddimmean`](#Boltz._seconddimmean) +- [`Boltz._vgg_block`](#Boltz._vgg_block) +- [`Boltz._vgg_classifier_layers`](#Boltz._vgg_classifier_layers) +- [`Boltz._vgg_convolutional_layers`](#Boltz._vgg_convolutional_layers) +- [`Boltz.transformer_encoder`](#Boltz.transformer_encoder) +- [`Boltz.vgg`](#Boltz.vgg) + + + + +# Computer Vision Models + + + + +## Classification Models: Native Lux Models + + +| MODEL NAME | FUNCTION | NAME | PRETRAINED | TOP 1 ACCURACY (%) | TOP 5 ACCURACY (%) | +| ------------------:| --------------------:| -----------:|:----------:|:------------------:|:------------------:| +| VGG | `vgg` | `:vgg11` | ✅ | 67.35 | 87.91 | +| VGG | `vgg` | `:vgg13` | ✅ | 68.40 | 88.48 | +| VGG | `vgg` | `:vgg16` | ✅ | 70.24 | 89.80 | +| VGG | `vgg` | `:vgg19` | ✅ | 71.09 | 90.27 | +| VGG | `vgg` | `:vgg11_bn` | ✅ | 69.09 | 88.94 | +| VGG | `vgg` | `:vgg13_bn` | ✅ | 69.66 | 89.49 | +| VGG | `vgg` | `:vgg16_bn` | ✅ | 72.11 | 91.02 | +| VGG | `vgg` | `:vgg19_bn` | ✅ | 72.95 | 91.32 | +| Vision Transformer | `vision_transformer` | `:tiny` | 🚫 | | | +| Vision Transformer | `vision_transformer` | `:small` | 🚫 | | | +| Vision Transformer | `vision_transformer` | `:base` | 🚫 | | | +| Vision Transformer | `vision_transformer` | `:large` | 🚫 | | | +| Vision Transformer | `vision_transformer` | `:huge` | 🚫 | | | +| Vision Transformer | `vision_transformer` | `:giant` | 🚫 | | | +| Vision Transformer | `vision_transformer` | `:gigantic` | 🚫 | | | + + + + +## Building Blocks + +
+# Boltz.ClassTokensType. + + + +```julia +ClassTokens(dim; init=Lux.zeros32) +``` + +Appends class tokens to an input with embedding dimension `dim` for use in many vision transformer namels. + +
+
+
+# Boltz.MultiHeadAttentionType. + + + +```julia +MultiHeadAttention(in_planes::Int, number_heads::Int; qkv_bias::Bool=false, + attention_dropout_rate::T=0.0f0, + projection_dropout_rate::T=0.0f0) where {T} +``` + +Multi-head self-attention layer + +
+
+
+# Boltz.ViPosEmbeddingType. + + + +```julia +ViPosEmbedding(embedsize, npatches; + init = (rng, dims...) -> randn(rng, Float32, dims...)) +``` + +Positional embedding layer used by many vision transformer-like namels. + +
+
+
+# Boltz.transformer_encoderFunction. + + + +```julia +transformer_encoder(in_planes, depth, number_heads; mlp_ratio = 4.0f0, dropout = 0.0f0) +``` + +Transformer as used in the base ViT architecture. ([reference](https://arxiv.org/abs/2010.11929)). + +**Arguments** + + * `in_planes`: number of input channels + * `depth`: number of attention blocks + * `number_heads`: number of attention heads + * `mlp_ratio`: ratio of MLP layers to the number of input channels + * `dropout_rate`: dropout rate + +
+
+
+# Boltz.vggFunction. + + + +```julia +vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dropout) +``` + +Create a VGG model ([reference](https://arxiv.org/abs/1409.1556v6)). + +**Arguments** + + * `imsize`: input image width and height as a tuple + * `config`: the configuration for the convolution layers + * `inchannels`: number of input channels + * `batchnorm`: set to `true` to use batch normalization after each convolution + * `nclasses`: number of output classes + * `fcsize`: intermediate fully connected layer size (see `Metalhead._vgg_classifier_layers`) + * `dropout`: dropout level between fully connected layers + +
+
+ + + +### Non-Public API + +
+# Boltz._seconddimmeanFunction. + + + +```julia +_seconddimmean(x) +``` + +Computes the mean of `x` along dimension `2` + +
+
+
+# Boltz._fast_chunkFunction. + + + +```julia +_fast_chunk(x::AbstractArray, ::Val{n}, ::Val{dim}) +``` + +Type-stable and faster version of `MLUtils.chunk` + +
+
+
+# Boltz._flatten_spatialFunction. + + + +```julia +_flatten_spatial(x::AbstractArray{T, 4}) +``` + +Flattens the first 2 dimensions of `x`, and permutes the remaining dimensions to (2, 1, 3) + +
+
+
+# Boltz._vgg_blockFunction. + + + +```julia +_vgg_block(input_filters, output_filters, depth, batchnorm) +``` + +A VGG block of convolution layers ([reference](https://arxiv.org/abs/1409.1556v6)). + +**Arguments** + + * `input_filters`: number of input feature maps + * `output_filters`: number of output feature maps + * `depth`: number of convolution/convolution + batch norm layers + * `batchnorm`: set to `true` to include batch normalization after each convolution + +
+
+
+# Boltz._vgg_classifier_layersFunction. + + + +```julia +_vgg_classifier_layers(imsize, nclasses, fcsize, dropout) +``` + +Create VGG classifier (fully connected) layers ([reference](https://arxiv.org/abs/1409.1556v6)). + +**Arguments** + + * `imsize`: tuple `(width, height, channels)` indicating the size after the convolution layers (see `Metalhead._vgg_convolutional_layers`) + * `nclasses`: number of output classes + * `fcsize`: input and output size of the intermediate fully connected layer + * `dropout`: the dropout level between each fully connected layer + +
+
+
+# Boltz._vgg_convolutional_layersFunction. + + + +```julia +_vgg_convolutional_layers(config, batchnorm, inchannels) +``` + +Create VGG convolution layers ([reference](https://arxiv.org/abs/1409.1556v6)). + +**Arguments** + + * `config`: vector of tuples `(output_channels, num_convolutions)` for each block (see `Metalhead._vgg_block`) + * `batchnorm`: set to `true` to include batch normalization after each convolution + * `inchannels`: number of input channels + +
+
+ + + +## Classification Models: Imported from Metalhead.jl + + +:::tip + + +You need to load `Flux` and `Metalhead` before using these models. + + +::: + + +| MODEL NAME | FUNCTION | NAME | PRETRAINED | TOP 1 ACCURACY (%) | TOP 5 ACCURACY (%) | +| ----------:| -----------:| ---------------------:|:----------:|:------------------:|:------------------:| +| AlexNet | `alexnet` | `:alexnet` | ✅ | 54.48 | 77.72 | +| ResNet | `resnet` | `:resnet18` | 🚫 | 68.08 | 88.44 | +| ResNet | `resnet` | `:resnet34` | 🚫 | 72.13 | 90.91 | +| ResNet | `resnet` | `:resnet50` | 🚫 | 74.55 | 92.36 | +| ResNet | `resnet` | `:resnet101` | 🚫 | 74.81 | 92.36 | +| ResNet | `resnet` | `:resnet152` | 🚫 | 77.63 | 93.84 | +| ConvMixer | `convmixer` | `:small` | 🚫 | | | +| ConvMixer | `convmixer` | `:base` | 🚫 | | | +| ConvMixer | `convmixer` | `:large` | 🚫 | | | +| DenseNet | `densenet` | `:densenet121` | 🚫 | | | +| DenseNet | `densenet` | `:densenet161` | 🚫 | | | +| DenseNet | `densenet` | `:densenet169` | 🚫 | | | +| DenseNet | `densenet` | `:densenet201` | 🚫 | | | +| GoogleNet | `googlenet` | `:googlenet` | 🚫 | | | +| MobileNet | `mobilenet` | `:mobilenet_v1` | 🚫 | | | +| MobileNet | `mobilenet` | `:mobilenet_v2` | 🚫 | | | +| MobileNet | `mobilenet` | `:mobilenet_v3_small` | 🚫 | | | +| MobileNet | `mobilenet` | `:mobilenet_v3_large` | 🚫 | | | +| ResNeXT | `resnext` | `:resnext50` | 🚫 | | | +| ResNeXT | `resnext` | `:resnext101` | 🚫 | | | +| ResNeXT | `resnext` | `:resnext152` | 🚫 | | | + + +These models can be created using `(; pretrained = )` + + + + +### Preprocessing + + +All the pretrained models require that the images be normalized with the parameters `mean = [0.485f0, 0.456f0, 0.406f0]` and `std = [0.229f0, 0.224f0, 0.225f0]`. + diff --git a/previews/PR474/api/Lux/contrib.md b/previews/PR474/api/Lux/contrib.md new file mode 100644 index 000000000..37027a59d --- /dev/null +++ b/previews/PR474/api/Lux/contrib.md @@ -0,0 +1,667 @@ + + + +# Experimental Features + + + + +All features listed on this page are **experimental** which means: + + +1. No SemVer Guarantees. We use code here to iterate fast and most users should wait for these features to be marked non-experimental. +2. The code will probably be moved into a separate repository in the future. +3. Expect edge-cases and report them. It will help us move these features out of experimental sooner. +4. None of the features are exported. + + +:::warning + + +Starting v"0.5.2" all Experimental features need to be accessed via `Lux.Experimental.`. Direct access via `Lux.` will be removed in v"0.6". + + +::: + + + + +## Index + +- [`Lux.Experimental.DebugLayer`](#Lux.Experimental.DebugLayer) +- [`Lux.Experimental.FrozenLayer`](#Lux.Experimental.FrozenLayer) +- [`Lux.Experimental.StatefulLuxLayer`](#Lux.Experimental.StatefulLuxLayer) +- [`Lux.Experimental.TrainState`](#Lux.Experimental.TrainState) +- [`Lux.Experimental.apply_gradients`](#Lux.Experimental.apply_gradients) +- [`Lux.Experimental.compute_gradients`](#Lux.Experimental.compute_gradients) +- [`Lux.Experimental.freeze`](#Lux.Experimental.freeze) +- [`Lux.Experimental.layer_map`](#Lux.Experimental.layer_map) +- [`Lux.Experimental.share_parameters`](#Lux.Experimental.share_parameters) +- [`Lux.Experimental.unfreeze`](#Lux.Experimental.unfreeze) +- [`Lux.Experimental.@compact`](#Lux.Experimental.@compact) +- [`Lux.Experimental.@debug_mode`](#Lux.Experimental.@debug_mode) +- [`Lux.Experimental.@layer_map`](#Lux.Experimental.@layer_map) + + + + +## Training + + +Helper Functions making it easier to train `Lux.jl` models. + + +Lux.Training is meant to be simple and provide extremely basic functionality. We provide basic building blocks which can be seamlessly composed to create complex training pipelines. + +
+# Lux.Experimental.TrainStateType. + + + +```julia +TrainState +``` + +Training State containing: + + * `model`: `Lux` model. + * `parameters`: Trainable Variables of the `model`. + * `states`: Non-trainable Variables of the `model`. + * `optimizer_state`: Optimizer State. + * `step`: Number of updates of the parameters made. + + +source
+ +
+
+
+# Lux.Experimental.compute_gradientsFunction. + + + +```julia +compute_gradients(ad::ADTypes.AbstractADType, objective_function::Function, data, + ts::TrainState) +``` + +Compute the gradients of the objective function wrt parameters stored in `ts`. + +**Arguments** + + * `ad`: Backend (from [ADTypes.jl](https://github.com/SciML/ADTypes.jl)) used to compute the gradients. + * `objective_function`: Objective function. The function must take 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics. + * `data`: Data used to compute the gradients. + * `ts`: Current Training State. See [`TrainState`](contrib#Lux.Experimental.TrainState). + +**Return** + +A 4-Tuple containing: + + * `grads`: Computed Gradients. + * `loss`: Loss from the objective function. + * `stats`: Any computed statistics from the objective function. + * `ts`: Updated Training State. + + +source
+ +
+
+
+# Lux.Experimental.apply_gradientsFunction. + + + +```julia +apply_gradients(ts::TrainState, grads) +``` + +Update the parameters stored in `ts` using the gradients `grads`. + +**Arguments** + + * `ts`: `TrainState` object. + * `grads`: Gradients of the loss function wrt `ts.params`. + +**Returns** + +Updated `TrainState` object. + + +source
+ +
+
+ + + +## Parameter Freezing + + +:::info + + +In the long term, this will be supported via [Optimisers.jl](https://github.com/FluxML/Optimisers.jl/pull/49). + + +::: + +
+# Lux.Experimental.FrozenLayerType. + + + +```julia +FrozenLayer(l::AbstractExplicitLayer, which_params::Union{Tuple, Nothing}) +``` + +Freeze the parameters with name `which_params` of the layer `l`. + +:::tip + +It is always recommended to use the [`Lux.Experimental.freeze`](contrib#Lux.Experimental.freeze) function instead of directly using the `FrozenLayer` constructor. + +::: + +:::warning + +There are no checks for `which_params`. For example, if the original layer has parameters named `(:weight, :bias)`, and `which_params`is set to`(:myweight,)` then none of the parameters are frozen and no error is thrown. + +::: + +**Arguments** + + * `l`: Lux AbstractExplicitLayer. + * `which_params`: Parameter Names to be Frozen. Can be set to `nothing`, in which case all parameters are frozen. + +**Input** + + * `x`: Input to the layer `l`. + +**Returns** + + * Output of the inner layer `l` + * Updated State + +**Parameters** + + * Parameters of the layer `l` excluding `which_params`. + +**States** + + * `frozen_params`: Parameters that are frozen, i.e., `which_params`. + * `states`: The state of the inner layer `l`. + +**Note on Internal Layer Implementation** + +The inner layer should work with `NamedTuple` parameters. In order to support custom parameter types, users need to implement `Lux._merge(::CustomParamType, ::NamedTuple)`. + +**Example** + +```julia +m = Lux.Experimental.FrozenLayer(Dense(2 => 2), (:weight,)) +``` + +See also [`Lux.Experimental.freeze`](contrib#Lux.Experimental.freeze), [`Lux.Experimental.unfreeze`](contrib#Lux.Experimental.unfreeze). + + +source
+ +
+
+
+# Lux.Experimental.freezeFunction. + + + +```julia +freeze(l::AbstractExplicitLayer, which_params::Union{Tuple, Nothing} = nothing) +``` + +Constructs a version of `l` with `which_params` frozen. If `which_params` is nothing, then all parameters are frozen. + + +source
+ + +``` +freeze(l::AbstractExplicitLayer, ps, st::NamedTuple, + which_params::Union{Tuple, Nothing} = nothing) +``` + +Construct a [`Lux.Experimental.FrozenLayer`](contrib#Lux.Experimental.FrozenLayer) for `l` with the current parameters and states. If `which_params` is nothing, then all parameters are frozen. + + +source
+ +
+
+
+# Lux.Experimental.unfreezeFunction. + + + +```julia +unfreeze(l::FrozenLayer) +``` + +Unfreezes the layer `l`. + + +source
+ + +``` +unfreeze(l::FrozenLayer, ps, st::NamedTuple) +``` + +Unwraps a [`Lux.Experimental.FrozenLayer`](contrib#Lux.Experimental.FrozenLayer) `l` with the current parameters and states. + + +source
+ +
+
+ +For detailed usage example look at the [manual page](../../manual/freezing_model_parameters). + + + + +## Map over Layer + +
+# Lux.Experimental.layer_mapFunction. + + + +```julia +layer_map(f::Function, l::AbstractExplicitLayer, ps, st::NamedTuple, + name::String="model") +``` + +Map the function `f` over the model `l`, with the parameters `ps` and states `st`. This is different from `Functors.fmap` since it zips the layers, parameters, and states and invokes the function on all of them together. + +**Call Signature for `f`** + + * Must take 4 inputs – `AbstractExplicitLayer`, Corresponding Parameters, Corresponding States, and the name of the layer. + * Must return a tuple of 3 elements – `AbstractExplicitLayer`, new parameters and the new states. + +:::tip + +We recommend using the macro `Lux.@layer_map` instead of this function. It automatically sets the `name` of the layer to be the variable name. + +::: + +**Example** + +```julia +using Lux, Random, Setfield + +c = Parallel(+; chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3), + dense_2=Dense(3 => 5)), + dense_3=Dense(5 => 1)) + +rng = Random.default_rng() +ps, st = Lux.setup(rng, c) + +# Makes parameters of Dense Layers inside Chain zero +function zero_dense_params(l, ps, st, name) + if l isa Dense + println("zeroing params of $name") + @set! ps.weight = zero.(ps.weight) + @set! ps.bias = zero.(ps.bias) + end + return l, ps, st +end + +Lux.layer_map(zero_dense_params, c, ps, st) +``` + + +source
+ +
+
+
+# Lux.Experimental.@layer_mapMacro. + + + +```julia +@layer_map func layer ps st +``` + +See the documentation of [`Lux.Experimental.layer_map`](contrib#Lux.Experimental.layer_map) for more details. This macro eliminates the need to the set the layer name, and uses the variable name as the starting point. + +**Example** + +```julia +using Lux, Random, Setfield + +c = Parallel(+; chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3), + dense_2=Dense(3 => 5)), + dense_3=Dense(5 => 1)) + +rng = Random.default_rng() +ps, st = Lux.setup(rng, c) + +# Makes parameters of Dense Layers inside Chain zero +function zero_dense_params(l, ps, st, name) + if l isa Dense + println("zeroing params of $name") + @set! ps.weight = zero.(ps.weight) + @set! ps.bias = zero.(ps.bias) + end + return l, ps, st +end + +Lux.@layer_map zero_dense_params c ps st +``` + + +source
+ +
+
+ + + +## Debugging Functionality + + +Model not working properly! Here are some functionalities to help you debug you Lux model. + +
+# Lux.Experimental.@debug_modeMacro. + + + +```julia +@debug_mode layer kwargs... +``` + +Recurses into the `layer` and replaces the inner most non Container Layers with a [`Lux.Experimental.DebugLayer`](contrib#Lux.Experimental.DebugLayer). + +See [`Lux.Experimental.DebugLayer`](contrib#Lux.Experimental.DebugLayer) for details about the Keyword Arguments. + + +source
+ +
+
+
+# Lux.Experimental.DebugLayerType. + + + +```julia +DebugLayer(layer::AbstractExplicitLayer; nan_check::Symbol=:both, + error_check::Bool=true, location::String="") +``` + +::: danger + +This layer is only meant to be used for debugging. If used for actual training or inference, will lead to extremely bad performance. + +::: + +A wrapper over Lux layers that adds checks for NaNs and errors. This is useful for debugging. + +**Arguments** + + * `layer`: The layer to be wrapped. + +**Keyword Arguments** + + * `nan_check`: Whether to check for NaNs in the input, parameters, and states. Can be `:both`, `:forward`, `:backward`, or `:none`. + * `error_check`: Whether to check for errors in the layer. If `true`, will throw an error if the layer fails. + * `location`: The location of the layer. Use [`Lux.Experimental.@debug_mode`](contrib#Lux.Experimental.@debug_mode) to construct this layer to populate this value correctly. + +**Inputs** + + * `x`: The input to the layer. + +**Outputs** + + * `y`: The output of the layer. + * `st`: The updated states of the layer. + +If `nan_check` is enabled and NaNs are detected then a `DomainError` is thrown. If `error_check` is enabled, then any errors in the layer are thrown with useful information to track where the error originates. + +::: warning + +`nan_check` for the backward mode only works with ChainRules Compatible Reverse Mode AD Tools currently. + +::: + +See [`Lux.Experimental.@debug_mode`](contrib#Lux.Experimental.@debug_mode) to construct this layer. + + +source
+ +
+
+ + + +## Tied Parameters + +
+# Lux.Experimental.share_parametersFunction. + + + +```julia +share_parameters(ps, sharing) +share_parameters(ps, sharing, new_parameters) +``` + +Updates the parameters in `ps` with a common set of parameters `new_parameters` that are shared between each list in the nested list `sharing`. (That was kind of a mouthful, the example should make it clear). + +**Arguments** + + * `ps`: Original parameters. + * `sharing`: A nested list of lists of accessors of `ps` which need to shate the parameters (See the example for details). (Each list in the list must be disjoint) + * `new_parameters`: If passed the length of `new_parameters` must be equal to the length of `sharing`. For each vector in `sharing` the corresponding parameter in `new_parameters` will be used. (If not passed, the parameters corresponding to the first element of each vector in `sharing` will be used). + +**Returns** + +Updated Parameters having the same structure as `ps`. + +**Example** + +```julia +model = Chain(; + d1=Dense(2 => 4, tanh), + d3=Chain(; l1=Dense(4 => 2), l2=Dense(2 => 4)), + d2=Dense(4 => 2)) + +ps, st = Lux.setup(Xoshiro(0), model) + +# share parameters of (d1 and d3.l1) and (d3.l2 and d2) +ps = Lux.share_parameters(ps, (("d3.l2", "d1"), ("d2", "d3.l1"))) +``` + + +source
+ +
+
+ + + +## Stateful Layer + +
+# Lux.Experimental.StatefulLuxLayerType. + + + +```julia +StatefulLuxLayer(model, ps, st) +``` + +::: warning + +This is not a Lux.AbstractExplicitLayer + +::: + +A convenience wrapper over Lux layers which stores the parameters and states internally. Most users should not be using this version. This comes handy when Lux internally uses the `@compact` to construct models and in SciML codebases where propagating state might involving [`Box`ing](https://github.com/JuliaLang/julia/issues/15276). + +For a motivating example, see the Neural ODE tutorial. + +::: warning + +State is mutated in place. An additional caveat is that the updated state from `model` must have the same type as `st`. + +::: + +**Arguments** + + * `model`: A Lux layer + * `ps`: The parameters of the layer. This can be set to `nothing`, if the user provides the parameters on function call + * `st`: The state of the layer + +**Inputs** + + * `x`: The input to the layer + * `ps`: The parameters of the layer. Optional, defaults to `s.ps` + +**Outputs** + + * `y`: The output of the layer + + +source
+ +
+
+ + + +## Compact Layer API + +
+# Lux.Experimental.@compactMacro. + + + +```julia +@compact(kw...) do x + ... +end +@compact(forward::Function; name=nothing, dispatch=nothing, parameters...) +``` + +Creates a layer by specifying some `parameters`, in the form of keywords, and (usually as a `do` block) a function for the forward pass. You may think of `@compact` as a specialized `let` block creating local variables that are trainable in Lux. Declared variable names may be used within the body of the `forward` function. Note that unlike typical Lux models, the forward function doesn't need to explicitly manage states. + +**Reserved Kwargs:** + +1. `name`: The name of the layer. +2. `dispatch`: The constructed layer has the type `Lux.Experimental.CompactLuxLayer{dispatch}` which can be used for custom dispatches. + +**Examples** + +Here is a linear model: + +```julia +using Lux, Random +import Lux.Experimental: @compact + +r = @compact(w=rand(3)) do x + return w .* x +end +ps, st = Lux.setup(Xoshiro(0), r) +r([1, 1, 1], ps, st) # x is set to [1, 1, 1]. +``` + +Here is a linear model with bias and activation: + +```julia +d_in = 5 +d_out = 7 +d = @compact(W=randn(d_out, d_in), b=zeros(d_out), act=relu) do x + y = W * x + return act.(y .+ b) +end +ps, st = Lux.setup(Xoshiro(0), d) +d(ones(5, 10), ps, st) # 7×10 Matrix as output. + +ps_dense = (; weight=ps.W, bias=ps.b) +first(d([1, 2, 3, 4, 5], ps, st)) ≈ +first(Dense(d_in => d_out, relu)([1, 2, 3, 4, 5], ps_dense, NamedTuple())) # Equivalent to a dense layer +``` + +Finally, here is a simple MLP: + +```julia +n_in = 1 +n_out = 1 +nlayers = 3 + +model = @compact(w1=Dense(n_in, 128), + w2=[Dense(128, 128) for i in 1:nlayers], + w3=Dense(128, n_out), + act=relu) do x + embed = act(w1(x)) + for w in w2 + embed = act(w(embed)) + end + out = w3(embed) + return out +end + +ps, st = Lux.setup(Xoshiro(0), model) + +model(randn(n_in, 32), ps, st) # 1×32 Matrix as output. +``` + +We can train this model just like any Lux model: + +```julia +using Optimisers, Zygote + +x_data = collect(-2.0f0:0.1f0:2.0f0)' +y_data = 2 .* x_data .- x_data .^ 3 +optim = Optimisers.setup(Adam(), ps) + +for epoch in 1:1000 + loss, gs = Zygote.withgradient(ps -> sum(abs2, first(model(x_data, ps, st)) .- y_data), + ps) + @show epoch, loss + Optimisers.update!(optim, ps, gs[1]) +end +``` + +You may also specify a `name` for the model, which will be used instead of the default printout, which gives a verbatim representation of the code used to construct the model: + +```julia +model = @compact(w=rand(3), name="Linear(3 => 1)") do x + return sum(w .* x) +end + +println(model) # "Linear(3 => 1)()" +``` + +This can be useful when using `@compact` to hierarchically construct complex models to be used inside a `Chain`. + +:::tip Type Stability + +If your input function `f` is type-stable but the generated model is not type stable, it should be treated as a bug. We will appreciate issues if you find such cases. + +::: + +:::warning Parameter Count + +Array Parameter don't print the number of parameters on the side. However, they do account for the total number of parameters printed at the bottom. + +::: + + +source
+ +
+
diff --git a/previews/PR474/api/Lux/flux_to_lux.md b/previews/PR474/api/Lux/flux_to_lux.md new file mode 100644 index 000000000..abe9b7cc4 --- /dev/null +++ b/previews/PR474/api/Lux/flux_to_lux.md @@ -0,0 +1,111 @@ + + + +# Flux Models to Lux Models + + + + +Accessing these functions require manually loading `Flux`, i.e., `using Flux` must be present somewhere in the code for these to be used. + + + + +## Index + +- [`Lux.FluxLayer`](#Lux.FluxLayer) +- [`Lux.transform`](#Lux.transform) + + + + +## Functions + +
+# Lux.transformFunction. + + + +```julia +transform(l; preserve_ps_st::Bool=false, force_preserve::Bool=false) +``` + +Convert a Flux Model to Lux Model. + +:::warning + +`transform` always ingores the `active` field of some of the Flux layers. This is almost never going to be supported. + +::: + +**Arguments** + + * `l`: Flux l or any generic Julia function / object. + +**Keyword Arguments** + + * `preserve_ps_st`: Set to `true` to preserve the states and parameters of the l. This attempts the best possible way to preserve the original model. But it might fail. If you need to override possible failures, set `force_preserve` to `true`. + * `force_preserve`: Some of the transformations with state and parameters preservation haven't been implemented yet, in these cases, if `force_transform` is `false` a warning will be printed and a core Lux layer will be returned. Else, it will create a [`FluxLayer`](flux_to_lux#Lux.FluxLayer). + +**Examples** + +```julia +import Flux +using Lux, Metalhead, Random + +m = ResNet(18) +m2 = transform(m.layers) + +x = randn(Float32, 224, 224, 3, 1); + +ps, st = Lux.setup(Random.default_rng(), m2); + +m2(x, ps, st) +``` + + +source
+ +
+
+ + + +## Layers + +
+# Lux.FluxLayerType. + + + +```julia +FluxLayer(layer) +``` + +Serves as a compatibility layer between Flux and Lux. This uses `Optimisers.destructure` API internally. + +:::warning + +Lux was written to overcome the limitations of `destructure` + `Flux`. It is recommended to rewrite your l in Lux instead of using this layer. + +::: + +:::warning + +Introducing this Layer in your model will lead to type instabilities, given the way `Optimisers.destructure` works. + +::: + +**Arguments** + + * `layer`: Flux layer + +**Parameters** + + * `p`: Flattened parameters of the `layer` + + +source
+ +
+
diff --git a/previews/PR474/api/Lux/layers.md b/previews/PR474/api/Lux/layers.md new file mode 100644 index 000000000..cf0859073 --- /dev/null +++ b/previews/PR474/api/Lux/layers.md @@ -0,0 +1,2053 @@ + + + +# Built-In Layers + + + + + + +## Index + +- [`Lux.AdaptiveMaxPool`](#Lux.AdaptiveMaxPool) +- [`Lux.AdaptiveMeanPool`](#Lux.AdaptiveMeanPool) +- [`Lux.AlphaDropout`](#Lux.AlphaDropout) +- [`Lux.BatchNorm`](#Lux.BatchNorm) +- [`Lux.Bilinear`](#Lux.Bilinear) +- [`Lux.BranchLayer`](#Lux.BranchLayer) +- [`Lux.Chain`](#Lux.Chain) +- [`Lux.Conv`](#Lux.Conv) +- [`Lux.ConvTranspose`](#Lux.ConvTranspose) +- [`Lux.CrossCor`](#Lux.CrossCor) +- [`Lux.Dense`](#Lux.Dense) +- [`Lux.Dropout`](#Lux.Dropout) +- [`Lux.Embedding`](#Lux.Embedding) +- [`Lux.FlattenLayer`](#Lux.FlattenLayer) +- [`Lux.GRUCell`](#Lux.GRUCell) +- [`Lux.GlobalMaxPool`](#Lux.GlobalMaxPool) +- [`Lux.GlobalMeanPool`](#Lux.GlobalMeanPool) +- [`Lux.GroupNorm`](#Lux.GroupNorm) +- [`Lux.InstanceNorm`](#Lux.InstanceNorm) +- [`Lux.LSTMCell`](#Lux.LSTMCell) +- [`Lux.LayerNorm`](#Lux.LayerNorm) +- [`Lux.MaxPool`](#Lux.MaxPool) +- [`Lux.Maxout`](#Lux.Maxout) +- [`Lux.MeanPool`](#Lux.MeanPool) +- [`Lux.NoOpLayer`](#Lux.NoOpLayer) +- [`Lux.PairwiseFusion`](#Lux.PairwiseFusion) +- [`Lux.Parallel`](#Lux.Parallel) +- [`Lux.RNNCell`](#Lux.RNNCell) +- [`Lux.Recurrence`](#Lux.Recurrence) +- [`Lux.RepeatedLayer`](#Lux.RepeatedLayer) +- [`Lux.ReshapeLayer`](#Lux.ReshapeLayer) +- [`Lux.Scale`](#Lux.Scale) +- [`Lux.SelectDim`](#Lux.SelectDim) +- [`Lux.SkipConnection`](#Lux.SkipConnection) +- [`Lux.StatefulRecurrentCell`](#Lux.StatefulRecurrentCell) +- [`Lux.Upsample`](#Lux.Upsample) +- [`Lux.VariationalHiddenDropout`](#Lux.VariationalHiddenDropout) +- [`Lux.WeightNorm`](#Lux.WeightNorm) +- [`Lux.WrappedFunction`](#Lux.WrappedFunction) +- [`Lux.PixelShuffle`](#Lux.PixelShuffle) + + + + +## Containers + +
+# Lux.BranchLayerType. + + + +```julia +BranchLayer(layers...) +BranchLayer(; name=nothing, layers...) +``` + +Takes an input `x` and passes it through all the `layers` and returns a tuple of the outputs. + +**Arguments** + + * Layers can be specified in two formats: + + * A list of `N` Lux layers + * Specified as `N` keyword arguments. + +**Keyword Arguments** + + * `name`: Name of the layer (optional) + +**Inputs** + + * `x`: Will be directly passed to each of the `layers` + +**Returns** + + * Tuple: `(layer_1(x), layer_2(x), ..., layer_N(x))` (naming changes if using the kwargs API) + * Updated state of the `layers` + +**Parameters** + + * Parameters of each `layer` wrapped in a NamedTuple with `fields = layer_1, layer_2, ..., layer_N` (naming changes if using the kwargs API) + +**States** + + * States of each `layer` wrapped in a NamedTuple with `fields = layer_1, layer_2, ..., layer_N` (naming changes if using the kwargs API) + +:::tip Comparison with Parallel + +This is slightly different from [`Parallel(nothing, layers...)`](layers#Lux.Parallel) + + * If the input is a tuple, `Parallel` will pass each element individually to each layer. + * `BranchLayer` essentially assumes 1 input comes in and is branched out into `N` outputs. + +::: + +**Example** + +An easy way to replicate an input to an NTuple is to do + +```julia +l = BranchLayer(NoOpLayer(), NoOpLayer(), NoOpLayer()) +``` + + +source
+ +
+
+
+# Lux.ChainType. + + + +```julia +Chain(layers...; name=nothing, disable_optimizations::Bool = false) +Chain(; layers..., name=nothing, disable_optimizations::Bool = false) +``` + +Collects multiple layers / functions to be called in sequence on a given input. + +**Arguments** + + * Layers can be specified in two formats: + + * A list of `N` Lux layers + * Specified as `N` keyword arguments. + +**Keyword Arguments** + + * `disable_optimizations`: Prevents any structural optimization + * `name`: Name of the layer (optional) + +**Inputs** + +Input `x` is passed sequentially to each layer, and must conform to the input requirements of the internal layers. + +**Returns** + + * Output after sequentially applying all the layers to `x` + * Updated model states + +**Parameters** + + * Parameters of each `layer` wrapped in a NamedTuple with `fields = layer_1, layer_2, ..., layer_N` (naming changes if using the kwargs API) + +**States** + + * States of each `layer` wrapped in a NamedTuple with `fields = layer_1, layer_2, ..., layer_N` (naming changes if using the kwargs API) + +**Optimizations** + +Performs a few optimizations to generate reasonable architectures. Can be disabled using keyword argument `disable_optimizations`. + + * All sublayers are recursively optimized. + * If a function `f` is passed as a layer and it doesn't take 3 inputs, it is converted to a [`WrappedFunction`](layers#Lux.WrappedFunction)(`f`) which takes only one input. + * If the layer is a Chain, it is flattened. + * [`NoOpLayer`](layers#Lux.NoOpLayer)s are removed. + * If there is only 1 layer (left after optimizations), then it is returned without the `Chain` wrapper. + * If there are no layers (left after optimizations), a [`NoOpLayer`](layers#Lux.NoOpLayer) is returned. + +**Miscellaneous Properties** + + * Allows indexing. We can access the `i`th layer using `m[i]`. We can also index using ranges or arrays. + +**Example** + +```julia +c = Chain(Dense(2, 3, relu), BatchNorm(3), Dense(3, 2)) +``` + + +source
+ +
+
+
+# Lux.PairwiseFusionType. + + + +```julia +PairwiseFusion(connection, layers...; name=nothing) +PairwiseFusion(connection; name=nothing, layers...) +``` + +``` +x1 → layer1 → y1 ↘ + connection → layer2 → y2 ↘ + x2 ↗ connection → y3 + x3 ↗ +``` + +**Arguments** + + * `connection`: Takes 2 inputs and combines them + * `layers`: `AbstractExplicitLayer`s. Layers can be specified in two formats: + + * A list of `N` Lux layers + * Specified as `N` keyword arguments. + +**Keyword Arguments** + + * `name`: Name of the layer (optional) + +**Inputs** + +Layer behaves differently based on input type: + +1. If the input `x` is a tuple of length `N + 1`, then the `layers` must be a tuple of length `N`. The computation is as follows + +```julia +y = x[1] +for i in 1:N + y = connection(x[i + 1], layers[i](y)) +end +``` + +2. Any other kind of input + +```julia +y = x +for i in 1:N + y = connection(x, layers[i](y)) +end +``` + +**Returns** + + * See Inputs section for how the return value is computed + * Updated model state for all the contained layers + +**Parameters** + + * Parameters of each `layer` wrapped in a NamedTuple with `fields = layer_1, layer_2, ..., layer_N` (naming changes if using the kwargs API) + +**States** + + * States of each `layer` wrapped in a NamedTuple with `fields = layer_1, layer_2, ..., layer_N` (naming changes if using the kwargs API) + + +source
+ +
+
+
+# Lux.ParallelType. + + + +```julia +Parallel(connection, layers...; name=nothing) +Parallel(connection; name=nothing, layers...) +``` + +Create a layer which passes an input to each path in `layers`, before reducing the output with `connection`. + +**Arguments** + + * `connection`: An `N`-argument function that is called after passing the input through each layer. If `connection = nothing`, we return a tuple `Parallel(nothing, f, g)(x, y) = (f(x), g(y))` + * Layers can be specified in two formats: + + * A list of `N` Lux layers + * Specified as `N` keyword arguments. + +**Keyword Arguments** + + * `name`: Name of the layer (optional) + +**Inputs** + + * `x`: If `x` is not a tuple, then return is computed as `connection([l(x) for l in layers]...)`. Else one is passed to each layer, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`. + +**Returns** + + * See the Inputs section for how the output is computed + * Updated state of the `layers` + +**Parameters** + + * Parameters of each `layer` wrapped in a NamedTuple with `fields = layer_1, layer_2, ..., layer_N` (naming changes if using the kwargs API) + +**States** + + * States of each `layer` wrapped in a NamedTuple with `fields = layer_1, layer_2, ..., layer_N` (naming changes if using the kwargs API) + +See also [`SkipConnection`](layers#Lux.SkipConnection) which is `Parallel` with one identity. + + +source
+ +
+
+
+# Lux.SkipConnectionType. + + + +```julia +SkipConnection(layer, connection; name=nothing) +``` + +Create a skip connection which consists of a layer or [`Chain`](layers#Lux.Chain) of consecutive layers and a shortcut connection linking the block's input to the output through a user-supplied 2-argument callable. The first argument to the callable will be propagated through the given `layer` while the second is the unchanged, "skipped" input. + +The simplest "ResNet"-type connection is just `SkipConnection(layer, +)`. + +**Arguments** + + * `layer`: Layer or `Chain` of layers to be applied to the input + * `connection`: + + * A 2-argument function that takes `layer(input)` and the input OR + * An AbstractExplicitLayer that takes `(layer(input), input)` as input + +**Keyword Arguments** + + * `name`: Name of the layer (optional) + +**Inputs** + + * `x`: Will be passed directly to `layer` + +**Returns** + + * Output of `connection(layer(input), input)` + * Updated state of `layer` + +**Parameters** + + * Parameters of `layer` OR + * If `connection` is an AbstractExplicitLayer, then NamedTuple with fields `:layers` and `:connection` + +**States** + + * States of `layer` OR + * If `connection` is an AbstractExplicitLayer, then NamedTuple with fields `:layers` and `:connection` + +See [`Parallel`](layers#Lux.Parallel) for a more general implementation. + + +source
+ +
+
+
+# Lux.RepeatedLayerType. + + + +```julia +RepeatedLayer(model; repeats::Val = Val(10), input_injection::Val = Val(false)) +``` + +Iteratively applies `model` for `repeats` number of times. The initial input is passed into the model repeatedly if `input_injection = Val(true)`. This layer unrolls the computation, however, semantically this is same as: + +1. `input_injection = Val(false)` + +```julia +res = x +for i in 1:repeats + res, st = model(res, ps, st) +end +``` + +2. `input_injection = Val(true)` + +```julia +res = x +for i in 1:repeats + res, st = model((res, x), ps, st) +end +``` + +It is expected that `repeats` will be a reasonable number below `20`, beyond that compile times for gradients might be unreasonably high. + +**Arguments** + + * `model` must be an `AbstractExplicitLayer` + +**Keyword Arguments** + + * `repeats`: Number of times to apply the model + * `input_injection`: If `true`, then the input is passed to the model along with the output + +**Inputs** + + * `x`: Input as described above + +**Returns** + + * Output is computed by as described above + * Updated state of the `model` + +**Parameters** + + * Parameters of `model` + +**States** + + * State of `model` + + +source
+ +
+
+ + + +## Convolutional Layers + +
+# Lux.ConvType. + + + +```julia +Conv(k::NTuple{N,Integer}, (in_chs => out_chs)::Pair{<:Integer,<:Integer}, + activation=identity; init_weight=glorot_uniform, init_bias=zeros32, stride=1, + pad=0, dilation=1, groups=1, use_bias=true) +``` + +Standard convolutional layer. + +Image data should be stored in WHCN order (width, height, channels, batch). In other words, a `100 x 100` RGB image would be a `100 x 100 x 3 x 1` array, and a batch of 50 would be a `100 x 100 x 3 x 50` array. This has `N = 2` spatial dimensions, and needs a kernel size like `(5, 5)`, a 2-tuple of integers. To take convolutions along `N` feature dimensions, this layer expects as input an array with `ndims(x) == N + 2`, where `size(x, N + 1) == in_chs` is the number of input channels, and `size(x, ndims(x))` is the number of observations in a batch. + +:::warning + +Frameworks like [`Pytorch`](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d) perform cross-correlation in their convolution layers + +::: + +**Arguments** + + * `k`: Tuple of integers specifying the size of the convolutional kernel. Eg, for 2D convolutions `length(k) == 2` + * `in_chs`: Number of input channels + * `out_chs`: Number of input and output channels + * `activation`: Activation Function + +**Keyword Arguments** + + * `init_weight`: Controls the initialization of the weight parameter + * `init_bias`: Controls the initialization of the bias parameter + * `stride`: Should each be either single integer, or a tuple with `N` integers + * `dilation`: Should each be either single integer, or a tuple with `N` integers + * `pad`: Specifies the number of elements added to the borders of the data array. It can be + + * a single integer for equal padding all around, + * a tuple of `N` integers, to apply the same padding at begin/end of each spatial dimension, + * a tuple of `2*N` integers, for asymmetric padding, or + * the singleton `SamePad()`, to calculate padding such that `size(output,d) == size(x,d) / stride` (possibly rounded) for each spatial dimension. + * Periodic padding can achieved by pre-empting the layer with a `WrappedFunction(x -> NNlib.circular_pad(x, N_pad; dims=pad_dims))` + * `groups`: Expected to be an `Int`. It specifies the number of groups to divide a convolution into (set `groups = in_chs` for Depthwise Convolutions). `in_chs` and `out_chs` must be divisible by `groups`. + * `use_bias`: Trainable bias can be disabled entirely by setting this to `false`. + * `allow_fast_activation`: If `true`, then certain activations can be approximated with a faster version. The new activation function will be given by `NNlib.fast_act(activation)` + +**Inputs** + + * `x`: Data satisfying `ndims(x) == N + 2 && size(x, N - 1) == in_chs`, i.e. `size(x) = (I_N, ..., I_1, C_in, N)` + +**Returns** + + * Output of the convolution `y` of size `(O_N, ..., O_1, C_out, N)` where + +$$ +O_i = floor\left(\frac{I_i + pad[i] + pad[(i + N) \% length(pad)] - dilation[i] \times (k[i] - 1)}{stride[i]} + 1\right) +$$ + + * Empty `NamedTuple()` + +**Parameters** + + * `weight`: Convolution kernel + * `bias`: Bias (present if `use_bias=true`) + + +source
+ +
+
+
+# Lux.ConvTransposeType. + + + +```julia +ConvTranspose(k::NTuple{N,Integer}, (in_chs => out_chs)::Pair{<:Integer,<:Integer}, + activation=identity; init_weight=glorot_uniform, init_bias=zeros32, + stride=1, pad=0, dilation=1, groups=1, use_bias=true) +``` + +Standard convolutional transpose layer. + +**Arguments** + + * `k`: Tuple of integers specifying the size of the convolutional kernel. Eg, for 2D convolutions `length(k) == 2` + * `in_chs`: Number of input channels + * `out_chs`: Number of input and output channels + * `activation`: Activation Function + +**Keyword Arguments** + + * `init_weight`: Controls the initialization of the weight parameter + * `init_bias`: Controls the initialization of the bias parameter + * `stride`: Should each be either single integer, or a tuple with `N` integers + * `dilation`: Should each be either single integer, or a tuple with `N` integers + * `pad`: Specifies the number of elements added to the borders of the data array. It can be + + * a single integer for equal padding all around, + * a tuple of `N` integers, to apply the same padding at begin/end of each spatial dimension, + * a tuple of `2*N` integers, for asymmetric padding, or + * the singleton `SamePad()`, to calculate padding such that `size(output,d) == size(x,d) * stride` (possibly rounded) for each spatial dimension. + * `groups`: Expected to be an `Int`. It specifies the number of groups to divide a convolution into (set `groups = in_chs` for Depthwise Convolutions). `in_chs` and `out_chs` must be divisible by `groups`. + * `use_bias`: Trainable bias can be disabled entirely by setting this to `false`. + * `allow_fast_activation`: If `true`, then certain activations can be approximated with a faster version. The new activation function will be given by `NNlib.fast_act(activation)` + +**Inputs** + + * `x`: Data satisfying `ndims(x) == N + 2 && size(x, N - 1) == in_chs`, i.e. `size(x) = (I_N, ..., I_1, C_in, N)` + +**Returns** + + * Output of the convolution transpose `y` of size `(O_N, ..., O_1, C_out, N)` where + * Empty `NamedTuple()` + +**Parameters** + + * `weight`: Convolution Transpose kernel + * `bias`: Bias (present if `use_bias=true`) + + +source
+ +
+
+
+# Lux.CrossCorType. + + + +```julia +CrossCor(k::NTuple{N,Integer}, (in_chs => out_chs)::Pair{<:Integer,<:Integer}, + activation=identity; init_weight=glorot_uniform, init_bias=zeros32, stride=1, + pad=0, dilation=1, use_bias=true) +``` + +Cross Correlation layer. + +Image data should be stored in WHCN order (width, height, channels, batch). In other words, a `100 x 100` RGB image would be a `100 x 100 x 3 x 1` array, and a batch of 50 would be a `100 x 100 x 3 x 50` array. This has `N = 2` spatial dimensions, and needs a kernel size like `(5, 5)`, a 2-tuple of integers. To take convolutions along `N` feature dimensions, this layer expects as input an array with `ndims(x) == N + 2`, where `size(x, N + 1) == in_chs` is the number of input channels, and `size(x, ndims(x))` is the number of observations in a batch. + +**Arguments** + + * `k`: Tuple of integers specifying the size of the convolutional kernel. Eg, for 2D convolutions `length(k) == 2` + * `in_chs`: Number of input channels + * `out_chs`: Number of input and output channels + * `activation`: Activation Function + +**Keyword Arguments** + + * `init_weight`: Controls the initialization of the weight parameter + * `init_bias`: Controls the initialization of the bias parameter + * `stride`: Should each be either single integer, or a tuple with `N` integers + * `dilation`: Should each be either single integer, or a tuple with `N` integers + * `pad`: Specifies the number of elements added to the borders of the data array. It can be + + * a single integer for equal padding all around, + * a tuple of `N` integers, to apply the same padding at begin/end of each spatial dimension, + * a tuple of `2*N` integers, for asymmetric padding, or + * the singleton `SamePad()`, to calculate padding such that `size(output,d) == size(x,d) / stride` (possibly rounded) for each spatial dimension. + * `use_bias`: Trainable bias can be disabled entirely by setting this to `false`. + * `allow_fast_activation`: If `true`, then certain activations can be approximated with a faster version. The new activation function will be given by `NNlib.fast_act(activation)` + +**Inputs** + + * `x`: Data satisfying `ndims(x) == N + 2 && size(x, N - 1) == in_chs`, i.e. `size(x) = (I_N, ..., I_1, C_in, N)` + +**Returns** + + * Output of the convolution `y` of size `(O_N, ..., O_1, C_out, N)` where + +$$ +O_i = floor\left(\frac{I_i + pad[i] + pad[(i + N) \% length(pad)] - dilation[i] \times (k[i] - 1)}{stride[i]} + 1\right) +$$ + + * Empty `NamedTuple()` + +**Parameters** + + * `weight`: Convolution kernel + * `bias`: Bias (present if `use_bias=true`) + + +source
+ +
+
+ + + +## Dropout Layers + +
+# Lux.AlphaDropoutType. + + + +```julia +AlphaDropout(p::Real) +``` + +AlphaDropout layer. + +**Arguments** + + * `p`: Probability of Dropout + + * if `p = 0` then [`NoOpLayer`](layers#Lux.NoOpLayer) is returned. + * if `p = 1` then `WrappedLayer(Base.Fix1(broadcast, zero))` is returned. + +**Inputs** + + * `x`: Must be an AbstractArray + +**Returns** + + * `x` with dropout mask applied if `training=Val(true)` else just `x` + * State with updated `rng` + +**States** + + * `rng`: Pseudo Random Number Generator + * `training`: Used to check if training/inference mode + +Call [`Lux.testmode`](../Building_Blocks/LuxCore#LuxCore.testmode) to switch to test mode. + +See also [`Dropout`](layers#Lux.Dropout), [`VariationalHiddenDropout`](layers#Lux.VariationalHiddenDropout) + + +source
+ +
+
+
+# Lux.DropoutType. + + + +```julia +Dropout(p; dims=:) +``` + +Dropout layer. + +**Arguments** + + * `p`: Probability of Dropout (if `p = 0` then [`NoOpLayer`](layers#Lux.NoOpLayer) is returned) + +**Keyword Arguments** + + * To apply dropout along certain dimension(s), specify the `dims` keyword. e.g. `Dropout(p; dims = 3)` will randomly zero out entire channels on WHCN input (also called 2D dropout). + +**Inputs** + + * `x`: Must be an AbstractArray + +**Returns** + + * `x` with dropout mask applied if `training=Val(true)` else just `x` + * State with updated `rng` + +**States** + + * `rng`: Pseudo Random Number Generator + * `training`: Used to check if training/inference mode + +Call [`Lux.testmode`](../Building_Blocks/LuxCore#LuxCore.testmode) to switch to test mode. + +See also [`AlphaDropout`](layers#Lux.AlphaDropout), [`VariationalHiddenDropout`](layers#Lux.VariationalHiddenDropout) + + +source
+ +
+
+
+# Lux.VariationalHiddenDropoutType. + + + +```julia +VariationalHiddenDropout(p; dims=:) +``` + +VariationalHiddenDropout layer. The only difference from Dropout is that the `mask` is retained until [`Lux.update_state(l, :update_mask, Val(true))`](../Building_Blocks/LuxCore#LuxCore.update_state) is called. + +**Arguments** + + * `p`: Probability of Dropout (if `p = 0` then [`NoOpLayer`](layers#Lux.NoOpLayer) is returned) + +**Keyword Arguments** + + * To apply dropout along certain dimension(s), specify the `dims` keyword. e.g. `VariationalHiddenDropout(p; dims = 3)` will randomly zero out entire channels on WHCN input (also called 2D dropout). + +**Inputs** + + * `x`: Must be an AbstractArray + +**Returns** + + * `x` with dropout mask applied if `training=Val(true)` else just `x` + * State with updated `rng` + +**States** + + * `rng`: Pseudo Random Number Generator + * `training`: Used to check if training/inference mode + * `mask`: Dropout mask. Initilly set to nothing. After every run, contains the mask applied in that call + * `update_mask`: Stores whether new mask needs to be generated in the current call + +Call [`Lux.testmode`](../Building_Blocks/LuxCore#LuxCore.testmode) to switch to test mode. + +See also [`AlphaDropout`](layers#Lux.AlphaDropout), [`Dropout`](layers#Lux.Dropout) + + +source
+ +
+
+ + + +## Pooling Layers + +
+# Lux.AdaptiveMaxPoolType. + + + +```julia +AdaptiveMaxPool(out::NTuple) +``` + +Adaptive Max Pooling layer. Calculates the necessary window size such that its output has `size(y)[1:N] == out`. + +**Arguments** + + * `out`: Size of the first `N` dimensions for the output + +**Inputs** + + * `x`: Expects as input an array with `ndims(x) == N+2`, i.e. channel and batch dimensions, after the `N` feature dimensions, where `N = length(out)`. + +**Returns** + + * Output of size `(out..., C, N)` + * Empty `NamedTuple()` + +See also [`MaxPool`](layers#Lux.MaxPool), [`AdaptiveMeanPool`](layers#Lux.AdaptiveMeanPool). + + +source
+ +
+
+
+# Lux.AdaptiveMeanPoolType. + + + +```julia +AdaptiveMeanPool(out::NTuple) +``` + +Adaptive Mean Pooling layer. Calculates the necessary window size such that its output has `size(y)[1:N] == out`. + +**Arguments** + + * `out`: Size of the first `N` dimensions for the output + +**Inputs** + + * `x`: Expects as input an array with `ndims(x) == N+2`, i.e. channel and batch dimensions, after the `N` feature dimensions, where `N = length(out)`. + +**Returns** + + * Output of size `(out..., C, N)` + * Empty `NamedTuple()` + +See also [`MeanPool`](layers#Lux.MeanPool), [`AdaptiveMaxPool`](layers#Lux.AdaptiveMaxPool). + + +source
+ +
+
+
+# Lux.GlobalMaxPoolType. + + + +```julia +GlobalMaxPool() +``` + +Global Max Pooling layer. Transforms (w,h,c,b)-shaped input into (1,1,c,b)-shaped output, by performing max pooling on the complete (w,h)-shaped feature maps. + +**Inputs** + + * `x`: Data satisfying `ndims(x) > 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)` + +**Returns** + + * Output of the pooling `y` of size `(1, ..., 1, C, N)` + * Empty `NamedTuple()` + +See also [`MaxPool`](layers#Lux.MaxPool), [`AdaptiveMaxPool`](layers#Lux.AdaptiveMaxPool), [`GlobalMeanPool`](layers#Lux.GlobalMeanPool) + + +source
+ +
+
+
+# Lux.GlobalMeanPoolType. + + + +```julia +GlobalMeanPool() +``` + +Global Mean Pooling layer. Transforms (w,h,c,b)-shaped input into (1,1,c,b)-shaped output, by performing mean pooling on the complete (w,h)-shaped feature maps. + +**Inputs** + + * `x`: Data satisfying `ndims(x) > 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)` + +**Returns** + + * Output of the pooling `y` of size `(1, ..., 1, C, N)` + * Empty `NamedTuple()` + +See also [`MeanPool`](layers#Lux.MeanPool), [`AdaptiveMeanPool`](layers#Lux.AdaptiveMeanPool), [`GlobalMaxPool`](layers#Lux.GlobalMaxPool) + + +source
+ +
+
+
+# Lux.MaxPoolType. + + + +```julia +MaxPool(window::NTuple; pad=0, stride=window) +``` + +Max pooling layer, which replaces all pixels in a block of size `window` with the maximum value. + +**Arguments** + + * `window`: Tuple of integers specifying the size of the window. Eg, for 2D pooling `length(window) == 2` + +**Keyword Arguments** + + * `stride`: Should each be either single integer, or a tuple with `N` integers + * `pad`: Specifies the number of elements added to the borders of the data array. It can be + + * a single integer for equal padding all around, + * a tuple of `N` integers, to apply the same padding at begin/end of each spatial dimension, + * a tuple of `2*N` integers, for asymmetric padding, or + * the singleton `SamePad()`, to calculate padding such that `size(output,d) == size(x,d) / stride` (possibly rounded) for each spatial dimension. + +**Inputs** + + * `x`: Data satisfying `ndims(x) == N + 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)` + +**Returns** + + * Output of the pooling `y` of size `(O_N, ..., O_1, C, N)` where + +$$ + O_i = floor\left(\frac{I_i + pad[i] + pad[(i + N) \% length(pad)] - dilation[i] \times (k[i] - 1)}{stride[i]} + 1\right) +$$ + + * Empty `NamedTuple()` + +See also [`Conv`](layers#Lux.Conv), [`MeanPool`](layers#Lux.MeanPool), [`GlobalMaxPool`](layers#Lux.GlobalMaxPool), [`AdaptiveMaxPool`](layers#Lux.AdaptiveMaxPool) + + +source
+ +
+
+
+# Lux.MeanPoolType. + + + +```julia +MeanPool(window::NTuple; pad=0, stride=window) +``` + +Mean pooling layer, which replaces all pixels in a block of size `window` with the mean value. + +**Arguments** + + * `window`: Tuple of integers specifying the size of the window. Eg, for 2D pooling `length(window) == 2` + +**Keyword Arguments** + + * `stride`: Should each be either single integer, or a tuple with `N` integers + * `pad`: Specifies the number of elements added to the borders of the data array. It can be + + * a single integer for equal padding all around, + * a tuple of `N` integers, to apply the same padding at begin/end of each spatial dimension, + * a tuple of `2*N` integers, for asymmetric padding, or + * the singleton `SamePad()`, to calculate padding such that `size(output,d) == size(x,d) / stride` (possibly rounded) for each spatial dimension. + +**Inputs** + + * `x`: Data satisfying `ndims(x) == N + 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)` + +**Returns** + + * Output of the pooling `y` of size `(O_N, ..., O_1, C, N)` where + +$$ + O_i = floor\left(\frac{I_i + pad[i] + pad[(i + N) \% length(pad)] - dilation[i] \times (k[i] - 1)}{stride[i]} + 1\right) +$$ + + * Empty `NamedTuple()` + +See also [`Conv`](layers#Lux.Conv), [`MaxPool`](layers#Lux.MaxPool), [`GlobalMeanPool`](layers#Lux.GlobalMeanPool), [`AdaptiveMeanPool`](layers#Lux.AdaptiveMeanPool) + + +source
+ +
+
+ + + +## Recurrent Layers + +
+# Lux.GRUCellType. + + + +```julia +GRUCell((in_dims, out_dims)::Pair{<:Int,<:Int}; use_bias=true, train_state::Bool=false, + init_weight::Tuple{Function,Function,Function}=(glorot_uniform, glorot_uniform, + glorot_uniform), + init_bias::Tuple{Function,Function,Function}=(zeros32, zeros32, zeros32), + init_state::Function=zeros32) +``` + +Gated Recurrent Unit (GRU) Cell + +$$ +\begin{align} + r &= \sigma(W_{ir} \times x + W_{hr} \times h_{prev} + b_{hr})\\ + z &= \sigma(W_{iz} \times x + W_{hz} \times h_{prev} + b_{hz})\\ + n &= \tanh(W_{in} \times x + b_{in} + r \cdot (W_{hn} \times h_{prev} + b_{hn}))\\ + h_{new} &= (1 - z) \cdot n + z \cdot h_{prev} +\end{align} +$$ + +**Arguments** + + * `in_dims`: Input Dimension + * `out_dims`: Output (Hidden State) Dimension + * `use_bias`: Set to false to deactivate bias + * `train_state`: Trainable initial hidden state can be activated by setting this to `true` + * `init_bias`: Initializer for bias. Must be a tuple containing 3 functions + * `init_weight`: Initializer for weight. Must be a tuple containing 3 functions + * `init_state`: Initializer for hidden state + +**Inputs** + + * Case 1a: Only a single input `x` of shape `(in_dims, batch_size)`, `train_state` is set to `false` - Creates a hidden state using `init_state` and proceeds to Case 2. + * Case 1b: Only a single input `x` of shape `(in_dims, batch_size)`, `train_state` is set to `true` - Repeats `hidden_state` from parameters to match the shape of `x` and proceeds to Case 2. + * Case 2: Tuple `(x, (h, ))` is provided, then the output and a tuple containing the updated hidden state is returned. + +**Returns** + + * Tuple containing + + * Output $h_{new}$ of shape `(out_dims, batch_size)` + * Tuple containing new hidden state $h_{new}$ + * Updated model state + +**Parameters** + + * `weight_i`: Concatenated Weights to map from input space $\\left\\\{ W_{ir}, W_{iz}, W_{in} \\right\\\}$. + * `weight_h`: Concatenated Weights to map from hidden space $\\left\\\{ W_{hr}, W_{hz}, W_{hn} \\right\\\}$. + * `bias_i`: Bias vector ($b_{in}$; not present if `use_bias=false`). + * `bias_h`: Concatenated Bias vector for the hidden space $\\left\\\{ b_{hr}, b_{hz}, b_{hn} \\right\\\}$ (not present if `use_bias=false`). + * `hidden_state`: Initial hidden state vector (not present if `train_state=false`) $\\left\\\{ b_{hr}, b_{hz}, b_{hn} \\right\\\}$. + +**States** + + * `rng`: Controls the randomness (if any) in the initial state generation + + +source
+ +
+
+
+# Lux.LSTMCellType. + + + +```julia +LSTMCell(in_dims => out_dims; use_bias::Bool=true, train_state::Bool=false, + train_memory::Bool=false, + init_weight=(glorot_uniform, glorot_uniform, glorot_uniform, glorot_uniform), + init_bias=(zeros32, zeros32, ones32, zeros32), init_state=zeros32, + init_memory=zeros32) +``` + +Long Short-Term (LSTM) Cell + +$$ +\begin{align} + i &= \sigma(W_{ii} \times x + W_{hi} \times h_{prev} + b_{i})\\ + f &= \sigma(W_{if} \times x + W_{hf} \times h_{prev} + b_{f})\\ + g &= tanh(W_{ig} \times x + W_{hg} \times h_{prev} + b_{g})\\ + o &= \sigma(W_{io} \times x + W_{ho} \times h_{prev} + b_{o})\\ + c_{new} &= f \cdot c_{prev} + i \cdot g\\ + h_{new} &= o \cdot tanh(c_{new}) +\end{align} +$$ + +**Arguments** + + * `in_dims`: Input Dimension + * `out_dims`: Output (Hidden State & Memory) Dimension + * `use_bias`: Set to false to deactivate bias + * `train_state`: Trainable initial hidden state can be activated by setting this to `true` + * `train_memory`: Trainable initial memory can be activated by setting this to `true` + * `init_bias`: Initializer for bias. Must be a tuple containing 4 functions + * `init_weight`: Initializer for weight. Must be a tuple containing 4 functions + * `init_state`: Initializer for hidden state + * `init_memory`: Initializer for memory + +**Inputs** + + * Case 1a: Only a single input `x` of shape `(in_dims, batch_size)`, `train_state` is set to `false`, `train_memory` is set to `false` - Creates a hidden state using `init_state`, hidden memory using `init_memory` and proceeds to Case 2. + * Case 1b: Only a single input `x` of shape `(in_dims, batch_size)`, `train_state` is set to `true`, `train_memory` is set to `false` - Repeats `hidden_state` vector from the parameters to match the shape of `x`, creates hidden memory using `init_memory` and proceeds to Case 2. + * Case 1c: Only a single input `x` of shape `(in_dims, batch_size)`, `train_state` is set to `false`, `train_memory` is set to `true` - Creates a hidden state using `init_state`, repeats the memory vector from parameters to match the shape of `x` and proceeds to Case 2. + * Case 1d: Only a single input `x` of shape `(in_dims, batch_size)`, `train_state` is set to `true`, `train_memory` is set to `true` - Repeats the hidden state and memory vectors from the parameters to match the shape of `x` and proceeds to Case 2. + * Case 2: Tuple `(x, (h, c))` is provided, then the output and a tuple containing the updated hidden state and memory is returned. + +**Returns** + + * Tuple Containing + + * Output $h_{new}$ of shape `(out_dims, batch_size)` + * Tuple containing new hidden state $h_{new}$ and new memory $c_{new}$ + * Updated model state + +**Parameters** + + * `weight_i`: Concatenated Weights to map from input space $\{ W_{ii}, W_{if}, W_{ig}, W_{io} \}$. + * `weight_h`: Concatenated Weights to map from hidden space $\{ W_{hi}, W_{hf}, W_{hg}, W_{ho} \}$ + * `bias`: Bias vector (not present if `use_bias=false`) + * `hidden_state`: Initial hidden state vector (not present if `train_state=false`) + * `memory`: Initial memory vector (not present if `train_memory=false`) + +**States** + + * `rng`: Controls the randomness (if any) in the initial state generation + + +source
+ +
+
+
+# Lux.RNNCellType. + + + +```julia +RNNCell(in_dims => out_dims, activation=tanh; bias::Bool=true, + train_state::Bool=false, init_bias=zeros32, init_weight=glorot_uniform, + init_state=ones32) +``` + +An Elman RNNCell cell with `activation` (typically set to `tanh` or `relu`). + +$h_{new} = activation(weight_{ih} \times x + weight_{hh} \times h_{prev} + bias)$ + +**Arguments** + + * `in_dims`: Input Dimension + * `out_dims`: Output (Hidden State) Dimension + * `activation`: Activation function + * `bias`: Set to false to deactivate bias + * `train_state`: Trainable initial hidden state can be activated by setting this to `true` + * `init_bias`: Initializer for bias + * `init_weight`: Initializer for weight + * `init_state`: Initializer for hidden state + +**Inputs** + + * Case 1a: Only a single input `x` of shape `(in_dims, batch_size)`, `train_state` is set to `false` - Creates a hidden state using `init_state` and proceeds to Case 2. + * Case 1b: Only a single input `x` of shape `(in_dims, batch_size)`, `train_state` is set to `true` - Repeats `hidden_state` from parameters to match the shape of `x` and proceeds to Case 2. + * Case 2: Tuple `(x, (h, ))` is provided, then the output and a tuple containing the updated hidden state is returned. + +**Returns** + + * Tuple containing + + * Output $h_{new}$ of shape `(out_dims, batch_size)` + * Tuple containing new hidden state $h_{new}$ + * Updated model state + +**Parameters** + + * `weight_ih`: Maps the input to the hidden state. + * `weight_hh`: Maps the hidden state to the hidden state. + * `bias`: Bias vector (not present if `use_bias=false`) + * `hidden_state`: Initial hidden state vector (not present if `train_state=false`) + +**States** + + * `rng`: Controls the randomness (if any) in the initial state generation + + +source
+ +
+
+
+# Lux.RecurrenceType. + + + +```julia +Recurrence(cell; + ordering::AbstractTimeSeriesDataBatchOrdering=BatchLastIndex(), + return_sequence::Bool=false) +``` + +Wraps a recurrent cell (like [`RNNCell`](layers#Lux.RNNCell), [`LSTMCell`](layers#Lux.LSTMCell), [`GRUCell`](layers#Lux.GRUCell)) to automatically operate over a sequence of inputs. + +:::warning + +This is completely distinct from `Flux.Recur`. It doesn't make the `cell` stateful, rather allows operating on an entire sequence of inputs at once. See [`StatefulRecurrentCell`](layers#Lux.StatefulRecurrentCell) for functionality similar to `Flux.Recur`. + +::: + +**Arguments** + + * `cell`: A recurrent cell. See [`RNNCell`](layers#Lux.RNNCell), [`LSTMCell`](layers#Lux.LSTMCell), [`GRUCell`](layers#Lux.GRUCell), for how the inputs/outputs of a recurrent cell must be structured. + +**Keyword Arguments** + + * `return_sequence`: If `true` returns the entire sequence of outputs, else returns only the last output. Defaults to `false`. + * `ordering`: The ordering of the batch and time dimensions in the input. Defaults to `BatchLastIndex()`. Alternatively can be set to `TimeLastIndex()`. + +**Inputs** + + * If `x` is a + + * Tuple or Vector: Each element is fed to the `cell` sequentially. + * Array (except a Vector): It is spliced along the penultimate dimension and each slice is fed to the `cell` sequentially. + +**Returns** + + * Output of the `cell` for the entire sequence. + * Update state of the `cell`. + +**Parameters** + + * Same as `cell`. + +**States** + + * Same as `cell`. + + +source
+ +
+
+
+# Lux.StatefulRecurrentCellType. + + + +```julia +StatefulRecurrentCell(cell) +``` + +Wraps a recurrent cell (like [`RNNCell`](layers#Lux.RNNCell), [`LSTMCell`](layers#Lux.LSTMCell), [`GRUCell`](layers#Lux.GRUCell)) and makes it stateful. + +:::tip + +This is very similar to `Flux.Recur` + +::: + +To avoid undefined behavior, once the processing of a single sequence of data is complete, update the state with `Lux.update_state(st, :carry, nothing)`. + +**Arguments** + + * `cell`: A recurrent cell. See [`RNNCell`](layers#Lux.RNNCell), [`LSTMCell`](layers#Lux.LSTMCell), [`GRUCell`](layers#Lux.GRUCell), for how the inputs/outputs of a recurrent cell must be structured. + +**Inputs** + + * Input to the `cell`. + +**Returns** + + * Output of the `cell` for the entire sequence. + * Update state of the `cell` and updated `carry`. + +**Parameters** + + * Same as `cell`. + +**States** + + * NamedTuple containing: + + * `cell`: Same as `cell`. + * `carry`: The carry state of the `cell`. + + +source
+ +
+
+ + + +## Linear Layers + +
+# Lux.BilinearType. + + + +```julia +Bilinear((in1_dims, in2_dims) => out, activation=identity; init_weight=glorot_uniform, + init_bias=zeros32, use_bias::Bool=true, allow_fast_activation::Bool=true) +Bilinear(in12_dims => out, activation=identity; init_weight=glorot_uniform, + init_bias=zeros32, use_bias::Bool=true, allow_fast_activation::Bool=true) +``` + +Create a fully connected layer between two inputs and an output, and otherwise similar to [`Dense`](layers#Lux.Dense). Its output, given vectors `x` & `y`, is another vector `z` with, for all `i in 1:out`: + +`z[i] = activation(x' * W[i, :, :] * y + bias[i])` + +If `x` and `y` are matrices, then each column of the output `z = B(x, y)` is of this form, with `B` the Bilinear layer. + +**Arguments** + + * `in1_dims`: number of input dimensions of `x` + * `in2_dims`: number of input dimensions of `y` + * `in12_dims`: If specified, then `in1_dims = in2_dims = in12_dims` + * `out`: number of output dimensions + * `activation`: activation function + +**Keyword Arguments** + + * `init_weight`: initializer for the weight matrix (`weight = init_weight(rng, out_dims, in1_dims, in2_dims)`) + * `init_bias`: initializer for the bias vector (ignored if `use_bias=false`) + * `use_bias`: Trainable bias can be disabled entirely by setting this to `false` + * `allow_fast_activation`: If `true`, then certain activations can be approximated with a faster version. The new activation function will be given by `NNlib.fast_act(activation)` + +**Input** + + * A 2-Tuple containing + + * `x` must be an AbstractArray with `size(x, 1) == in1_dims` + * `y` must be an AbstractArray with `size(y, 1) == in2_dims` + * If the input is an AbstractArray, then `x = y` + +**Returns** + + * AbstractArray with dimensions `(out_dims, size(x, 2))` + * Empty `NamedTuple()` + +**Parameters** + + * `weight`: Weight Matrix of size `(out_dims, in1_dims, in2_dims)` + * `bias`: Bias of size `(out_dims, 1)` (present if `use_bias=true`) + + +source
+ +
+
+
+# Lux.DenseType. + + + +```julia +Dense(in_dims => out_dims, activation=identity; init_weight=glorot_uniform, + init_bias=zeros32, bias::Bool=true) +``` + +Create a traditional fully connected layer, whose forward pass is given by: `y = activation.(weight * x .+ bias)` + +**Arguments** + + * `in_dims`: number of input dimensions + * `out_dims`: number of output dimensions + * `activation`: activation function + +**Keyword Arguments** + + * `init_weight`: initializer for the weight matrix (`weight = init_weight(rng, out_dims, in_dims)`) + * `init_bias`: initializer for the bias vector (ignored if `use_bias=false`) + * `use_bias`: Trainable bias can be disabled entirely by setting this to `false` + * `allow_fast_activation`: If `true`, then certain activations can be approximated with a faster version. The new activation function will be given by `NNlib.fast_act(activation)` + +**Input** + + * `x` must be an AbstractArray with `size(x, 1) == in_dims` + +**Returns** + + * AbstractArray with dimensions `(out_dims, ...)` where `...` are the dimensions of `x` + * Empty `NamedTuple()` + +**Parameters** + + * `weight`: Weight Matrix of size `(out_dims, in_dims)` + * `bias`: Bias of size `(out_dims, 1)` (present if `use_bias=true`) + + +source
+ +
+
+
+# Lux.EmbeddingType. + + + +```julia +Embedding(in_dims => out_dims; init_weight=randn32) +``` + +A lookup table that stores embeddings of dimension `out_dims` for a vocabulary of size `in_dims`. + +This layer is often used to store word embeddings and retrieve them using indices. + +:::warning + +Unlike `Flux.Embedding`, this layer does not support using `OneHotArray` as an input. + +::: + +**Arguments** + + * `in_dims`: number of input dimensions + * `out_dims`: number of output dimensions + +**Keyword Arguments** + + * `init_weight`: initializer for the weight matrix (`weight = init_weight(rng, out_dims, in_dims)`) + +**Input** + + * Integer OR + * Abstract Vector of Integers OR + * Abstract Array of Integers + +**Returns** + + * Returns the embedding corresponding to each index in the input. For an N dimensional input, an N + 1 dimensional output is returned. + * Empty `NamedTuple()` + + +source
+ +
+
+
+# Lux.ScaleType. + + + +```julia +Scale(dims, activation=identity; init_weight=ones32, init_bias=zeros32, bias::Bool=true) +``` + +Create a Sparsely Connected Layer with a very specific structure (only Diagonal Elements are non-zero). The forward pass is given by: `y = activation.(weight .* x .+ bias)` + +**Arguments** + + * `dims`: size of the learnable scale and bias parameters. + * `activation`: activation function + +**Keyword Arguments** + + * `init_weight`: initializer for the weight matrix (`weight = init_weight(rng, out_dims, in_dims)`) + * `init_bias`: initializer for the bias vector (ignored if `use_bias=false`) + * `use_bias`: Trainable bias can be disabled entirely by setting this to `false` + * `allow_fast_activation`: If `true`, then certain activations can be approximated with a faster version. The new activation function will be given by `NNlib.fast_act(activation)` + +**Input** + + * `x` must be an Array of size `(dims..., B)` or `(dims...[0], ..., dims[k])` for `k ≤ size(dims)` + +**Returns** + + * Array of size `(dims..., B)` or `(dims...[0], ..., dims[k])` for `k ≤ size(dims)` + * Empty `NamedTuple()` + +**Parameters** + + * `weight`: Weight Array of size `(dims...)` + * `bias`: Bias of size `(dims...)` + + +source
+ +
+
+ + + +## Misc. Helper Layers + +
+# Lux.FlattenLayerType. + + + +```julia +FlattenLayer() +``` + +Flattens the passed array into a matrix. + +**Inputs** + + * `x`: AbstractArray + +**Returns** + + * AbstractMatrix of size `(:, size(x, ndims(x)))` + * Empty `NamedTuple()` + + +source
+ +
+
+
+# Lux.MaxoutType. + + + +```julia +Maxout(layers...) +Maxout(; layers...) +Maxout(f::Function, n_alts::Int) +``` + +This contains a number of internal layers, each of which receives the same input. Its output is the elementwise maximum of the the internal layers' outputs. + +Maxout over linear dense layers satisfies the univeral approximation theorem. See [1]. + +See also [`Parallel`](layers#Lux.Parallel) to reduce with other operators. + +**Arguments** + + * Layers can be specified in three formats: + + * A list of `N` Lux layers + * Specified as `N` keyword arguments. + * A no argument function `f` and an integer `n_alts` which specifies the number of layers. + +**Inputs** + + * `x`: Input that is passed to each of the layers + +**Returns** + + * Output is computed by taking elementwise `max` of the outputs of the individual layers. + * Updated state of the `layers` + +**Parameters** + + * Parameters of each `layer` wrapped in a NamedTuple with `fields = layer_1, layer_2, ..., layer_N` (naming changes if using the kwargs API) + +**States** + + * States of each `layer` wrapped in a NamedTuple with `fields = layer_1, layer_2, ..., layer_N` (naming changes if using the kwargs API) + +**References** + +[1] Goodfellow, Warde-Farley, Mirza, Courville & Bengio "Maxout Networks" [https://arxiv.org/abs/1302.4389](https://arxiv.org/abs/1302.4389) + + +source
+ +
+
+
+# Lux.NoOpLayerType. + + + +```julia +NoOpLayer() +``` + +As the name suggests does nothing but allows pretty printing of layers. Whatever input is passed is returned. + + +source
+ +
+
+
+# Lux.ReshapeLayerType. + + + +```julia +ReshapeLayer(dims) +``` + +Reshapes the passed array to have a size of `(dims..., :)` + +**Arguments** + + * `dims`: The new dimensions of the array (excluding the last dimension). + +**Inputs** + + * `x`: AbstractArray of any shape which can be reshaped in `(dims..., size(x, ndims(x)))` + +**Returns** + + * AbstractArray of size `(dims..., size(x, ndims(x)))` + * Empty `NamedTuple()` + + +source
+ +
+
+
+# Lux.SelectDimType. + + + +```julia +SelectDim(dim, i) +``` + +Return a view of all the data of the input `x` where the index for dimension `dim` equals `i`. Equivalent to `view(x,:,:,...,i,:,:,...)` where `i` is in position `d`. + +**Arguments** + + * `dim`: Dimension for indexing + * `i`: Index for dimension `dim` + +**Inputs** + + * `x`: AbstractArray that can be indexed with `view(x,:,:,...,i,:,:,...)` + +**Returns** + + * `view(x,:,:,...,i,:,:,...)` where `i` is in position `d` + * Empty `NamedTuple()` + + +source
+ +
+
+
+# Lux.WrappedFunctionType. + + + +```julia +WrappedFunction(f) +``` + +Wraps a stateless and parameter less function. Might be used when a function is added to `Chain`. For example, `Chain(x -> relu.(x))` would not work and the right thing to do would be `Chain((x, ps, st) -> (relu.(x), st))`. An easier thing to do would be `Chain(WrappedFunction(Base.Fix1(broadcast, relu)))` + +**Arguments** + + * `f::Function`: A stateless and parameterless function + +**Inputs** + + * `x`: s.t `hasmethod(f, (typeof(x),))` is `true` + +**Returns** + + * Output of `f(x)` + * Empty `NamedTuple()` + + +source
+ +
+
+ + + +## Normalization Layers + +
+# Lux.BatchNormType. + + + +```julia +BatchNorm(chs::Integer, activation=identity; init_bias=zeros32, init_scale=ones32, + affine=true, track_stats=true, epsilon=1f-5, momentum=0.1f0, + allow_fast_activation::Bool=true) +``` + +[Batch Normalization](https://arxiv.org/abs/1502.03167) layer. + +`BatchNorm` computes the mean and variance for each $D_1 × ... × D_{N-2} × 1 × D_N$ input slice and normalises the input accordingly. + +**Arguments** + + * `chs`: Size of the channel dimension in your data. Given an array with `N` dimensions, call the `N-1`th the channel dimension. For a batch of feature vectors this is just the data dimension, for `WHCN` images it's the usual channel dimension. + * `activation`: After normalization, elementwise activation `activation` is applied. + +**Keyword Arguments** + + * If `track_stats=true`, accumulates mean and variance statistics in training phase that will be used to renormalize the input in test phase. + * `epsilon`: a value added to the denominator for numerical stability + * `momentum`: the value used for the `running_mean` and `running_var` computation + * `allow_fast_activation`: If `true`, then certain activations can be approximated with a faster version. The new activation function will be given by `NNlib.fast_act(activation)` + * If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters. + + * `init_bias`: Controls how the `bias` is initiliazed + * `init_scale`: Controls how the `scale` is initiliazed + +**Inputs** + + * `x`: Array where `size(x, N - 1) = chs` and `ndims(x) > 2` + +**Returns** + + * `y`: Normalized Array + * Update model state + +**Parameters** + + * `affine=true` + + * `bias`: Bias of shape `(chs,)` + * `scale`: Scale of shape `(chs,)` + * `affine=false` - Empty `NamedTuple()` + +**States** + + * Statistics if `track_stats=true` + + * `running_mean`: Running mean of shape `(chs,)` + * `running_var`: Running variance of shape `(chs,)` + * Statistics if `track_stats=false` + + * `running_mean`: nothing + * `running_var`: nothing + * `training`: Used to check if training/inference mode + +Use `Lux.testmode` during inference. + +**Example** + +```julia +m = Chain(Dense(784 => 64), BatchNorm(64, relu), Dense(64 => 10), BatchNorm(10)) +``` + +:::warning + +Passing a batch size of 1, during training will result in NaNs. + +::: + +See also [`BatchNorm`](layers#Lux.BatchNorm), [`InstanceNorm`](layers#Lux.InstanceNorm), [`LayerNorm`](layers#Lux.LayerNorm), [`WeightNorm`](layers#Lux.WeightNorm) + + +source
+ +
+
+
+# Lux.GroupNormType. + + + +```julia +GroupNorm(chs::Integer, groups::Integer, activation=identity; init_bias=zeros32, + init_scale=ones32, affine=true, epsilon=1f-5, + allow_fast_activation::Bool=true) +``` + +[Group Normalization](https://arxiv.org/abs/1803.08494) layer. + +**Arguments** + + * `chs`: Size of the channel dimension in your data. Given an array with `N` dimensions, call the `N-1`th the channel dimension. For a batch of feature vectors this is just the data dimension, for `WHCN` images it's the usual channel dimension. + * `groups` is the number of groups along which the statistics are computed. The number of channels must be an integer multiple of the number of groups. + * `activation`: After normalization, elementwise activation `activation` is applied. + +**Keyword Arguments** + + * `epsilon`: a value added to the denominator for numerical stability + * `allow_fast_activation`: If `true`, then certain activations can be approximated with a faster version. The new activation function will be given by `NNlib.fast_act(activation)` + * If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters. + + * `init_bias`: Controls how the `bias` is initiliazed + * `init_scale`: Controls how the `scale` is initiliazed + +**Inputs** + + * `x`: Array where `size(x, N - 1) = chs` and `ndims(x) > 2` + +**Returns** + + * `y`: Normalized Array + * Update model state + +**Parameters** + + * `affine=true` + + * `bias`: Bias of shape `(chs,)` + * `scale`: Scale of shape `(chs,)` + * `affine=false` - Empty `NamedTuple()` + +**States** + + * `training`: Used to check if training/inference mode + +Use `Lux.testmode` during inference. + +**Example** + +```julia +m = Chain(Dense(784 => 64), GroupNorm(64, 4, relu), Dense(64 => 10), GroupNorm(10, 5)) +``` + +:::warning + +GroupNorm doesn't have CUDNN support. The GPU fallback is not very efficient. + +::: + +See also [`GroupNorm`](layers#Lux.GroupNorm), [`InstanceNorm`](layers#Lux.InstanceNorm), [`LayerNorm`](layers#Lux.LayerNorm), [`WeightNorm`](layers#Lux.WeightNorm) + + +source
+ +
+
+
+# Lux.InstanceNormType. + + + +```julia +InstanceNorm(chs::Integer, activation=identity; init_bias=zeros32, init_scale=ones32, + affine=true, epsilon=1f-5, allow_fast_activation::Bool=true) +``` + +Instance Normalization. For details see [1]. + +Instance Normalization computes the mean and variance for each $D_1 \times ... \times D_{N - 2} \times 1 \times 1$` input slice and normalises the input accordingly. + +**Arguments** + + * `chs`: Size of the channel dimension in your data. Given an array with `N` dimensions, call the `N-1`th the channel dimension. For a batch of feature vectors this is just the data dimension, for `WHCN` images it's the usual channel dimension. + * `activation`: After normalization, elementwise activation `activation` is applied. + +**Keyword Arguments** + + * `epsilon`: a value added to the denominator for numerical stability + * `allow_fast_activation`: If `true`, then certain activations can be approximated with a faster version. The new activation function will be given by `NNlib.fast_act(activation)` + * If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters. + + * `init_bias`: Controls how the `bias` is initiliazed + * `init_scale`: Controls how the `scale` is initiliazed + +**Inputs** + + * `x`: Array where `size(x, N - 1) = chs` and `ndims(x) > 2` + +**Returns** + + * `y`: Normalized Array + * Update model state + +**Parameters** + + * `affine=true` + + * `bias`: Bias of shape `(chs,)` + * `scale`: Scale of shape `(chs,)` + * `affine=false` - Empty `NamedTuple()` + +**States** + + * `training`: Used to check if training/inference mode + +Use `Lux.testmode` during inference. + +**Example** + +```julia +m = Chain(Dense(784 => 64), InstanceNorm(64, relu), Dense(64 => 10), InstanceNorm(10, 5)) +``` + +**References** + +[1] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Instance normalization: The missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016). + +:::warning + +InstanceNorm doesn't have CUDNN support. The GPU fallback is not very efficient. + +::: + +See also [`BatchNorm`](layers#Lux.BatchNorm), [`GroupNorm`](layers#Lux.GroupNorm), [`LayerNorm`](layers#Lux.LayerNorm), [`WeightNorm`](layers#Lux.WeightNorm) + + +source
+ +
+
+
+# Lux.LayerNormType. + + + +```julia +LayerNorm(shape::NTuple{N, Int}, activation=identity; epsilon=1f-5, dims=Colon(), + affine::Bool=true, init_bias=zeros32, init_scale=ones32,) +``` + +Computes mean and standard deviation over the whole input array, and uses these to normalize the whole array. Optionally applies an elementwise affine transformation afterwards. + +Given an input array $x$, this layer computes + +$$ +y = \frac{x - \mathbb{E}[x]}{\sqrt{Var[x] + \epsilon}} * \gamma + \beta +$$ + +where $\gamma$ & $\beta$ are trainable parameters if `affine=true`. + +:::warning + +As of v0.5.0, the doc used to say `affine::Bool=false`, but the code actually had `affine::Bool=true` as the default. Now the doc reflects the code, so please check whether your assumptions about the default (if made) were invalid. + +::: + +**Arguments** + + * `shape`: Broadcastable shape of input array excluding the batch dimension. + * `activation`: After normalization, elementwise activation `activation` is applied. + +**Keyword Arguments** + + * `allow_fast_activation`: If `true`, then certain activations can be approximated with a faster version. The new activation function will be given by `NNlib.fast_act(activation)` + * `epsilon`: a value added to the denominator for numerical stability. + * `dims`: Dimensions to normalize the array over. + * If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters. + + * `init_bias`: Controls how the `bias` is initiliazed + * `init_scale`: Controls how the `scale` is initiliazed + +**Inputs** + + * `x`: AbstractArray + +**Returns** + + * `y`: Normalized Array + * Empty NamedTuple() + +**Parameters** + + * `affine=false`: Empty `NamedTuple()` + * `affine=true` + + * `bias`: Bias of shape `(shape..., 1)` + * `scale`: Scale of shape `(shape..., 1)` + + +source
+ +
+
+
+# Lux.WeightNormType. + + + +```julia +WeightNorm(layer::AbstractExplicitLayer, which_params::NTuple{N,Symbol}, + dims::Union{Tuple,Nothing}=nothing) +``` + +Applies [weight normalization](https://arxiv.org/abs/1602.07868) to a parameter in the given layer. + +$w = g\frac{v}{\|v\|}$ + +Weight normalization is a reparameterization that decouples the magnitude of a weight tensor from its direction. This updates the parameters in `which_params` (e.g. `weight`) using two parameters: one specifying the magnitude (e.g. `weight_g`) and one specifying the direction (e.g. `weight_v`). + +**Arguments** + + * `layer` whose parameters are being reparameterized + * `which_params`: parameter names for the parameters being reparameterized + * By default, a norm over the entire array is computed. Pass `dims` to modify the dimension. + +**Inputs** + + * `x`: Should be of valid type for input to `layer` + +**Returns** + + * Output from `layer` + * Updated model state of `layer` + +**Parameters** + + * `normalized`: Parameters of `layer` that are being normalized + * `unnormalized`: Parameters of `layer` that are not being normalized + +**States** + + * Same as that of `layer` + + +source
+ +
+
+ + + +## Upsampling + +
+# Lux.PixelShuffleFunction. + + + +```julia +PixelShuffle(r::Int) +``` + +Pixel shuffling layer with upscale factor `r`. Usually used for generating higher resolution images while upscaling them. + +See `NNlib.pixel_shuffle` for more details. + +PixelShuffle is not a Layer, rather it returns a [`WrappedFunction`](layers#Lux.WrappedFunction) with the function set to `Base.Fix2(pixel_shuffle, r)` + +**Arguments** + + * `r`: Upscale factor + +**Inputs** + + * `x`: For 4D-arrays representing N images, the operation converts input size(x) == (W, H, r^2 x C, N) to output of size (r x W, r x H, C, N). For D-dimensional data, it expects ndims(x) == D+2 with channel and batch dimensions, and divides the number of channels by r^D. + +**Returns** + + * Output of size `(r x W, r x H, C, N)` for 4D-arrays, and `(r x W, r x H, ..., C, N)` for D-dimensional data, where `D = ndims(x) - 2` + + +source
+ +
+
+
+# Lux.UpsampleType. + + + +```julia +Upsample(mode = :nearest; [scale, size]) +Upsample(scale, mode = :nearest) +``` + +Upsampling Layer. + +**Layer Construction** + +**Option 1** + + * `mode`: Set to `:nearest`, `:linear`, `:bilinear` or `:trilinear` + +Exactly one of two keywords must be specified: + + * If `scale` is a number, this applies to all but the last two dimensions (channel and batch) of the input. It may also be a tuple, to control dimensions individually. + * Alternatively, keyword `size` accepts a tuple, to directly specify the leading dimensions of the output. + +**Option 2** + + * If `scale` is a number, this applies to all but the last two dimensions (channel and batch) of the input. It may also be a tuple, to control dimensions individually. + * `mode`: Set to `:nearest`, `:bilinear` or `:trilinear` + +Currently supported upsampling `mode`s and corresponding NNlib's methods are: + + * `:nearest` -> `NNlib.upsample_nearest` + * `:bilinear` -> `NNlib.upsample_bilinear` + * `:trilinear` -> `NNlib.upsample_trilinear` + +**Inputs** + + * `x`: For the input dimensions look into the documentation for the corresponding `NNlib` function + + * As a rule of thumb, `:nearest` should work with arrays of arbitrary dimensions + * `:bilinear` works with 4D Arrays + * `:trilinear` works with 5D Arrays + +**Returns** + + * Upsampled Input of size `size` or of size `(I_1 x scale[1], ..., I_N x scale[N], C, N)` + * Empty `NamedTuple()` + + +source
+ +
+
diff --git a/previews/PR474/api/Lux/utilities.md b/previews/PR474/api/Lux/utilities.md new file mode 100644 index 000000000..3e593d70d --- /dev/null +++ b/previews/PR474/api/Lux/utilities.md @@ -0,0 +1,248 @@ + + + +# Utilities + + + + + + +## Index + +- [`Lux.cpu`](#Lux.cpu) +- [`Lux.disable_stacktrace_truncation!`](#Lux.disable_stacktrace_truncation!) +- [`Lux.f16`](#Lux.f16) +- [`Lux.f32`](#Lux.f32) +- [`Lux.f64`](#Lux.f64) +- [`Lux.foldl_init`](#Lux.foldl_init) +- [`Lux.gpu`](#Lux.gpu) +- [`Lux.istraining`](#Lux.istraining) +- [`Lux.multigate`](#Lux.multigate) +- [`Lux.replicate`](#Lux.replicate) + + + + +## Device Management / Data Transfer + +
+# Lux.cpuFunction. + + + +```julia +cpu(x) +``` + +Transfer `x` to CPU. + +::: warning + +This function has been deprecated. Use [`cpu_device`](../Accelerator_Support/LuxDeviceUtils#LuxDeviceUtils.cpu_device) instead. + +::: + + +source
+ +
+
+
+# Lux.gpuFunction. + + + +```julia +gpu(x) +``` + +Transfer `x` to GPU determined by the backend set using [`Lux.gpu_backend!`](../Accelerator_Support/LuxDeviceUtils#LuxDeviceUtils.gpu_backend!). + +:::warning + +This function has been deprecated. Use [`gpu_device`](../Accelerator_Support/LuxDeviceUtils#LuxDeviceUtils.gpu_device) instead. Using this function inside performance critical code will cause massive slowdowns due to type inference failure. + +::: + + +source
+ +
+
+ +:::warning + + +For detailed API documentation on Data Transfer check out the [LuxDeviceUtils.jl](../LuxDeviceUtils/) + + +::: + + + + +## Weight Initialization + + +:::warning + + +For API documentation on Initialization check out the [WeightInitializers.jl](../WeightInitializers/) + + +::: + + + + +## Miscellaneous Utilities + +
+# Lux.foldl_initFunction. + + + +```julia +foldl_init(op, x) +foldl_init(op, x, init) +``` + +Exactly same as `foldl(op, x; init)` in the forward pass. But, gives gradients wrt `init` in the backward pass. + + +source
+ +
+
+
+# Lux.istrainingFunction. + + + +```julia +istraining(::Val{training}) +istraining(st::NamedTuple) +``` + +Returns `true` if `training` is `true` or if `st` contains a `training` field with value `true`. Else returns `false`. + +Method undefined if `st.training` is not of type `Val`. + + +source
+ +
+
+
+# Lux.multigateFunction. + + + +```julia +multigate(x::AbstractArray, ::Val{N}) +``` + +Split up `x` into `N` equally sized chunks (along dimension `1`). + + +source
+ +
+
+
+# Lux.replicateFunction. + + + +```julia +replicate(rng::AbstractRNG) +replicate(rng::CUDA.RNG) +``` + +Creates a copy of the `rng` state depending on its type. + + +source
+ +
+
+ + + +## Updating Floating Point Precision + + +By default, Lux uses Float32 for all parameters and states. To update the precision simply pass the parameters / states / arrays into one of the following functions. + +
+# Lux.f16Function. + + + +```julia +f16(m) +``` + +Converts the `eltype` of `m` *floating point* values to `Float16`. Recurses into structs marked with `Functors.@functor`. + + +source
+ +
+
+
+# Lux.f32Function. + + + +```julia +f32(m) +``` + +Converts the `eltype` of `m` *floating point* values to `Float32`. Recurses into structs marked with `Functors.@functor`. + + +source
+ +
+
+
+# Lux.f64Function. + + + +```julia +f64(m) +``` + +Converts the `eltype` of `m` *floating point* values to `Float64`. Recurses into structs marked with `Functors.@functor`. + + +source
+ +
+
+ + + +## Truncated Stacktraces + +
+# Lux.disable_stacktrace_truncation!Function. + + + +```julia +disable_stacktrace_truncation!(; disable::Bool=true) +``` + +An easy way to update `TruncatedStacktraces.VERBOSE` without having to load it manually. + +Effectively does `TruncatedStacktraces.VERBOSE[] = disable` + + +source
+ +
+
diff --git a/previews/PR474/api/Testing_Functionality/LuxTestUtils.md b/previews/PR474/api/Testing_Functionality/LuxTestUtils.md new file mode 100644 index 000000000..e978fcdd4 --- /dev/null +++ b/previews/PR474/api/Testing_Functionality/LuxTestUtils.md @@ -0,0 +1,141 @@ + + + + + +# LuxTestUtils + + +:::warning + + +This is a testing package. Hence, we don't use features like weak dependencies to reduce load times. It is recommended that you exclusively use this package for testing and not add a dependency to it in your main package Project.toml. + + +::: + + +Implements utilities for testing **gradient correctness** and **dynamic dispatch** of Lux.jl models. + + + + +## Index + +- [`LuxTestUtils.@jet`](#LuxTestUtils.@jet) +- [`LuxTestUtils.@test_gradients`](#LuxTestUtils.@test_gradients) + + + + +## Testing using JET.jl + +
+# LuxTestUtils.@jetMacro. + + + +```julia +@jet f(args...) call_broken=false opt_broken=false +``` + +Run JET tests on the function `f` with the arguments `args...`. If `JET` fails to compile or julia version is < 1.7, then the macro will be a no-op. + +**Keyword Arguments** + + * `call_broken`: Marks the test_call as broken. + * `opt_broken`: Marks the test_opt as broken. + +All additional arguments will be forwarded to `@JET.test_call` and `@JET.test_opt`. + +:::tip + +Instead of specifying `target_modules` with every call, you can set preferences for `target_modules` using `Preferences.jl`. For example, to set `target_modules` to `(Lux, LuxLib)` we can run: + +```julia +using Preferences + +set_preferences!(Base.UUID("ac9de150-d08f-4546-94fb-7472b5760531"), + "target_modules" => ["Lux", "LuxLib"]) +``` + +::: + +**Example** + +```julia +using LuxTestUtils + +@testset "Showcase JET Testing" begin + @jet sum([1, 2, 3]) target_modules=(Base, Core) + + @jet sum(1, 1) target_modules=(Base, Core) opt_broken=true +end +``` + +
+
+ + + +## Gradient Correctness + +
+# LuxTestUtils.@test_gradientsMacro. + + + +```julia +@test_gradients f args... [kwargs...] +``` + +Compare the gradients computed by Zygote.jl (Reverse Mode AD) against: + + * Tracker.jl (Reverse Mode AD) + * ReverseDiff.jl (Reverse Mode AD) + * ForwardDiff.jl (Forward Mode AD) + * FiniteDifferences.jl (Finite Differences) + +:::tip + +This function is completely compatible with Test.jl + +::: + +**Arguments** + + * `f`: The function to test. + * `args...`: Inputs to `f` wrt which the gradients are computed. + +**Keyword Arguments** + + * `gpu_testing`: Disables ForwardDiff, ReverseDiff and FiniteDifferences tests. (Default: `false`) + * `soft_fail`: If `true`, the test will not fail if any of the gradients are incorrect, instead it will show up as broken. (Default: `false`) + * `skip_(tracker|reverse_diff|forward_diff|finite_differences)`: Skip the corresponding gradient computation and check. (Default: `false`) + * `large_arrays_skip_(forward_diff|finite_differences)`: Skip the corresponding gradient computation and check for large arrays. (Forward Mode and Finite Differences are not efficient for large arrays.) (Default: `true`) + * `large_array_length`: The length of the array above which the gradient computation is considered large. (Default: 25) + * `max_total_array_size`: Treat as large array if the total size of all arrays is greater than this value. (Default: 100) + * `(tracker|reverse_diff|forward_diff|finite_differences)_broken`: Mark the corresponding gradient test as broken. (Default: `false`) + +**Keyword Arguments for `check_approx`** + + * `atol`: Absolute tolerance for gradient comparisons. (Default: `0.0`) + * `rtol`: Relative tolerance for gradient comparisons. (Default: `atol > 0 ? 0.0 : √eps(typeof(atol))`) + * `nans`: Whether or not NaNs are considered equal. (Default: `false`) + +**Example** + +```julia +using LuxTestUtils + +x = randn(10) + +@testset "Showcase Gradient Testing" begin + @test_gradients sum abs2 x + + @test_gradients prod x +end +``` + +
+
diff --git a/previews/PR474/api/index.md b/previews/PR474/api/index.md new file mode 100644 index 000000000..49d33d429 --- /dev/null +++ b/previews/PR474/api/index.md @@ -0,0 +1,76 @@ + + + +# API Reference + +- [LuxAMDGPU](Accelerator_Support/LuxAMDGPU#LuxAMDGPU) + - [Index](Accelerator_Support/LuxAMDGPU#Index) + - [API](Accelerator_Support/LuxAMDGPU#API) +- [LuxCUDA](Accelerator_Support/LuxCUDA#LuxCUDA) + - [Index](Accelerator_Support/LuxCUDA#Index) + - [API Reference](Accelerator_Support/LuxCUDA#API-Reference) +- [LuxDeviceUtils](Accelerator_Support/LuxDeviceUtils#LuxDeviceUtils) + - [Index](Accelerator_Support/LuxDeviceUtils#Index) + - [Preferences](Accelerator_Support/LuxDeviceUtils#Preferences) + - [Data Transfer](Accelerator_Support/LuxDeviceUtils#Data-Transfer) + - [Miscellaneous](Accelerator_Support/LuxDeviceUtils#Miscellaneous) +- [LuxCore](Building_Blocks/LuxCore#LuxCore) + - [Index](Building_Blocks/LuxCore#Index) + - [Abstract Types](Building_Blocks/LuxCore#Abstract-Types) + - [General](Building_Blocks/LuxCore#General) + - [Parameters](Building_Blocks/LuxCore#Parameters) + - [States](Building_Blocks/LuxCore#States) +- [LuxLib](Building_Blocks/LuxLib#LuxLib) + - [Index](Building_Blocks/LuxLib#Index) + - [Dropout](Building_Blocks/LuxLib#Dropout) + - [Normalization](Building_Blocks/LuxLib#Normalization) +- [WeightInitializers](Building_Blocks/WeightInitializers#WeightInitializers) + - [Index](Building_Blocks/WeightInitializers#Index) + - [API Reference](Building_Blocks/WeightInitializers#API-Reference) + - [Main Functions](Building_Blocks/WeightInitializers#Main-Functions) + - [Commonly Used Wrappers](Building_Blocks/WeightInitializers#Commonly-Used-Wrappers) +- [Boltz](Domain_Specific_Modeling/Boltz#Boltz) + - [Index](Domain_Specific_Modeling/Boltz#Index) +- [Computer Vision Models](Domain_Specific_Modeling/Boltz#Computer-Vision-Models) + - [Classification Models: Native Lux Models](Domain_Specific_Modeling/Boltz#Classification-Models:-Native-Lux-Models) + - [Building Blocks](Domain_Specific_Modeling/Boltz#Building-Blocks) + - [Non-Public API](Domain_Specific_Modeling/Boltz#Non-Public-API) + - [Classification Models: Imported from Metalhead.jl](Domain_Specific_Modeling/Boltz#Classification-Models:-Imported-from-Metalhead.jl) + - [Preprocessing](Domain_Specific_Modeling/Boltz#Preprocessing) +- [Experimental Features](Lux/contrib#Experimental-Features) + - [Index](Lux/contrib#Index) + - [Training](Lux/contrib#Training) + - [Parameter Freezing](Lux/contrib#Parameter-Freezing) + - [Map over Layer](Lux/contrib#Map-over-Layer) + - [Debugging Functionality](Lux/contrib#Debugging-Functionality) + - [Tied Parameters](Lux/contrib#Tied-Parameters) + - [Stateful Layer](Lux/contrib#Stateful-Layer) + - [Compact Layer API](Lux/contrib#Compact-Layer-API) +- [Flux Models to Lux Models](Lux/flux_to_lux#Flux-Models-to-Lux-Models) + - [Index](Lux/flux_to_lux#Index) + - [Functions](Lux/flux_to_lux#Functions) + - [Layers](Lux/flux_to_lux#Layers) +- [Built-In Layers](Lux/layers#Built-In-Layers) + - [Index](Lux/layers#Index) + - [Containers](Lux/layers#Containers) + - [Convolutional Layers](Lux/layers#Convolutional-Layers) + - [Dropout Layers](Lux/layers#Dropout-Layers) + - [Pooling Layers](Lux/layers#Pooling-Layers) + - [Recurrent Layers](Lux/layers#Recurrent-Layers) + - [Linear Layers](Lux/layers#Linear-Layers) + - [Misc. Helper Layers](Lux/layers#Misc.-Helper-Layers) + - [Normalization Layers](Lux/layers#Normalization-Layers) + - [Upsampling](Lux/layers#Upsampling) +- [Utilities](Lux/utilities#Utilities) + - [Index](Lux/utilities#Index) + - [Device Management / Data Transfer](Lux/utilities#Device-Management-/-Data-Transfer) + - [Weight Initialization](Lux/utilities#Weight-Initialization) + - [Miscellaneous Utilities](Lux/utilities#Miscellaneous-Utilities) + - [Updating Floating Point Precision](Lux/utilities#Updating-Floating-Point-Precision) + - [Truncated Stacktraces](Lux/utilities#Truncated-Stacktraces) +- [LuxTestUtils](Testing_Functionality/LuxTestUtils#LuxTestUtils) + - [Index](Testing_Functionality/LuxTestUtils#Index) + - [Testing using JET.jl](Testing_Functionality/LuxTestUtils#Testing-using-JET.jl) + - [Gradient Correctness](Testing_Functionality/LuxTestUtils#Gradient-Correctness) +- [API Reference](index#API-Reference) + diff --git a/previews/PR474/ecosystem.md b/previews/PR474/ecosystem.md new file mode 100644 index 000000000..18a33d5cf --- /dev/null +++ b/previews/PR474/ecosystem.md @@ -0,0 +1,422 @@ +--- +layout: page +--- + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/previews/PR474/introduction/citation.md b/previews/PR474/introduction/citation.md new file mode 100644 index 000000000..2a8af0df9 --- /dev/null +++ b/previews/PR474/introduction/citation.md @@ -0,0 +1,33 @@ + + + +# Citation + + +If you found this library to be useful in academic work, then please cite: + + +```bibtex +@software{pal2023lux, + author = {Pal, Avik}, + title = {{Lux: Explicit Parameterization of Deep Neural Networks in Julia}}, + month = {April}, + year = 2023, + note = {If you use this software, please cite it as below.}, + publisher = {Zenodo}, + version = {v0.5.0}, + doi = {10.5281/zenodo.7808904}, + url = {https://doi.org/10.5281/zenodo.7808904} +} +``` + + +```bibtex +@thesis{pal2023efficient, + title = {{On Efficient Training \& Inference of Neural Differential Equations}}, + author = {Pal, Avik}, + year = {2023}, + school = {Massachusetts Institute of Technology} +} +``` + diff --git a/previews/PR474/introduction/index.md b/previews/PR474/introduction/index.md new file mode 100644 index 000000000..435be48d7 --- /dev/null +++ b/previews/PR474/introduction/index.md @@ -0,0 +1,240 @@ + + + +# Getting Started + + + + +## Installation + + +Install [Julia v1.6 or above](https://julialang.org/downloads/). Lux.jl is available through the Julia package manager. You can enter it by pressing `]` in the REPL and then typing + + +```julia +pkg> add Lux +``` + + +Alternatively, you can also do + + +```julia +import Pkg; Pkg.add("Lux") +``` + + +:::tip + + +The Julia Compiler is always improving. As such, we recommend using the latest stable version of Julia instead of the LTS. + + +::: + + + + +## Quickstart + + +:::tip PRE-REQUISITES + + +You need to install `Optimisers` and `Zygote` if not done already. `Pkg.add(["Optimisers", "Zygote"])` + + +::: + + +```julia +using Lux, Random, Optimisers, Zygote +# using LuxCUDA, LuxAMDGPU, Metal # Optional packages for GPU support +``` + + +We take randomness very seriously + + +```julia +# Seeding +rng = Random.default_rng() +Random.seed!(rng, 0) +``` + + +``` +Random.TaskLocalRNG() +``` + + +Build the model + + +```julia +# Construct the layer +model = Chain(Dense(128, 256, tanh), Chain(Dense(256, 1, tanh), Dense(1, 10))) +``` + + +``` +Chain( + layer_1 = Dense(128 => 256, tanh_fast), # 33_024 parameters + layer_2 = Dense(256 => 1, tanh_fast), # 257 parameters + layer_3 = Dense(1 => 10), # 20 parameters +) # Total: 33_301 parameters, + # plus 0 states. +``` + + +Models don't hold parameters and states so initialize them. From there on, we just use our standard AD and Optimisers API. + + +```julia +# Get the device determined by Lux +device = gpu_device() + +# Parameter and State Variables +ps, st = Lux.setup(rng, model) .|> device + +# Dummy Input +x = rand(rng, Float32, 128, 2) |> device + +# Run the model +y, st = Lux.apply(model, x, ps, st) + +# Gradients +## Pullback API to capture change in state +(l, st_), pb = pullback(p -> Lux.apply(model, x, p, st), ps) +gs = pb((one.(l), nothing))[1] + +# Optimization +st_opt = Optimisers.setup(Adam(0.0001f0), ps) +st_opt, ps = Optimisers.update(st_opt, ps, gs) +``` + + +``` +((layer_1 = (weight = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[0.00313608 0.00806096 … 0.00476192 0.00732118; -0.00447309 -0.0119719 … -0.00822211 -0.0110335; … ; -0.00294453 -0.00749935 … -0.00426221 -0.00678769; 0.000750543 0.00195163 … 0.00120731 0.00178011], Float32[9.83485f-7 6.49782f-6 … 2.26756f-6 5.3599f-6; 2.00083f-6 1.43324f-5 … 6.76022f-6 1.21738f-5; … ; 8.67016f-7 5.62395f-6 … 1.81662f-6 4.60721f-6; 5.63307f-8 3.80882f-7 … 1.45758f-7 3.16876f-7], (0.81, 0.998001))), bias = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[0.00954525; -0.0146331; … ; -0.00881351; 0.00233261;;], Float32[9.11106f-6; 2.14125f-5; … ; 7.76769f-6; 5.44098f-7;;], (0.81, 0.998001)))), layer_2 = (weight = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[-0.0104967 0.0714637 … -0.0224641 0.108277], Float32[1.10179f-5 0.000510699 … 5.04627f-5 0.00117238], (0.81, 0.998001))), bias = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[0.178909;;], Float32[0.0032008;;], (0.81, 0.998001)))), layer_3 = (weight = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[-0.105128; -0.105128; … ; -0.105128; -0.105128;;], Float32[0.00110518; 0.00110518; … ; 0.00110518; 0.00110518;;], (0.81, 0.998001))), bias = Leaf(Adam(0.0001, (0.9, 0.999), 1.0e-8), (Float32[0.2; 0.2; … ; 0.2; 0.2;;], Float32[0.00399995; 0.00399995; … ; 0.00399995; 0.00399995;;], (0.81, 0.998001))))), (layer_1 = (weight = Float32[-0.11044693 0.10963185 … 0.097855344 -0.009167461; -0.0110904 0.07588978 … -0.03180492 0.088967875; … ; 0.01864451 -0.034903362 … -0.016194405 0.019176451; -0.09216565 -0.047490627 … -0.08869007 0.009417342], bias = Float32[-9.999999f-5; 9.999998f-5; … ; 9.999999f-5; -9.9999954f-5;;]), layer_2 = (weight = Float32[0.05391791 -0.103956826 … -0.050862882 0.020512676], bias = Float32[-0.0001;;]), layer_3 = (weight = Float32[-0.6546853; 0.6101978; … ; 0.41120994; 0.5494141;;], bias = Float32[-0.0001; -0.0001; … ; -0.0001; -0.0001;;]))) +``` + + + + +## Defining Custom Layers + + +```julia +using Lux, Random, Optimisers, Zygote +# using LuxCUDA, LuxAMDGPU, Metal # Optional packages for GPU support +import Lux.Experimental: @compact +``` + + +We will define a custom MLP using the `@compact` macro. The macro takes in a list of parameters, layers and states, and a function defining the forward pass of the neural network. + + +```julia +n_in = 1 +n_out = 1 +nlayers = 3 + +model = @compact(w1=Dense(n_in, 128), + w2=[Dense(128, 128) for i in 1:nlayers], + w3=Dense(128, n_out), + act=relu) do x + embed = act(w1(x)) + for w in w2 + embed = act(w(embed)) + end + out = w3(embed) + return out +end +``` + + +``` +@compact( + w1 = Dense(1 => 128), # 256 parameters + w2 = NamedTuple( + 1 = Dense(128 => 128), # 16_512 parameters + 2 = Dense(128 => 128), # 16_512 parameters + 3 = Dense(128 => 128), # 16_512 parameters + ), + w3 = Dense(128 => 1), # 129 parameters + act = relu, +) do x + embed = act(w1(x)) + for w = w2 + embed = act(w(embed)) + end + out = w3(embed) + return out +end # Total: 49_921 parameters, + # plus 1 states. +``` + + +We can initialize the model and train it with the same code as before! + + +```julia +ps, st = Lux.setup(Xoshiro(0), model) + +model(randn(n_in, 32), ps, st) # 1×32 Matrix as output. + +x_data = collect(-2.0f0:0.1f0:2.0f0)' +y_data = 2 .* x_data .- x_data .^ 3 +st_opt = Optimisers.setup(Adam(), ps) + +for epoch in 1:1000 + global st # Put this in a function in real use-cases + (loss, st), pb = Zygote.pullback(ps) do p + y, st_ = model(x_data, p, st) + return sum(abs2, y .- y_data), st_ + end + gs = only(pb((one(loss), nothing))) + epoch % 100 == 1 && println("Epoch: $(epoch) | Loss: $(loss)") + Optimisers.update!(st_opt, ps, gs) +end +``` + + +``` +Epoch: 1 | Loss: 84.32512 +Epoch: 101 | Loss: 0.08861052 +Epoch: 201 | Loss: 0.007037298 +Epoch: 301 | Loss: 0.005391656 +Epoch: 401 | Loss: 0.014058021 +Epoch: 501 | Loss: 0.0022117028 +Epoch: 601 | Loss: 0.0015865607 +Epoch: 701 | Loss: 0.21984956 +Epoch: 801 | Loss: 0.00019668281 +Epoch: 901 | Loss: 0.0018975141 +``` + + + + +## Additional Packages + + +`LuxDL` hosts various packages that provide additional functionality for Lux.jl. All packages mentioned in this documentation are available via the Julia General Registry. + + +You can install all those packages via `import Pkg; Pkg.add()`. + + + + +## GPU Support + + +GPU Support for Lux.jl requires loading additional packages: + + + * [`LuxCUDA.jl`](https://github.com/LuxDL/LuxCUDA.jl) for CUDA support. + * [`LuxAMDGPU.jl`](https://github.com/LuxDL/LuxAMDGPU.jl) for AMDGPU support. + * [`Metal.jl`](https://github.com/JuliaGPU/Metal.jl) for Apple Metal support. + diff --git a/previews/PR474/introduction/overview.md b/previews/PR474/introduction/overview.md new file mode 100644 index 000000000..e8d92dbeb --- /dev/null +++ b/previews/PR474/introduction/overview.md @@ -0,0 +1,41 @@ + + + +# Why we wrote Lux? + + +Julia already has quite a few well established Neural Network Frameworks – [Flux](https://fluxml.ai/) & [KNet](https://denizyuret.github.io/Knet.jl/latest/). However, certain design elements – **Coupled Model and Parameters** & **Internal Mutations** – associated with these frameworks make them less compiler and user friendly. Making changes to address these problems in the respective frameworks would be too disruptive for users. Here comes in `Lux`: a neural network framework built completely using pure functions to make it both compiler and autodiff friendly. + + + + +## Design Principles + + + * **Layers must be immutable** – cannot store any parameter/state but rather store the information to construct them + * **Layers are pure functions** + * **Layers return a Tuple containing the result and the updated state** + * **Given same inputs the outputs must be same** – yes this must hold true even for stochastic functions. Randomness must be controlled using `rng`s passed in the state. + * **Easily extensible** + + + + +## Why use Lux over Flux? + + + * **Neural Networks for SciML**: For SciML Applications (Neural ODEs, Deep Equilibrium Models) solvers typically expect a monolithic parameter vector. Flux enables this via its `destructure` mechanism, but `destructure` comes with various [edge cases and limitations](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.destructure). Lux forces users to make an explicit distinction between state variables and parameter variables to avoid these issues. Also, it comes battery-included for distributed training using [FluxMPI.jl](https://github.com/avik-pal/FluxMPI.jl) *(I know :P the naming)* + * **Sensible display of Custom Layers** – Ever wanted to see Pytorch like Network printouts or wondered how to extend the pretty printing of Flux's layers? Lux handles all of that by default. + * **Truly immutable models** - No *unexpected internal mutations* since all layers are implemented as pure functions. All layers are also *deterministic* given the parameters and state: if a layer is supposed to be stochastic (say `Dropout`), the state must contain a seed which is then updated after the function call. + * **Easy Parameter Manipulation** – By separating parameter data and layer structures, Lux makes implementing `WeightNorm`, `SpectralNorm`, etc. downright trivial. Without this separation, it is much harder to pass such parameters around without mutations which AD systems don't like. + + + + +## Why not use Lux? + + + * **Small Neural Networks on CPU** – Lux is developed for training large neural networks. For smaller architectures, we recommend using [SimpleChains.jl](https://github.com/PumasAI/SimpleChains.jl). + * **Lux won't magically speed up your code (yet)** – Lux shares the same backend with Flux and so if your primary desire to shift is driven by performance, you will be disappointed. + * **XLA Support** – Lux doesn't compile to XLA which means no TPU support unfortunately. + diff --git a/previews/PR474/introduction/resources.md b/previews/PR474/introduction/resources.md new file mode 100644 index 000000000..e202904e5 --- /dev/null +++ b/previews/PR474/introduction/resources.md @@ -0,0 +1,19 @@ + + + +# Resources to Get Started + + + * Go through the [Quickstart Example](index). + * Read the introductory tutorials on [Julia](https://jump.dev/JuMP.jl/stable/tutorials/getting_started/getting_started_with_julia/#Getting-started-with-Julia) and [Lux](../tutorials/). + * Go through the examples sorted based on their complexity in the documentation. + + +:::warning HAVE MORE QUESTIONS? + + +For usage related questions, please use [Github Discussions](https://github.com/avik-pal/Lux.jl/discussions) or [JuliaLang Discourse (machine learning domain)](https://discourse.julialang.org/c/domain/ml/) which allows questions and answers to be indexed. To report bugs use [github issues](https://github.com/LuxDL/Lux.jl/issues) or even better send in a [pull request](https://github.com/LuxDL/Lux.jl/pulls). + + +::: + diff --git a/previews/PR474/manual/debugging.md b/previews/PR474/manual/debugging.md new file mode 100644 index 000000000..8e378333a --- /dev/null +++ b/previews/PR474/manual/debugging.md @@ -0,0 +1,339 @@ + + + +# Debugging Lux Models + + +Debugging DNNs can be very painful. Especially with the gigantic stacktraces for Lux, it is even harder to pin-point to which particular layer errored out. This page describes some useful tools that ship with Lux, that can help you debug your models. + + +:::tip TL;DR + + +Simply wrap your model with `Lux.Experimental.@debug`!! + + +::: + + +:::warning DON'T FORGET + + +Remember to use the non Debug mode model after you finish debugging. Debug mode models are way slower. + + +::: + + +Let us construct a model which has an obviously incorrect dimension. In this example, you will see how easy it is to pin-point the problematic layer. + + + + +## Incorrect Model Specification: Dimension Mismatch Problems + + +```julia +using Lux, Random + +model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 3), Dense(1 => 1)), + BatchNorm(1); disable_optimizations=true) + +model_debug = Lux.Experimental.@debug_mode model +``` + + +``` +Chain( + layer_1 = DebugLayer( + layer = Dense(1 => 16, relu), # 32 parameters + ), + layer_2 = Chain( + layer_1 = DebugLayer( + layer = Dense(16 => 3), # 51 parameters + ), + layer_2 = DebugLayer( + layer = Dense(1 => 1), # 2 parameters + ), + ), + layer_3 = DebugLayer( + layer = BatchNorm(1, affine=true, track_stats=true), # 2 parameters, plus 3 + ), +) # Total: 87 parameters, + # plus 3 states. +``` + + +Note that we can use the parameters and states for `model` itself in `model_debug`, no need to make any changes. If you ran the original model this is the kind of error you would see: + + +```julia +rng = Xoshiro(0) + +ps, st = Lux.setup(rng, model) +x = randn(rng, Float32, 1, 1) + +try + model(x, ps, st) +catch e + println(e) +end +``` + + +``` +DimensionMismatch("A has dimensions (1,1) but B has dimensions (3,1)") +``` + + +Ofcourse, this error will come with a detailed stacktrace, but it is still not very useful. Now let's try using the debug mode model: + + +```julia +try + model_debug(x, ps, st) +catch e + println(e) +end +``` + + +``` +[ Info: Input Type: Matrix{Float32} | Input Structure: (1, 1) +[ Info: Running Layer: Dense(1 => 16, relu) at location model.layers.layer_1! +[ Info: Output Type: Matrix{Float32} | Output Structure: (16, 1) +[ Info: Input Type: Matrix{Float32} | Input Structure: (16, 1) +[ Info: Running Layer: Dense(16 => 3) at location model.layers.layer_2.layers.layer_1! +[ Info: Output Type: Matrix{Float32} | Output Structure: (3, 1) +[ Info: Input Type: Matrix{Float32} | Input Structure: (3, 1) +[ Info: Running Layer: Dense(1 => 1) at location model.layers.layer_2.layers.layer_2! +┌ Error: Layer Dense(1 => 1) failed!! This layer is present at location model.layers.layer_2.layers.layer_2 +└ @ Lux.Experimental /var/lib/buildkite-agent/builds/gpuci-15/julialang/lux-dot-jl/src/contrib/debug.jl:113 +DimensionMismatch("A has dimensions (1,1) but B has dimensions (3,1)") +``` + + +See now we know that `model.layers.layer_2.layers.layer_2` is the problematic layer. Let us fix that layer and see what happens: + + +```julia +model = Chain(Dense(1 => 16, relu), + Chain(Dense(16 => 3), // [!code --] + Chain(Dense(16 => 1), // [!code ++] + Dense(1 => 1)), + BatchNorm(1); disable_optimizations=true) +``` + + +```julia +model_fixed = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), + BatchNorm(1); disable_optimizations=true) + +ps, st = Lux.setup(rng, model_fixed) + +model_fixed(x, ps, st) +``` + + +``` +(Float32[0.0;;], (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple()), layer_3 = (running_mean = Float32[-0.01397949], running_var = Float32[NaN], training = Val{true}()))) +``` + + +Voila!! We have tracked down and fixed the problem. + + + + +## Tracking down NaNs + + +Have you encountered those pesky little NaNs in your training? They are very hard to track down. We will create an artificially simulate NaNs in our model and see how we can track the offending layer. + + +We can set `nan_check` to `:forward`, `:backward` or `:both` to check for NaNs in the debug model. (or even disable it by setting it to `:none`) + + +```julia +model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), + BatchNorm(1); disable_optimizations=true) + +ps, st = Lux.setup(rng, model) + +model_debug = Lux.Experimental.@debug_mode model nan_check=:both +``` + + +``` +Chain( + layer_1 = DebugLayer( + layer = Dense(1 => 16, relu), # 32 parameters + ), + layer_2 = Chain( + layer_1 = DebugLayer( + layer = Dense(16 => 1), # 17 parameters + ), + layer_2 = DebugLayer( + layer = Dense(1 => 1), # 2 parameters + ), + ), + layer_3 = DebugLayer( + layer = BatchNorm(1, affine=true, track_stats=true), # 2 parameters, plus 3 + ), +) # Total: 53 parameters, + # plus 3 states. +``` + + +Let us set a value in the parameter to `NaN`: + + +```julia +ps.layer_2.layer_2.weight[1, 1] = NaN +``` + + +``` +NaN +``` + + +Now let us run the model + + +```julia +model(x, ps, st) +``` + + +``` +(Float32[NaN;;], (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple()), layer_3 = (running_mean = Float32[NaN], running_var = Float32[NaN], training = Val{true}()))) +``` + + +Ah as expected our output is `NaN`. But is is not very clear how to track where the first `NaN` occurred. Let's run the debug model and check: + + +```julia +try + model_debug(x, ps, st) +catch e + println(e) +end +``` + + +``` +[ Info: Input Type: Matrix{Float32} | Input Structure: (1, 1) +[ Info: Running Layer: Dense(1 => 16, relu) at location model.layers.layer_1! +[ Info: Output Type: Matrix{Float32} | Output Structure: (16, 1) +[ Info: Input Type: Matrix{Float32} | Input Structure: (16, 1) +[ Info: Running Layer: Dense(16 => 1) at location model.layers.layer_2.layers.layer_1! +[ Info: Output Type: Matrix{Float32} | Output Structure: (1, 1) +[ Info: Input Type: Matrix{Float32} | Input Structure: (1, 1) +[ Info: Running Layer: Dense(1 => 1) at location model.layers.layer_2.layers.layer_2! +DomainError((weight = Float32[NaN;;], bias = Float32[0.0;;]), "NaNs detected in parameters of layer Dense(1 => 1) at location model.layers.layer_2.layers.layer_2") +``` + + +And we have figured it out! The first `NaN` occurred in the parameters of `model.layers.layer_2.layers.layer_2`! But what if NaN occurs in the reverse pass! Let us define a custom layer and introduce a fake NaN in the backward pass. + + +```julia +using ChainRulesCore, Zygote + +const CRC = ChainRulesCore + +offending_layer(x) = 2 .* x +``` + + +``` +offending_layer (generic function with 1 method) +``` + + +```julia +model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), offending_layer), + BatchNorm(1); disable_optimizations=true) + +ps, st = Lux.setup(rng, model) + +model(x, ps, st) +``` + + +``` +(Float32[0.0;;], (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple()), layer_3 = (running_mean = Float32[-0.092828535], running_var = Float32[NaN], training = Val{true}()))) +``` + + +Let us define a custom backward pass to introduce some NaNs: + + +```julia +function CRC.rrule(::typeof(offending_layer), x) + y = offending_layer(x) + function ∇offending_layer(Δ) + Δ[1] = NaN + return NoTangent(), Δ + end + return y, ∇offending_layer +end +``` + + +Let us compute the gradient of the layer now: + + +```julia +Zygote.gradient(ps -> sum(first(model(x, ps, st))), ps) +``` + + +``` +((layer_1 = (weight = Float32[0.0; NaN; … ; NaN; 0.0;;], bias = Float32[0.0; NaN; … ; NaN; 0.0;;]), layer_2 = (layer_1 = (weight = Float32[NaN NaN … NaN NaN], bias = Float32[NaN;;]), layer_2 = nothing), layer_3 = (scale = Float32[0.0], bias = Fill(1.0f0, 1))),) +``` + + +Oh no!! A `NaN` is present in the gradient of `ps`. Let us run the debug model and see where the `NaN` occurred: + + +```julia +model_debug = Lux.Experimental.@debug_mode model nan_check=:both + +try + Zygote.gradient(ps -> sum(first(model_debug(x, ps, st))), ps) +catch e + println(e) +end +``` + + +``` +[ Info: Input Type: Matrix{Float32} | Input Structure: (1, 1) +[ Info: Running Layer: Dense(1 => 16, relu) at location model.layers.layer_1! +[ Info: Output Type: Matrix{Float32} | Output Structure: (16, 1) +[ Info: Input Type: Matrix{Float32} | Input Structure: (16, 1) +[ Info: Running Layer: Dense(16 => 1) at location model.layers.layer_2.layers.layer_1! +[ Info: Output Type: Matrix{Float32} | Output Structure: (1, 1) +[ Info: Input Type: Matrix{Float32} | Input Structure: (1, 1) +[ Info: Running Layer: WrappedFunction(offending_layer) at location model.layers.layer_2.layers.layer_2! +[ Info: Output Type: Matrix{Float32} | Output Structure: (1, 1) +[ Info: Input Type: Matrix{Float32} | Input Structure: (1, 1) +[ Info: Running Layer: BatchNorm(1, affine=true, track_stats=true) at location model.layers.layer_3! +[ Info: Output Type: Matrix{Float32} | Output Structure: (1, 1) +DomainError(Float32[NaN;;], "NaNs detected in pullback output for WrappedFunction(offending_layer) at location model.layers.layer_2.layers.layer_2!") +``` + + +And there you go our debug layer prints that the problem is in `WrappedFunction(offending_layer) at location model.layers.layer_2.layers.layer_2`! Once we fix the pullback of the layer, we will fix the NaNs. + + + + +## Conclusion + + +In this manual section, we have discussed tracking down errors in Lux models. We have covered tracking incorrect model specifications and NaNs in forward and backward passes. However, remember that this is an **Experimental** feature, and there might be edge cases that don't work correctly. If you find any such cases, please open an issue on GitHub! + diff --git a/previews/PR474/manual/dispatch_custom_input.md b/previews/PR474/manual/dispatch_custom_input.md new file mode 100644 index 000000000..96e6c2dcb --- /dev/null +++ b/previews/PR474/manual/dispatch_custom_input.md @@ -0,0 +1,162 @@ + + + +# Dispatching on Custom Input Types + + + + +## Which function should participate in dispatch? + + + * Defining a dispatch on `(::Layer)(x::MyInputType, ps, st::NamedTuple)` is inconvenient, since it requires the user to define a new method for every layer type. + * `(::AbstractExplicitLayer)(x::MyInputType, ps, st::NamedTuple)` doesn't work. + * Instead, we need to define the dispatch on `Lux.apply(::AbstractExplicitLayer, x::MyInputType, ps, st::NamedTuple)`. + + + + +## Concrete Example + + +Consider [Neural ODEs](https://implicit-layers-tutorial.org/neural_odes/). In these models, often time we want to every iteration of the neural network to take the current time as input. Here, we won't go through implementing an entire Neural ODE model. Instead we will define a time dependent version of [`Chain`](../api/Lux/layers#Lux.Chain). + + + + +### Time-Dependent Chain Implementation + + +```julia +using Lux, Random + +struct TDChain{L <: NamedTuple} <: Lux.AbstractExplicitContainerLayer{(:layers,)} + layers::L +end + +function (l::TDChain)((x, t)::Tuple, ps, st::NamedTuple) + # Concatenate along the 2nd last dimension + sz = ntuple(i -> i == ndims(x) - 1 ? 1 : size(x, i), ndims(x)) + t_ = ones(eltype(x), sz) .* t # Needs to be modified for GPU + for name in keys(l.layers) + x, st_ = Lux.apply(getfield(l.layers, name), cat(x, t_; dims=ndims(x) - 1), + getfield(ps, name), getfield(st, name)) + st = merge(st, NamedTuple{(name,)}((st_,))) + end + return x, st +end + +model = Chain(Dense(3, 4), TDChain((; d1=Dense(5, 4), d2=Dense(5, 4))), Dense(4, 1)) +``` + + +``` +Chain( + layer_1 = Dense(3 => 4), # 16 parameters + layer_2 = TDChain( + layers = NamedTuple( + d1 = Dense(5 => 4), # 24 parameters + d2 = Dense(5 => 4), # 24 parameters + ), + ), + layer_3 = Dense(4 => 1), # 5 parameters +) # Total: 69 parameters, + # plus 0 states. +``` + + + + +### Running the TDChain + + +```julia +rng = MersenneTwister(0) +ps, st = Lux.setup(rng, model) +x = randn(rng, Float32, 3, 2) + +try + model(x, ps, st) +catch e + Base.showerror(stdout, e) +end +``` + + +``` +MethodError: no method matching (::Main.TDChain{@NamedTuple{d1::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, d2::Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}})(::Matrix{Float32}, ::@NamedTuple{d1::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}, d2::@NamedTuple{weight::Matrix{Float32}, bias::Matrix{Float32}}}, ::@NamedTuple{d1::@NamedTuple{}, d2::@NamedTuple{}}) + +Closest candidates are: + (::Main.TDChain)(!Matched::Tuple, ::Any, ::NamedTuple) + @ Main dispatch_custom_input.md:29 +``` + + + + +### Writing the Correct Dispatch Rules + + + * Create a Custom Layer storing the time. + + +```julia +struct ArrayAndTime{A <: AbstractArray, T <: Real} + array::A + time::T +end +``` + + + * Define the dispatch on `Lux.apply(::AbstractExplicitLayer, x::ArrayAndTime, ps, st::NamedTuple)`. + + +```julia +function Lux.apply(layer::Lux.AbstractExplicitLayer, x::ArrayAndTime, ps, st::NamedTuple) + y, st = layer(x.array, ps, st) + return ArrayAndTime(y, x.time), st +end + +function Lux.apply(layer::TDChain, x::ArrayAndTime, ps, st::NamedTuple) + y, st = layer((x.array, x.time), ps, st) + return ArrayAndTime(y, x.time), st +end +``` + + + * Run the model. + + +```julia +xt = ArrayAndTime(x, 10.0f0) + +model(xt, ps, st)[1] +``` + + +``` +Main.ArrayAndTime{Matrix{Float32}, Float32}(Float32[4.8016562 5.174927], 10.0f0) +``` + + + + +### Using the Same Input for Non-TD Models + + +Writing proper dispatch means we can simply replace the `TDChain` with a `Chain` (of course with dimension corrections) and the pipeline still works. + + +```julia +model = Chain(Dense(3, 4), Chain((; d1=Dense(4, 4), d2=Dense(4, 4))), Dense(4, 1)) + +ps, st = Lux.setup(rng, model) + +model(xt, ps, st)[1] +``` + + +``` +Main.ArrayAndTime{Matrix{Float32}, Float32}(Float32[-0.08124366 -1.1121564], 10.0f0) +``` + diff --git a/previews/PR474/manual/freezing_model_parameters.md b/previews/PR474/manual/freezing_model_parameters.md new file mode 100644 index 000000000..8f870c49b --- /dev/null +++ b/previews/PR474/manual/freezing_model_parameters.md @@ -0,0 +1,163 @@ + + + +# Freezing Model Parameters + + +::: warning + + +API for freezing parameters should be considered experimental at this point. + + +::: + + +In this manual entry, we will go over how to freeze certain parameters in a model. + + + + +## Freezing Layers of a Particular Kind + + +To freeze a particular kind of layer, let's say [`Dense`](../api/Lux/layers#Lux.Dense) in the following example. We can use [`Lux.Experimental.@layer_map`](../api/Lux/contrib#Lux.Experimental.@layer_map) and freeze layers if they are of type `Dense`. + + +```julia +using Lux, Random + +rng = Random.default_rng() +Random.seed!(rng, 0) + +model = Chain(Dense(3, 4), Chain(Dense(4, 4), Dropout(0.5f0), BatchNorm(4)), + Dense(4, 1); disable_optimizations=true) + +ps, st = Lux.setup(rng, model) + +x = randn(rng, Float32, 3, 2) + +model(x, ps, st) + +function freeze_dense(d::Lux.Dense, ps, st, ::String) + return Lux.freeze(d, ps, st, (:weight, :bias)) +end +freeze_dense(l, ps, st, name) = (l, ps, st) + +model_frozen, ps_frozen, st_frozen = Lux.Experimental.@layer_map freeze_dense model ps st + +model_frozen(x, ps_frozen, st_frozen) +``` + + +``` +(Float32[1.7641534 -1.7641534], (layer_1 = (frozen_params = (weight = Float32[-0.026350189 -0.5554656 -0.35653266; -0.17461072 0.6705545 0.29924855; -0.8935247 -0.42453378 -0.3020351; -0.7988979 -0.7666331 -0.7104237], bias = Float32[0.0; 0.0; 0.0; 0.0;;]), states = NamedTuple()), layer_2 = (layer_1 = (frozen_params = (weight = Float32[-0.47289538 -0.680748 0.1764085 0.34383082; 0.42747158 -0.13819042 -0.109261915 -0.6143286; -0.35790488 -0.20881107 0.70390546 0.48137343; 0.82561636 0.38187847 0.05779423 -0.35181466], bias = Float32[0.0; 0.0; 0.0; 0.0;;]), states = NamedTuple()), layer_2 = (rng = Random.Xoshiro(0x87711e5ce1a49ffe, 0xa210b60ecab6b8c5, 0x436c749552fc8172, 0x03e9c7d813a9f096, 0x22a21880af5dc689), training = Val{true}()), layer_3 = (running_mean = Float32[-0.04517859, 0.03484953, -0.004917746, 0.0074841487], running_var = Float32[0.94082206, 0.92428976, 0.90048367, 0.90112025], training = Val{true}())), layer_3 = (frozen_params = (weight = Float32[0.3981135 0.45468387 -0.07694905 0.8353388], bias = Float32[0.0;;]), states = NamedTuple()))) +``` + + + + +## Freezing by Layer Name + + +When the function in `layer_map` is called, the 4th argument is the name of the layer. For example, if you want to freeze the 1st layer inside the inner Chain. The name for this would be `.layer_2.layer_1`. + + +:::code-group + + +```julia [Freezing by Layer Name] + +function freeze_by_name(d, ps, st, name::String) + if name == "model.layer_2.layer_1" + return Lux.Experimental.freeze(d, ps, st, (:weight, :bias)) + else + return d, ps, st + end +end + +``` + + +```julia [Freezing by Layer Type] + +function freeze_dense(d::Dense, ps, st, ::String) + return Lux.Experimental.freeze(d, ps, st, (:weight, :bias)) +end +freeze_dense(l, ps, st, _) = (l, ps, st) + +``` + + +::: + + + + +## Freezing Part of the Parameters + + +Instead of freezing all the parameters, we can simply specify `(:weight,)` to freeze only the `weight` parameter while training the `bias` parameter. + + +::: code-group + + +```julia [Freezing Some Parameters of a Layer] + +function freeze_by_name(d, ps, st, name::String) + if name == "model.layer_2.layer_1" + return Lux.freeze(d, ps, st, (:weight,)) + else + return d, ps, st + end +end + +``` + + +```julia [Freezing All Parameters of a Layer] + +function freeze_by_name(d, ps, st, name::String) + if name == "model.layer_2.layer_1" + return Lux.freeze(d, ps, st, (:weight, :bias)) + else + return d, ps, st + end +end + +``` + + +::: + + + + +## Freezing Part of a Chain + + +Starting `v0.4.22`, we can directly index into a `Chain`. So freezing a part of a `Chain`, is extremely easy. + + +```julia +using Lux, Random + +rng = Random.default_rng() +Random.seed!(rng, 0) + +model = Chain(Dense(3, 4), Dense(4, 4), Dropout(0.5f0), BatchNorm(4), Dense(4, 1)) + +model_frozen = Chain(model[1:2], Lux.freeze(model[3:4]), model[5]) +ps, st = Lux.setup(rng, model_frozen) + +x = randn(rng, Float32, 3, 2) + +model_frozen(x, ps, st) +``` + + +``` +(Float32[1.7641534 -1.7641534], (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (frozen_params = (layer_3 = NamedTuple(), layer_4 = (scale = Float32[1.0, 1.0, 1.0, 1.0], bias = Float32[0.0, 0.0, 0.0, 0.0])), states = (layer_3 = (rng = Random.Xoshiro(0x87711e5ce1a49ffe, 0xa210b60ecab6b8c5, 0x436c749552fc8172, 0x03e9c7d813a9f096, 0x22a21880af5dc689), training = Val{true}()), layer_4 = (running_mean = Float32[-0.04517859, 0.03484953, -0.004917746, 0.0074841487], running_var = Float32[0.94082206, 0.92428976, 0.90048367, 0.90112025], training = Val{true}()))), layer_4 = NamedTuple())) +``` + diff --git a/previews/PR474/manual/gpu_management.md b/previews/PR474/manual/gpu_management.md new file mode 100644 index 000000000..244029eb5 --- /dev/null +++ b/previews/PR474/manual/gpu_management.md @@ -0,0 +1,139 @@ + + + +# GPU Management + + +::: info + + +Starting from `v0.5`, Lux has transitioned to a new GPU management system. The old system using `cpu` and `gpu` functions is still in place but will be removed in `v0.6`. Using the old functions might lead to performance regressions if used inside performance critical code. + + +::: + + +`Lux.jl` can handle multiple GPU backends. Currently, the following backends are supported: + + +```julia +using Lux, LuxCUDA, LuxAMDGPU # Important to load trigger packages + +supported_gpu_backends() +``` + + +``` +("CUDA", "AMDGPU", "Metal") +``` + + +::: danger Metal Support + + +Support for Metal GPUs should be considered extremely experimental at this point. + + +::: + + + + +## Automatic Backend Management (Recommended Approach) + + +Automatic Backend Management is done by two simple functions: `cpu_device` and `gpu_device`. + + +1. [`cpu_device`](../api/Accelerator_Support/LuxDeviceUtils#LuxDeviceUtils.cpu_device): This is a simple function and just returns a `LuxCPUDevice` object. + + +```julia +cdev = cpu_device() +``` + + +``` +(::LuxCPUDevice) (generic function with 5 methods) +``` + + +```julia +x_cpu = randn(Float32, 3, 2) +``` + + +``` +3×2 Matrix{Float32}: + 0.433884 0.229779 + -0.459193 -1.95972 + -0.541064 -1.40102 +``` + + +2. [`gpu_device`](../api/Accelerator_Support/LuxDeviceUtils#LuxDeviceUtils.gpu_device): This function performs automatic GPU device selection and returns an object. + + 1. If no GPU is available, it returns a `LuxCPUDevice` object. + 2. If a LocalPreferences file is present, then the backend specified in the file is used. To set a backend, use `Lux.gpu_backend!()`. + + 1. If the trigger package corresponding to the device is not loaded, then a warning is displayed. + 2. If no LocalPreferences file is present, then the first working GPU with loaded trigger package is used. + + +```julia +gdev = gpu_device() + +x_gpu = x_cpu |> gdev +``` + + +``` +3×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}: + 0.433884 0.229779 + -0.459193 -1.95972 + -0.541064 -1.40102 +``` + + +```julia +(x_gpu |> cdev) ≈ x_cpu +``` + + +``` +true +``` + + + + +## Manual Backend Management + + +Automatic Device Selection can be circumvented by directly using `LuxCPUDevice` and `AbstractLuxGPUDevice` objects. + + +```julia +cdev = LuxCPUDevice() + +x_cpu = randn(Float32, 3, 2) + +if LuxCUDA.functional() + gdev = LuxCUDADevice() + x_gpu = x_cpu |> gdev +elseif LuxAMDGPU.functional() + gdev = LuxAMDGPUDevice() + x_gpu = x_cpu |> gdev +else + @info "No GPU is available. Using CPU." + x_gpu = x_cpu +end + +(x_gpu |> cdev) ≈ x_cpu +``` + + +``` +true +``` + diff --git a/previews/PR474/manual/interface.md b/previews/PR474/manual/interface.md new file mode 100644 index 000000000..8648568bd --- /dev/null +++ b/previews/PR474/manual/interface.md @@ -0,0 +1,297 @@ + + + +# Lux Interface + + +:::tip + + +If you just want to define compatibility with Lux without actually using any of the other functionality provided by Lux (like layers), it is recommended to depend on `LuxCore.jl` instead of `Lux.jl`. `LuxCore.jl` is a significantly lighter dependency. + + +::: + + +First let's set the expectations straight. + + + * Do you **have to** follow the interface? *No*. + * **Should you** follow it? *Probably yes*. + * **Why?** It provides the ability for frameworks built on top of Lux to be cross compatible. Additionally, any new functionality built into Lux, will just work for your framework. + + +::: warning + + +The interface is optional for frameworks being developed independent of Lux. All functionality in the core library (and officially supported ones) **must** adhere to the interface + + +::: + + + + +## Layer Interface + + + + +### Singular Layer + + +If the layer doesn't contain any other Lux layer, then it is a `Singular Layer`. This means it should optionally subtype `Lux.AbstractExplicitLayer` but mandatorily define all the necessary functions mentioned in the docstrings. Consider a simplified version of [`Dense`](../api/Lux/layers#Lux.Dense) called `Linear`. + + +First, setup the architectural details for this layer. Note, that the architecture doesn't contain any mutable structure like arrays. When in doubt, remember, once constructed a model architecture cannot change. + + +::: tip + + +For people coming from Flux.jl background this might be weird. We recommend checking out [the Flux to Lux migration guide](migrate_from_flux) first before proceeding. + + +::: + + +```julia +using Lux, Random + +struct Linear{F1, F2} <: Lux.AbstractExplicitLayer + in_dims::Int + out_dims::Int + init_weight::F1 + init_bias::F2 +end + +function Linear(in_dims::Int, out_dims::Int; init_weight=Lux.glorot_uniform, + init_bias=Lux.zeros32) + return Linear{typeof(init_weight), typeof(init_bias)}(in_dims, out_dims, init_weight, + init_bias) +end + +l = Linear(2, 4) +``` + + +``` +Linear() +``` + + +Next, we need to implement functions which return the parameters and states for the layer. In case of `Linear`, the parameters are `weight` and `bias` while the states are empty. States become important when defining layers like [`BatchNorm`](../api/Lux/layers#Lux.BatchNorm), [`WeightNorm`](../api/Lux/layers#Lux.WeightNorm), etc. The recommended data structure for returning parameters is a NamedTuple, though anything satisfying the [Parameter Interface](#parameter-interface) is valid. + + +```julia +function Lux.initialparameters(rng::AbstractRNG, l::Linear) + return (weight=l.init_weight(rng, l.out_dims, l.in_dims), + bias=l.init_bias(rng, l.out_dims, 1)) +end + +Lux.initialstates(::AbstractRNG, ::Linear) = NamedTuple() +``` + + +You could also implement `Lux.parameterlength` and `Lux.statelength` to prevent wasteful reconstruction of the parameters and states. + + +```julia +# This works +println("Parameter Length: ", Lux.parameterlength(l), "; State Length: ", + Lux.statelength(l)) + +# But still recommened to define these +Lux.parameterlength(l::Linear) = l.out_dims * l.in_dims + l.out_dims + +Lux.statelength(::Linear) = 0 +``` + + +``` +Parameter Length: 12; State Length: 0 +``` + + +::: tip + + +You might notice that we don't pass in a `PRNG` for these functions. If your parameter length and/or state length depend on a random number generator, you should think **really hard** about what you are trying to do and why. + + +::: + + +Now, we need to define how the layer works. For this you make your layer a function with exactly 3 arguments – `x` the input, `ps` the parameters, and `st` the states. This function must return two things – `y` the output, and `st_new` the updated state. + + +```julia +function (l::Linear)(x::AbstractMatrix, ps, st::NamedTuple) + y = ps.weight * x .+ ps.bias + return y, st +end +``` + + +Finally, let's run this layer. If you have made this far into the documentation, we don't feel you need a refresher on that. + + +```julia +rng = Random.default_rng() +Random.seed!(rng, 0) + +ps, st = Lux.setup(rng, l) + +println("Parameter Length: ", Lux.parameterlength(l), "; State Length: ", + Lux.statelength(l)) + +x = randn(rng, Float32, 2, 1) + +Lux.apply(l, x, ps, st) # or `l(x, ps, st)` +``` + + +``` +(Float32[-0.15276335; 0.45325348; 1.0207279; 0.78226817;;], NamedTuple()) +``` + + + + +### Container Layer + + +If your layer comprises of other Lux layers, then it is a `Container Layer`. Note that you could treat it as a [`Singular Layer`](#singular-layer), and it is still fine. FWIW, if you cannot subtype your layer with `Lux.AbstractExplicitContainerLayer` then you should go down the [`Singular Layer`](#singular-layer) route. But subtyping allows us to bypass some of these common definitions. Let us now define a layer, which is basically a composition of two linear layers. + + +```julia +struct ComposedLinear{L1, L2} <: Lux.AbstractExplicitContainerLayer{(:linear_1, :linear_2)} + linear_1::L1 + linear_2::L2 +end + +function (cl::ComposedLinear)(x::AbstractMatrix, ps, st::NamedTuple) + # To access the parameters and states for `linear_1` we do `ps.linear_1` and + # `st.linear_1`. Similarly for `linear_2` + y, st_l1 = cl.linear_1(x, ps.linear_1, st.linear_1) + y, st_l2 = cl.linear_2(y, ps.linear_2, st.linear_2) + # Finally, we need to return the new state which has the exact structure as `st` + return y, (linear_1 = st_l1, linear_2 = st_l2) +end +``` + + +Here, you will notice we have passed `(:linear_1, :linear_2)` to the supertype. It essentially informs the type that, `.linear_1` and `.linear_2` are Lux layers and we need to construct parameters and states for those. Let's construct these and see: + + +```julia +model = ComposedLinear(Linear(2, 4), Linear(4, 2)) +display(model) + +ps, st = Lux.setup(rng, model) + +println("Parameters: ", ps) +println("States: ", st) + +println("Parameter Length: ", Lux.parameterlength(model), "; State Length: ", + Lux.statelength(model)) + +x = randn(rng, Float32, 2, 1) + +Lux.apply(model, x, ps, st) # or `model(x, ps, st)` +``` + + +``` +(Float32[1.3410565; 0.78000563;;], (linear_1 = NamedTuple(), linear_2 = NamedTuple())) +``` + + + + +## Parameter Interface + + +We accept any parameter type as long as we can fetch the parameters using `getproperty(obj, :parameter_name)`. This allows us to simultaneously support `NamedTuple`s and `ComponentArray`s. Let us go through a concrete example of what it means. Consider [`Dense`](../api/Lux/layers#Lux.Dense) which expects two parameters named `weight` and `bias`. + + +::: info + + +If you are defining your own parameter type, it is your responsibility to make sure that it works with the AutoDiff System you are using. + + +::: + + +```julia +using Lux, Random + +d = Dense(2, 3) +rng = Random.default_rng() +Random.seed!(rng, 0) + +ps_default, st = Lux.setup(rng, d) + +x = randn(rng, Float32, 2, 1) + +println("Result with `NamedTuple` parameters: ", first(d(x, ps_default, st))) +``` + + +``` +Result with `NamedTuple` parameters: Float32[1.135916; 0.7668784; -1.0876652;;] +``` + + +Let, us define a custom parameter type with fields `myweight` and `mybias` but if we try to access `weight` we get back `myweight`, similar for `bias`. + + +::: warning + + +This is for demonstrative purposes, don't try this at home! + + +::: + + +```julia +struct DenseLayerParameters{W, B} + myweight::W + mybias::B +end + +function Base.getproperty(ps::DenseLayerParameters, x::Symbol) + if x == :weight + return getfield(ps, :myweight) + elseif x == :bias + return getfield(ps, :mybias) + end + return getfield(ps, x) +end + +ps = DenseLayerParameters(ps_default.weight, ps_default.bias) + +println("Result with `DenseLayerParameters` parameters: ", first(d(x, ps, st))) +``` + + +``` +Result with `DenseLayerParameters` parameters: Float32[1.135916; 0.7668784; -1.0876652;;] +``` + + +The takeaway from this shouldn't be – *lets define weird parameter types*. Simply because you can do weird things like this doesn't mean you should, since it only leads to bugs. + + +Instead this shows the flexibility you have for how your parameters can be structured. + + + + +## State Interface + + +States are always type constrained to be `NamedTuple`. The structure of the input state **must** match that of the output state, i.e. `keys(st_in) == keys(st_out)`. This doesn't imply that types of the input and output state match. To generate efficient code, we often do dispatch on the state, for example, [`Dropout`](../api/Lux/layers#Lux.Dropout), [`BatchNorm`](../api/Lux/layers#Lux.BatchNorm), etc. + diff --git a/previews/PR474/manual/migrate_from_flux.md b/previews/PR474/manual/migrate_from_flux.md new file mode 100644 index 000000000..dfdf56efa --- /dev/null +++ b/previews/PR474/manual/migrate_from_flux.md @@ -0,0 +1,176 @@ + + + +# Migrating from Flux to Lux + + +For the core library layers like [`Dense`](../api/Lux/layers#Lux.Dense), [`Conv`](../api/Lux/layers#Lux.Conv), etc. we have intentionally kept the API very similar to Flux. In most cases, replacing `using Flux` with `using Lux` should be enough to get you started. We cover the additional changes that you will have to make in the following example. + + +:::code-group + + +```julia{1,7,9,11} [Lux] +using Lux, Random, NNlib, Zygote + +model = Chain(Dense(2 => 4), BatchNorm(4, relu), Dense(4 => 2)) +rng = Random.default_rng() +x = randn(rng, Float32, 2, 4) + +ps, st = Lux.setup(rng, model) + +model(x, ps, st) + +gradient(ps -> sum(first(model(x, ps, st))), ps) +``` + + +```julia [Flux] +using Flux, Random, NNlib, Zygote + +model = Chain(Dense(2 => 4), BatchNorm(4, relu), Dense(4 => 2)) +rng = Random.default_rng() +x = randn(rng, Float32, 2, 4) + + + +model(x) + +gradient(model -> sum(model(x)), model) +``` + + +::: + + + + +## Implementing Custom Layers + + +Flux and Lux operate under extremely different design philosophies regarding how layers should be implemented. A summary of the differences would be: + + + * Flux stores everything in a single struct and relies on `Functors.@functor` and `Flux.trainable` to distinguish between trainable and non-trainable parameters. + * Lux relies on the user to define `Lux.initialparameters` and `Lux.initialstates` to distinguish between trainable parameters (called "parameters") and non-trainable parameters (called "states"). Additionally, Lux layers define the model architecture, hence device transfer utilities like [`gpu_device`](../api/Accelerator_Support/LuxDeviceUtils#LuxDeviceUtils.gpu_device), [`cpu_device`](../api/Accelerator_Support/LuxDeviceUtils#LuxDeviceUtils.cpu_device), etc. cannot be applied on Lux layers, instead they need to be applied on the parameters and states. + + +Let's work through a concrete example to demonstrate this. We will implement a very simple layer that computes $A \times B \times x$ where $A$ is not trainable and $B$ is trainable. + + +:::code-group + + +```julia [Lux] +using Lux, Random, NNlib, Zygote + +struct LuxLinear <: Lux.AbstractExplicitLayer + init_A + init_B +end + +function LuxLinear(A::AbstractArray, B::AbstractArray) + # Storing Arrays or any mutable structure inside a Lux Layer is not recommended + # instead we will convert this to a function to perform lazy initialization + return LuxLinear(() -> copy(A), () -> copy(B)) +end + +# `B` is a parameter +Lux.initialparameters(::AbstractRNG, layer::LuxLinear) = (B=layer.init_B(),) + +# `A` is a state +Lux.initialstates(::AbstractRNG, layer::LuxLinear) = (A=layer.init_A(),) + +(l::LuxLinear)(x, ps, st) = st.A * ps.B * x, st +``` + + +```julia [Flux] +using Flux, Random, NNlib, Zygote, Optimisers + +struct FluxLinear + A + B +end + + + + + + + +# `A` is not trainable +Optimisers.trainable(f::FluxLinear) = (B=f.B,) + +# Needed so that both `A` and `B` can be transfered between devices +Flux.@functor FluxLinear + +(l::FluxLinear)(x) = l.A * l.B * x +``` + + +::: + + +Now let us run the model. + + +:::code-group + + +```julia{2,5,7,9} [Lux] +rng = Random.default_rng() +model = LuxLinear(randn(rng, 2, 4), randn(rng, 4, 2)) +x = randn(rng, 2, 1) + +ps, st = Lux.setup(rng, model) + +model(x, ps, st) + +gradient(ps -> sum(first(model(x, ps, st))), ps) +``` + + +```julia [Flux] +rng = Random.default_rng() +model = FluxLinear(randn(rng, 2, 4), randn(rng, 4, 2)) +x = randn(rng, 2, 1) + + + +model(x) + +gradient(model -> sum(model(x)), model) +``` + + +::: + + +To reiterate some important points: + + + * Don't store mutables like Arrays inside a Lux Layer. + * Parameters and States should be constructured inside the respective `initial*` functions. + + + + +## Certain Important Implementation Details + + + + +### Training/Inference Mode + + +Flux supports a mode called `:auto` which automatically decides if the user is training the model or running inference. This is the default mode for `Flux.BatchNorm`, `Flux.GroupNorm`, `Flux.Dropout`, etc. Lux doesn't support this mode (specifically to keep code simple and do exactly what the user wants), hence our default mode is `training`. This can be changed using `Lux.testmode`. + + + + +## Can we still use Flux Layers? + + +If you have `Flux` loaded in your code, you can use the function [`Lux.transform`](../api/Lux/flux_to_lux#Lux.transform) to automatically convert your model to `Lux`. Note that in case a native Lux counterpart isn't available, we fallback to using `Optimisers.destructure`. + diff --git a/previews/PR474/manual/weight_initializers.md b/previews/PR474/manual/weight_initializers.md new file mode 100644 index 000000000..15f7fdf05 --- /dev/null +++ b/previews/PR474/manual/weight_initializers.md @@ -0,0 +1,142 @@ + + + +# Initializing Weights + + +`WeightInitializers.jl` provides common weight initialization schemes for deep learning models. + + +```julia +using WeightInitializers, Random + +# Fixing rng +rng = Random.MersenneTwister(42) +``` + + +``` +Random.MersenneTwister(42) +``` + + +```julia +# Explicit rng call +weights = kaiming_normal(rng, 2, 5) +``` + + +``` +2×5 Matrix{Float32}: + -0.351662 0.0171745 1.12442 -0.296372 -1.67094 + -0.281053 -0.18941 -0.724099 0.0987538 0.634549 +``` + + +```julia +# Default rng call +weights = kaiming_normal(2, 5) +``` + + +``` +2×5 Matrix{Float32}: + -0.227513 -0.265372 0.265788 1.29955 -0.192836 + 0.687611 0.454679 -0.433656 0.20548 0.292002 +``` + + +```julia +# Passing kwargs (if needed) with explicit rng call +weights_cl = kaiming_normal(rng; gain=1.0) +weights = weights_cl(2, 5) +``` + + +``` +2×5 Matrix{Float64}: + 0.484056 0.231723 0.164379 0.306147 0.18365 + 0.0836414 0.666965 -0.396323 -0.711329 -0.382971 +``` + + +```julia +# Passing kwargs (if needed) with default rng call +weights_cl = kaiming_normal(; gain=1.0) +weights = weights_cl(2, 5) +``` + + +``` +2×5 Matrix{Float64}: + -0.160876 -0.187646 0.18794 0.918918 -0.136356 + 0.486214 0.321506 -0.306641 0.145296 0.206476 +``` + + +To generate weights directly on GPU, pass in a `CUDA.RNG`. (Note that this is currently implemented only for NVIDIA GPUs) + + +```julia +using LuxCUDA + +weights = kaiming_normal(CUDA.default_rng(), 2, 5) +``` + + +``` +2×5 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}: + -0.631371 -0.373283 0.474398 -0.14462 0.790993 + 0.758608 -1.03315 -0.194163 -0.273136 0.885854 +``` + + +You can also generate Complex Numbers: + + +```julia +weights = kaiming_normal(CUDA.default_rng(), ComplexF32, 2, 5) +``` + + +``` +2×5 CuArray{ComplexF32, 2, CUDA.Mem.DeviceBuffer}: + -0.0365544-0.551206im -0.0220487+0.0649364im … 0.0018996+0.282734im + 0.0246539+0.203026im -0.682422-0.0257335im -0.686471-0.470935im +``` + + + + +## Quick examples + + +The package is meant to be working with deep learning libraries such as (F)Lux. All the methods take as input the chosen `rng` type and the dimension for the array. + + +```julia +weights = init(rng, dims...) +``` + + +The `rng` is optional, if not specified a default one will be used. + + +```julia +weights = init(dims...) +``` + + +If there is the need to use keyword arguments the methods can be called with just the `rng` (optionally) and the keywords to get in return a function behaving like the two examples above. + + +```julia +weights_init = init(rng; kwargs...) +weights = weights_init(rng, dims...) + +# Or + +weights_init = init(; kwargs...) +weights = weights_init(dims...) +``` + diff --git a/previews/PR474/siteinfo.js b/previews/PR474/siteinfo.js new file mode 100644 index 000000000..9b4476e9c --- /dev/null +++ b/previews/PR474/siteinfo.js @@ -0,0 +1 @@ +var DOCUMENTER_CURRENT_VERSION = "previews/PR474"; diff --git a/previews/PR474/tutorials/advanced/1_GravitationalWaveForm-24.png b/previews/PR474/tutorials/advanced/1_GravitationalWaveForm-24.png new file mode 100644 index 000000000..48e43e6dd Binary files /dev/null and b/previews/PR474/tutorials/advanced/1_GravitationalWaveForm-24.png differ diff --git a/previews/PR474/tutorials/advanced/1_GravitationalWaveForm-35.png b/previews/PR474/tutorials/advanced/1_GravitationalWaveForm-35.png new file mode 100644 index 000000000..ab2b34728 Binary files /dev/null and b/previews/PR474/tutorials/advanced/1_GravitationalWaveForm-35.png differ diff --git a/previews/PR474/tutorials/advanced/1_GravitationalWaveForm-48.png b/previews/PR474/tutorials/advanced/1_GravitationalWaveForm-48.png new file mode 100644 index 000000000..0fbe34ac6 Binary files /dev/null and b/previews/PR474/tutorials/advanced/1_GravitationalWaveForm-48.png differ diff --git a/previews/PR474/tutorials/advanced/1_GravitationalWaveForm-50.png b/previews/PR474/tutorials/advanced/1_GravitationalWaveForm-50.png new file mode 100644 index 000000000..931076e60 Binary files /dev/null and b/previews/PR474/tutorials/advanced/1_GravitationalWaveForm-50.png differ diff --git a/previews/PR474/tutorials/advanced/1_GravitationalWaveForm.md b/previews/PR474/tutorials/advanced/1_GravitationalWaveForm.md new file mode 100644 index 000000000..30773a7b4 --- /dev/null +++ b/previews/PR474/tutorials/advanced/1_GravitationalWaveForm.md @@ -0,0 +1,547 @@ + + + + + +# Training a Neural ODE to Model Gravitational Waveforms + + +This code is adapted from [Astroinformatics/ScientificMachineLearning](https://github.com/Astroinformatics/ScientificMachineLearning/blob/c93aac3a460d70b4cce98836b677fd9b732e94b7/neuralode_gw.ipynb) + + +The code has been minimally adapted from [Keith et. al. 2021](https://arxiv.org/abs/2102.12695) which originally used Flux.jl + + + + +## Package Imports + + +```julia +using Lux, ComponentArrays, LineSearches, LuxAMDGPU, LuxCUDA, OrdinaryDiffEq, + Optimization, OptimizationOptimJL, Random, SciMLSensitivity +using CairoMakie, MakiePublication +CUDA.allowscalar(false) +``` + + + + +## Define some Utility Functions + + +::: tip + + +This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant. + + +::: + + +We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector $r = r_1 - r_2$ and use Newtonian formulas to get $r_1$, $r_2$ (e.g. Theoretical Mechanics of Particles and Continua 4.3) + + +```julia +function one2two(path, m₁, m₂) + M = m₁ + m₂ + r₁ = m₂ / M .* path + r₂ = -m₁ / M .* path + return r₁, r₂ +end +``` + + +``` +one2two (generic function with 1 method) +``` + + +Next we define a function to perform the change of variables: $(\chi(t),\phi(t)) \mapsto (x(t),y(t))$ + + +```julia +@views function soln2orbit(soln, model_params=nothing) + @assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4" + + if size(soln, 1) == 2 + χ = soln[1, :] + ϕ = soln[2, :] + + @assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2" + p, M, e = model_params + else + χ = soln[1, :] + ϕ = soln[2, :] + p = soln[3, :] + e = soln[4, :] + end + + r = p ./ (1 .+ e .* cos.(χ)) + x = r .* cos.(ϕ) + y = r .* sin.(ϕ) + + orbit = vcat(x', y') + return orbit +end +``` + + +``` +soln2orbit (generic function with 2 methods) +``` + + +This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0 + + +```julia +function d_dt(v::AbstractVector, dt) + a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3] + b = (v[3:end] .- v[1:(end - 2)]) / 2 + c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2] + return [a; b; c] / dt +end +``` + + +``` +d_dt (generic function with 1 method) +``` + + +This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0 + + +```julia +function d2_dt2(v::AbstractVector, dt) + a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4] + b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end] + c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3] + return [a; b; c] / (dt^2) +end +``` + + +``` +d2_dt2 (generic function with 1 method) +``` + + +Now we define a function to compute the trace-free moment tensor from the orbit + + +```julia +function orbit2tensor(orbit, component, mass=1) + x = orbit[1, :] + y = orbit[2, :] + + Ixx = x .^ 2 + Iyy = y .^ 2 + Ixy = x .* y + trace = Ixx .+ Iyy + + if component[1] == 1 && component[2] == 1 + tmp = Ixx .- trace ./ 3 + elseif component[1] == 2 && component[2] == 2 + tmp = Iyy .- trace ./ 3 + else + tmp = Ixy + end + + return mass .* tmp +end + +function h_22_quadrupole_components(dt, orbit, component, mass=1) + mtensor = orbit2tensor(orbit, component, mass) + mtensor_ddot = d2_dt2(mtensor, dt) + return 2 * mtensor_ddot +end + +function h_22_quadrupole(dt, orbit, mass=1) + h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass) + h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass) + h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass) + return h11, h12, h22 +end + +function h_22_strain_one_body(dt::T, orbit) where {T} + h11, h12, h22 = h_22_quadrupole(dt, orbit) + + h₊ = h11 - h22 + hₓ = T(2) * h12 + + scaling_const = √(T(π) / 5) + return scaling_const * h₊, -scaling_const * hₓ +end + +function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2) + h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1) + h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2) + h11 = h11_1 + h11_2 + h12 = h12_1 + h12_2 + h22 = h22_1 + h22_2 + return h11, h12, h22 +end + +function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T} + # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2 + + @assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity" + + h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2) + + h₊ = h11 - h22 + hₓ = T(2) * h12 + + scaling_const = √(T(π) / 5) + return scaling_const * h₊, -scaling_const * hₓ +end + +function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T} + @assert mass_ratio≤1 "mass_ratio must be <= 1" + @assert mass_ratio≥0 "mass_ratio must be non-negative" + + orbit = soln2orbit(soln, model_params) + if mass_ratio > 0 + m₂ = inv(T(1) + mass_ratio) + m₁ = mass_ratio * m₂ + + orbit₁, orbit₂ = one2two(orbit, m₁, m₂) + waveform = h_22_strain_two_body(dt, orbit1, mass1, orbit2, mass2) + else + waveform = h_22_strain_one_body(dt, orbit) + end + return waveform +end +``` + + +``` +compute_waveform (generic function with 2 methods) +``` + + + + +## Simulating the True Model + + +`RelativisticOrbitModel` defines system of odes which describes motion of point like particle in schwarzschild background, uses + + +$$ +u[1] = \chi +$$ + + +$$ +u[2] = \phi +$$ + + +where, $p$, $M$, and $e$ are constants + + +```julia +function RelativisticOrbitModel(u, (p, M, e), t) + χ, ϕ = u + + numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2 + denom = sqrt((p - 2)^2 - 4 * e^2) + + χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom) + ϕ̇ = numer / (M * (p^(3 / 2)) * denom) + + return [χ̇, ϕ̇] +end + +mass_ratio = 0.0 # test particle +u0 = Float64[π, 0.0] # initial conditions +datasize = 250 +tspan = (0.0f0, 6.0f4) # timespace for GW waveform +tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep +dt_data = tsteps[2] - tsteps[1] +dt = 100.0 +const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e +``` + + +Let's simulate the true model and plot the results using `OrdinaryDiffEq.jl` + + +```julia +prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params) +soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false)) +waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params)) + +fig = with_theme(theme_web()) do + fig = Figure() + ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform") + + l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75) + s = scatter!(ax, tsteps, waveform; markershape=:circle, markeralpha=0.25, alpha=0.5) + + axislegend(ax, [[l, s]], ["Waveform Data"]) + + return fig +end +``` + + +![](1_GravitationalWaveForm-24.png) + + + + +## Defiing a Neural Network Model + + +Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function `ODE_model` that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives. + + +It is typically never recommended to use globals but incase you do use them, make sure to mark them as `const`. + + +We will deviate from the standard Neural Network initialization and use `WeightInitializers.jl`, + + +```julia +const nn = Chain(Base.Fix1(broadcast, cos), + Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4)), + Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4)), + Dense(32 => 2; init_weight=truncated_normal(; std=1e-4))) +ps, st = Lux.setup(MersenneTwister(), nn) +``` + + +``` +((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-3.7534275f-5; 9.587293f-5; -0.00020307681; 2.7310172f-5; 4.4802062f-5; -9.994351f-6; -2.9701485f-5; -8.724159f-5; -0.00010881431; -9.6940414f-5; -4.375306f-6; -0.000110954315; 7.387934f-5; -2.761209f-5; 0.00011043136; -0.00015784007; -1.1113592f-5; 6.388433f-5; -0.00014325912; -4.5682064f-5; 7.1291244f-5; -5.9531292f-5; 9.476896f-5; 0.0001966393; 5.541161f-6; -0.0001697429; -0.00022925164; -4.5089957f-5; -3.81126f-5; 1.2938472f-5; -3.3968197f-5; -7.0052585f-5;;], bias = Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;]), layer_3 = (weight = Float32[6.761105f-5 -0.000113901915 8.964591f-6 -0.00016116614 0.00013880015 -0.00013879138 -3.554107f-5 6.059696f-5 -0.00024709766 -0.00014070104 -8.979313f-5 -1.4413713f-5 -9.936666f-5 -9.672333f-5 -7.263863f-5 5.8567486f-5 -2.8290946f-5 -6.729768f-5 -8.4595085f-5 -3.252451f-5 1.4338686f-5 -1.1622614f-5 2.6112843f-5 -0.00012211224 -3.0702406f-5 -0.00019466369 -8.263866f-5 3.46534f-5 -7.399809f-6 -0.00013036455 0.00013129906 -0.00010237028; -0.00019548097 2.6739317f-5 0.00011322644 1.7645018f-5 6.908265f-5 -0.0002837229 8.1247075f-5 0.00011636492 9.217035f-5 -3.6968945f-6 -5.6754816f-5 8.0914666f-5 -9.703333f-6 8.010135f-5 2.8915832f-5 4.7849626f-5 7.918131f-5 -0.00017794414 0.00012016945 0.000110826164 0.0001733935 -7.307795f-6 -4.5871067f-5 0.00021210038 0.00013693089 1.3263791f-5 0.00014216152 3.7456666f-5 -7.8651516f-5 -5.6846176f-5 1.9444027f-5 9.6039425f-5; -2.5985977f-5 0.00010558587 2.654133f-6 -3.118091f-5 8.02511f-5 -3.631959f-5 0.00014321286 2.0998217f-5 -2.6606238f-5 5.4152366f-5 -4.246607f-6 1.0811776f-5 -8.9289614f-7 -9.6984455f-5 0.00010148475 8.050631f-5 -0.00011599769 1.6067202f-7 -8.368298f-5 7.586512f-5 -0.00021594488 6.960428f-5 -7.644235f-5 4.5867848f-5 4.6067846f-5 -0.00015519273 4.200408f-5 -0.00015001697 -5.4830394f-5 -9.761585f-5 5.1612f-5 -4.7108184f-5; -4.9468103f-6 6.100778f-5 0.0002248902 2.2481365f-5 -5.9164282f-5 -0.00016892838 2.9950816f-5 -5.500141f-5 -6.84253f-5 -5.7105f-5 -1.5854755f-5 8.887545f-5 -0.00014754037 -0.00017234501 -2.125293f-5 3.7674336f-6 -5.5900364f-5 0.00016250162 -6.581971f-5 3.491841f-5 5.611013f-5 -1.3222975f-5 -5.9341968f-5 -0.00012054232 4.1396073f-5 7.07468f-5 0.00019461343 0.00013953108 -6.8784415f-5 -0.00014103648 6.211121f-5 -1.0979025f-5; -0.00016269046 -4.038683f-5 5.1514868f-5 -8.924401f-5 3.9883653f-5 -3.9835548f-5 -6.908117f-5 2.3048007f-5 -3.6763555f-5 -1.3746815f-5 -7.5990734f-5 3.230138f-5 -6.4971544f-5 -0.00015814423 -0.0001851258 -1.599811f-5 9.555079f-5 -6.453685f-5 -1.2443973f-5 7.117298f-5 -3.4086265f-6 -0.00019522739 2.5452942f-5 -5.498493f-7 -2.3053162f-5 0.00014901232 5.5567467f-5 -9.562596f-5 -1.8108049f-5 0.0001726859 2.7772296f-6 0.00010719969; 1.2146686f-5 -2.4243906f-5 -8.9340254f-5 9.877082f-5 -1.631962f-5 0.000116152594 0.00017797515 9.5130046f-5 0.00010179499 0.00010881183 0.00011092269 -2.900193f-5 2.9529894f-5 -0.000108132146 0.00013373278 -0.00012725154 -9.695451f-7 -9.763457f-5 0.00014091612 3.0294985f-5 -9.552999f-5 8.482511f-5 5.660853f-5 0.00012428859 -6.127698f-5 -6.0485017f-5 8.647213f-5 -7.488917f-5 -1.0679927f-5 -4.7259382f-5 6.438435f-5 -0.00018757928; -6.566568f-5 -3.3646113f-5 8.671624f-5 -4.065746f-5 -7.4712134f-6 8.047148f-5 9.7479475f-5 -9.7461554f-5 2.1854203f-5 -5.3928775f-6 1.7218212f-5 -1.1766306f-5 4.188885f-6 -3.3229382f-5 3.140321f-5 0.00018610239 1.4785526f-6 5.615193f-5 9.095312f-5 1.2546787f-5 8.373901f-5 1.3567898f-5 5.2952953f-5 0.00016013715 5.1444626f-5 -0.00013369217 -0.00011386128 -1.4321387f-5 -3.3755292f-5 0.00011034097 3.9902552f-5 -9.5185685f-5; -0.00016736936 -0.00014818972 2.6946713f-5 8.3882805f-5 4.4185267f-6 -1.32613795f-5 -1.4832713f-5 -4.6952686f-5 -6.563721f-5 -2.8683447f-5 -3.1700787f-5 4.560241f-5 -6.560982f-5 0.00014710317 0.000106385 -0.0001424362 7.224775f-6 -4.8524853f-5 -9.551184f-5 0.00019952007 -4.1362448f-5 -0.0001566128 -0.00017194882 0.0001511447 -4.503181f-5 -4.3834312f-5 5.597769f-5 -0.00016817503 -0.00020149168 0.00011917359 0.00012282288 8.879565f-5; 9.005764f-5 -0.000104279636 -0.00020861726 0.00011654795 -0.0001878115 9.170598f-5 -2.2800568f-5 7.3588126f-5 6.131765f-5 8.612518f-5 0.00011414634 -0.00015244156 8.77575f-5 0.0001032412 -7.488193f-5 9.047691f-5 -3.3409837f-5 3.9945815f-5 0.00012330696 -5.4873184f-5 5.871341f-5 -4.309193f-5 -5.435596f-5 1.6422406f-5 -8.2792885f-6 -6.490122f-5 3.1177082f-5 0.00011712389 9.975942f-6 -3.468329f-5 0.00012418012 9.5352916f-5; 1.8307346f-5 -7.4760974f-6 -9.509076f-6 0.00010720775 0.00022891539 -1.7801009f-5 0.0001134608 -4.8712274f-5 -8.340814f-5 2.086001f-5 -0.00015277331 -0.00010062263 -7.459868f-5 -4.6025452f-5 -0.00016344074 -2.9167175f-5 4.6709905f-5 8.269951f-5 -9.479973f-5 -9.0533125f-5 -6.2209045f-5 -2.27565f-5 -1.470693f-6 2.0270943f-5 -1.768672f-5 0.00014186426 -4.9610902f-5 -4.6614947f-5 6.381325f-5 4.5036508f-5 8.218586f-5 8.606068f-6; -2.3976118f-5 7.4354284f-5 -7.8912344f-5 -2.2845867f-5 -0.00019194542 7.125854f-5 -8.285065f-5 -2.3212544f-5 0.00017206058 -1.9278854f-5 -0.00011775113 9.305179f-5 -0.00017833042 -2.1000875f-5 -6.481427f-5 -5.560593f-5 3.16148f-5 0.00011925995 9.085574f-5 -2.7252207f-5 0.00013033982 -8.462668f-5 5.3033484f-5 7.3393144f-6 8.881129f-6 3.311339f-5 -1.4659923f-5 -5.8569614f-5 3.501024f-5 0.00013967518 0.0002375264 -4.2071126f-5; 2.1844135f-5 -0.0001447924 3.0929736f-5 3.3578483f-5 -5.007242f-5 -1.8200426f-5 9.270273f-5 8.062867f-5 5.490854f-6 -1.2040791f-5 -6.637152f-6 0.00017624113 -8.4803054f-5 9.32618f-5 0.00010831205 -0.00020070799 -9.2799666f-5 -0.00022796886 4.608795f-5 -7.291324f-5 2.0374477f-5 0.00010940203 -9.017371f-5 -0.000115712806 -9.2113565f-5 -1.2122845f-5 4.679737f-5 -9.038651f-5 -7.778076f-5 -3.6846304f-5 -4.9037953f-5 -9.655655f-5; -8.5796347f-7 -0.00020660971 -0.00011996289 -0.00012564834 2.5726813f-5 0.00018287414 0.00020730491 -7.0170813f-6 -1.36177805f-5 -5.657367f-5 -1.5690106f-5 -2.353855f-5 4.8210826f-5 1.9343943f-5 0.00013850816 0.000117500465 -0.00015189555 5.523898f-5 -8.703245f-5 -2.4861638f-5 2.641342f-5 -4.6409652f-5 7.861565f-5 -6.354364f-5 8.94254f-6 2.8756533f-5 -0.00010554956 2.4856276f-5 6.8546986f-5 -0.000104392624 -8.586132f-5 8.1873535f-5; 0.00021355918 0.00013151007 -0.00018185088 0.000100912715 -4.2591953f-5 8.0937134f-5 -0.00013910097 0.00012739946 -4.298582f-5 0.00023429835 0.00012029626 -0.00014725365 9.190304f-5 5.502032f-5 0.00010128233 -5.751056f-5 4.921677f-5 8.696627f-5 0.0001405872 -9.268556f-5 -2.125892f-5 -6.541491f-5 4.4591292f-5 6.439612f-5 -2.1727277f-5 -5.35682f-5 9.923696f-5 -0.00016687224 -0.00012193904 2.6427591f-5 -7.4114646f-6 1.6011429f-5; 0.00020759272 -8.2139f-5 4.552265f-5 -7.2354105f-6 0.00016132413 1.07424985f-5 2.4565574f-5 -0.00011407443 -0.000112021735 -6.37444f-5 3.0787458f-5 4.0570947f-5 1.1581369f-5 0.00015548791 -2.2667089f-5 -7.496505f-5 6.0816263f-5 8.652482f-5 3.8184073f-5 -0.00013766196 -1.2482308f-5 -0.00026049276 -1.1788181f-5 0.00010245156 -2.3329272f-5 -0.00011171854 2.6835365f-5 7.772137f-5 -0.000118710464 -0.000101394326 0.0001917172 -5.8030328f-5; -3.5241617f-5 -3.7997113f-5 -0.00015438788 -7.031436f-5 -7.369307f-6 3.530067f-5 7.216243f-5 -4.750082f-5 -0.00011452152 0.00024547402 7.06018f-5 -0.00015834396 -0.000105093444 -1.9506488f-5 0.00010290264 4.567291f-5 -0.00017570857 -1.9434161f-5 7.900747f-5 7.070915f-5 -0.00019769961 0.00012592082 -1.4219325f-5 -4.7706137f-5 -5.0051727f-5 3.8403174f-5 -4.3218497f-5 0.00017298516 9.7437005f-5 -1.1763959f-5 -0.00011181385 -0.0002501864; -5.525854f-5 1.8087301f-5 -9.965503f-5 -0.000109753404 -7.624324f-5 0.00013782918 -3.9840983f-5 3.073326f-5 -8.786233f-5 -1.8344175f-5 -4.20142f-5 0.0001769575 0.0001972421 6.1934355f-5 -4.823372f-5 4.8301703f-5 -0.00011174749 7.0448f-5 -5.9433896f-5 -6.2935724f-5 -7.4250083f-6 -3.0982716f-5 6.89731f-5 -2.1363889f-5 8.189096f-5 -7.059768f-6 0.00014192617 -0.00013615322 0.00010379173 4.200875f-5 0.00011425142 2.156763f-6; -9.368468f-5 -0.00014656328 -0.00013627374 0.00011150491 0.00014804116 1.9151714f-5 7.6833385f-5 4.686646f-5 9.218986f-5 -0.00015922933 -1.187716f-5 -3.6512647f-6 -9.1516406f-5 -9.945131f-6 -6.3246815f-5 9.124299f-5 6.3962696f-5 -6.716385f-5 3.68331f-5 5.637649f-5 6.986443f-5 8.9714056f-5 7.4751088f-6 0.000116717434 0.00014105352 -0.00011666501 0.00010081233 1.9929017f-5 -2.9267432f-5 8.425493f-5 -0.00014090218 -0.00012716456; 6.2265186f-5 -0.00015502388 0.00015362691 -7.906807f-5 -7.572171f-5 -0.00010579486 -5.746414f-6 0.00010010981 0.00018997275 1.9668566f-6 0.00010910515 0.00024385356 1.9836181f-5 -0.00021522121 -0.0001690883 -9.892282f-5 -0.00011787428 -6.1388455f-5 0.00015061039 4.0925847f-6 -0.00012566289 3.1887073f-5 -3.9325565f-5 -5.642694f-5 4.3136824f-5 2.1605616f-5 0.00013621447 -3.8678358f-5 0.0001232448 0.00031467344 3.186947f-5 6.352093f-5; 5.8051828f-5 0.00011755239 -8.630124f-5 1.7596247f-5 -9.446116f-5 -3.5806017f-5 -2.9850791f-5 6.1862585f-5 1.7284208f-5 8.436351f-5 -2.8148258f-7 -7.1910625f-5 1.1572703f-6 3.5754536f-5 -0.00011033168 3.3378274f-5 -5.1591793f-5 2.3716677f-5 -5.0667397f-5 -0.00017250435 -0.00011599581 3.9717244f-5 9.23878f-6 7.617037f-5 0.00032626803 0.00014796156 6.440196f-5 8.487586f-5 0.00010801957 -0.00014962886 -8.753361f-5 4.6541543f-5; -3.4864992f-5 0.00014489937 0.00017197175 0.00016956887 3.615395f-5 -3.607434f-5 -0.0001745673 -0.0002024635 0.00010612504 1.6606882f-5 -0.0001632898 -6.497571f-5 -2.9604376f-5 4.569699f-5 -0.00015154689 0.00016237854 1.2504684f-5 -6.640947f-5 8.5253996f-5 6.167072f-5 0.00012972928 -0.00016208076 3.521787f-5 -0.00021100706 4.670387f-6 7.0341564f-5 6.497756f-5 -3.695794f-5 4.6165223f-5 -8.646171f-5 2.5158322f-5 1.3180173f-5; 0.00010783719 -8.6098163f-7 -5.5659653f-5 0.0001018247 -6.748683f-5 -0.00013413666 6.288587f-5 3.7961687f-5 4.7855658f-5 6.634427f-6 -7.5017975f-5 6.090577f-5 6.986077f-5 -0.00019676608 -1.8029443f-5 9.899767f-5 -2.642988f-5 -4.6305846f-5 -3.6691756f-5 4.0621373f-5 5.4369924f-5 -4.23593f-6 -7.384837f-5 -7.061963f-5 -0.00012377794 0.00016422497 1.1223372f-5 5.21631f-5 -2.6631067f-5 0.00016381251 1.4793771f-5 8.166553f-5; 7.2092895f-5 -0.00018327968 -6.4941436f-5 0.00026116535 0.00013604014 2.7541879f-5 -3.083775f-5 -9.2774164f-5 0.00011816725 -5.3335218f-5 8.6334025f-5 0.00013644819 1.873067f-5 4.0324947f-5 -5.828331f-5 2.2703083f-5 -0.00021211569 0.00019917052 9.597362f-5 7.825745f-6 -6.0460334f-5 -0.00010037813 -0.00015556904 0.00011028232 -2.8457815f-5 -7.170699f-6 -4.365136f-5 6.142104f-5 -3.320914f-5 4.3353175f-5 -2.4388624f-5 0.00011741803; 6.885784f-5 -6.577162f-5 -0.00016516853 0.00015954807 -0.00020923551 -7.437655f-5 -1.8963085f-5 -7.688618f-5 0.00012938863 0.00018441291 4.5995825f-5 0.00013424407 -7.607619f-5 9.690385f-5 4.8989194f-5 0.00015271863 -5.1646348f-5 -0.00011597087 1.3516338f-6 8.9734254f-5 -1.4386297f-5 5.6206092f-5 -2.8031967f-5 9.230711f-5 -9.703975f-5 -5.702648f-5 -1.4909304f-5 0.00014520748 3.7420257f-5 -0.00015659408 3.7704387f-5 6.270015f-5; -0.00010970163 -3.1929749f-6 1.8567363f-5 8.471964f-5 9.443171f-5 -3.1292482f-5 -4.110973f-5 -6.364768f-7 9.27982f-5 -1.860573f-5 -0.00015082447 4.751226f-5 0.00022276596 -7.8337354f-5 -0.00015602796 2.79549f-5 -1.5144451f-5 7.955865f-6 -4.231437f-5 3.3795848f-5 -0.00015319066 -6.3052f-5 7.164251f-5 2.675313f-5 -5.048239f-5 -0.0002325602 -2.4063773f-5 4.3692744f-5 -2.6701227f-5 4.475248f-5 -3.3108496f-5 2.3644827f-5; 0.00012448835 -0.00015694596 -8.1813305f-5 7.419979f-5 3.8067716f-5 -9.7702214f-5 -0.00019920115 -0.00015588086 -4.4741184f-5 3.4373068f-5 -4.900018f-5 -0.0002527356 -9.57983f-5 1.4170487f-5 -4.2961175f-5 4.1811014f-5 0.00010371536 0.00010284704 0.00017538131 -0.00011807678 0.0001001662 -0.00013130253 0.00012288256 0.00012931979 0.00025033974 -9.51475f-6 -0.00021311734 -7.137152f-5 -9.0671056f-5 -5.295804f-5 -9.285775f-6 -2.0475134f-5; 4.218741f-5 -2.687855f-5 -0.00012609824 -9.970099f-5 -3.627047f-5 0.000105723666 -1.2999311f-5 0.0001025369 0.00016081029 -6.87301f-5 -0.00012459588 0.00017570004 -0.00014162355 5.8432124f-6 9.426718f-5 -3.3713f-5 -0.00010652919 -0.00012488446 0.00011061982 0.00011289351 9.947804f-6 0.00012597135 9.078777f-5 1.1399733f-7 2.1645259f-5 0.00021435291 -0.00012574777 6.754362f-5 3.791349f-5 -0.00013065062 7.3103415f-6 -1.1499429f-5; -6.103377f-5 -9.1203445f-5 6.045847f-5 7.3517396f-5 -6.528617f-5 2.2441374f-5 9.033684f-5 -5.904216f-5 -5.765813f-5 9.106579f-5 -3.1181382f-5 -0.00018200254 -6.733254f-5 0.00014207218 4.8103553f-5 -8.362824f-5 1.1691237f-5 0.00019550722 2.7018545f-5 -7.682146f-5 -0.00013286094 -0.00018943097 -9.899181f-5 0.00010191948 -6.756122f-5 6.154821f-5 6.147019f-5 0.00015826142 3.8505103f-5 -3.350976f-5 0.00020528957 0.00021367623; 5.3222087f-5 0.00018807554 -2.4660878f-6 1.9308689f-5 2.1195687f-5 -6.249333f-5 -2.5295285f-5 0.0001111647 -4.100128f-5 1.4142901f-5 -9.86955f-6 -0.00016014921 -0.0001420917 6.437262f-5 7.957116f-5 -7.50171f-5 9.5053634f-5 0.00023381195 -7.245417f-6 1.5909922f-5 -4.081567f-5 -4.8053607f-5 0.0001450196 0.000107618 -5.3550437f-5 8.168464f-5 -8.7340915f-5 -7.210665f-5 -5.252501f-5 -0.0001519233 -6.6706525f-5 -5.0764076f-5; -0.00010892246 7.490791f-5 0.00010133239 -2.7112463f-5 -2.7261603f-5 0.000107227184 -5.403155f-5 6.6421686f-5 -0.00010644121 -4.2165495f-5 0.00021430063 -1.4637f-6 9.771257f-5 0.00012267297 -3.076957f-5 -1.1607873f-5 -8.360679f-5 1.152256f-5 4.2852993f-5 -4.9386294f-5 -4.0691608f-5 -6.368005f-5 -1.1976172f-5 -0.00010869059 3.5530757f-5 5.0454222f-5 0.00010188413 6.841393f-5 -5.0241244f-5 0.00015245397 -7.895698f-5 6.0219583f-5; -5.303627f-5 -1.8745694f-5 1.6348693f-5 5.7746787f-5 -7.224715f-6 3.8287644f-6 7.0626804f-5 -2.2848812f-5 -4.1804033f-5 -8.662874f-6 -0.00015293852 2.7515593f-5 0.00012639737 4.146572f-5 -2.8078339f-5 -0.00015498533 0.00014600993 -4.7832706f-5 0.00017865714 -4.1875068f-5 5.7105175f-5 0.00022340933 0.00010140867 -3.0475547f-5 0.00013225879 5.2447023f-5 0.00012626311 -1.7756929f-5 -5.7917423f-5 -0.00010622354 -5.0128692f-5 -0.00011080242; 5.464636f-5 5.7637066f-5 0.00012274504 0.0002755762 -6.648314f-5 -0.00014265781 -7.2401606f-5 8.115587f-6 9.974197f-5 7.923513f-5 0.00019312838 6.184099f-5 8.9193956f-5 7.953824f-5 -5.2709638f-5 -1.6961545f-5 5.664859f-5 0.00022280957 6.961478f-5 0.00018504179 -2.9028608f-5 -5.85627f-5 0.000118810516 -9.372176f-5 9.829584f-6 0.00019844779 -7.7063116f-5 0.00011331276 2.7304155f-5 2.1979633f-5 -5.3504955f-5 2.5904916f-5], bias = Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;]), layer_4 = (weight = Float32[3.262928f-5 5.0363935f-5 7.980772f-5 -1.8525381f-5 -3.1385637f-6 -0.00017683936 0.00016484373 4.6893074f-5 8.1241466f-5 0.00010770842 -0.0001578596 0.00012844763 5.7723384f-5 0.00017318777 -0.00015215481 -0.00010982934 -2.3794748f-5 1.9830482f-5 3.390214f-5 3.7392983f-5 8.496735f-5 -2.1927228f-5 -3.2255768f-5 -6.152622f-6 -0.00019298044 -1.5268086f-5 2.7910408f-5 -3.7722322f-5 6.3804064f-6 0.00012652401 1.818955f-5 -3.9074104f-5; -3.0793883f-5 -7.38439f-5 5.234796f-5 -8.966429f-5 0.00010205379 -6.106228f-5 -0.00010570496 -2.4574814f-5 -7.383016f-5 0.00016824267 5.7393412f-5 -7.1197464f-5 3.4921013f-5 7.726707f-5 -5.264235f-6 -7.681448f-5 -8.816292f-5 2.2420016f-5 7.506003f-5 -0.00017724569 -6.6819106f-5 -0.00023924508 -0.00013460488 8.047367f-5 4.374219f-5 -1.9749243f-5 0.00010067088 0.00016900497 6.784794f-5 -0.000101395926 -7.199668f-5 -7.28125f-7], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple())) +``` + + +Similar to most DL frameworks, Lux defaults to using `Float32`, however, in this case we need Float64 + + +```julia +const params = ComponentArray{Float64}(ps) +``` + + +``` +ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-3.7534275179496035e-5; 9.58729287958704e-5; -0.00020307680824771523; 2.7310172299621627e-5; 4.480206189327873e-5; -9.99435087578604e-6; -2.9701484891120344e-5; -8.724159124540165e-5; -0.00010881431080633774; -9.69404136412777e-5; -4.375306161819026e-6; -0.00011095431545982137; 7.387933874269947e-5; -2.761208997981157e-5; 0.00011043136328225955; -0.00015784006973262876; -1.111359233618714e-5; 6.388432666426525e-5; -0.00014325912343338132; -4.568206350086257e-5; 7.129124423954636e-5; -5.9531292208703235e-5; 9.476896229898557e-5; 0.00019663930288515985; 5.541161044675391e-6; -0.0001697428961051628; -0.0002292516437591985; -4.508995698415674e-5; -3.81126010324806e-5; 1.2938471627421677e-5; -3.396819738554768e-5; -7.005258521530777e-5;;], bias = [0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;]), layer_3 = (weight = [6.761105032637715e-5 -0.00011390191502869129 8.964590961113572e-6 -0.00016116614278871566 0.0001388001546729356 -0.0001387913798680529 -3.554107024683617e-5 6.05969617026858e-5 -0.0002470976614858955 -0.0001407010422553867 -8.979313133750111e-5 -1.4413712960958946e-5 -9.936666174326092e-5 -9.672332816990092e-5 -7.263862789841369e-5 5.8567486121319234e-5 -2.8290945920161903e-5 -6.729768210789189e-5 -8.459508535452187e-5 -3.252450915169902e-5 1.4338686014525592e-5 -1.1622613783401903e-5 2.6112842533621006e-5 -0.00012211223656777292 -3.0702405638294294e-5 -0.00019466369121801108 -8.26386603876017e-5 3.4653399779926986e-5 -7.399808964692056e-6 -0.00013036455493420362 0.00013129906437825412 -0.00010237027890980244; -0.00019548097043298185 2.6739317036117427e-5 0.00011322643695166335 1.7645017578615807e-5 6.908264913363382e-5 -0.00028372291126288474 8.124707528622821e-5 0.00011636492126854137 9.217034676112235e-5 -3.6968945096305106e-6 -5.6754815886961296e-5 8.091466588666663e-5 -9.703332580102142e-6 8.010135206859559e-5 2.891583244490903e-5 4.784962584380992e-5 7.91813072282821e-5 -0.00017794413724914193 0.00012016944674542174 0.00011082616401836276 0.00017339350597467273 -7.307794930966338e-6 -4.5871067413827404e-5 0.00021210037812124938 0.00013693088840227574 1.326379060628824e-5 0.0001421615161234513 3.745666617760435e-5 -7.865151565056294e-5 -5.684617644874379e-5 1.9444027202553116e-5 9.603942453395575e-5; -2.59859771176707e-5 0.00010558586654951796 2.654132913448848e-6 -3.118090899079107e-5 8.02510985522531e-5 -3.631959043559618e-5 0.00014321286289487034 2.0998217223677784e-5 -2.660623795236461e-5 5.41523659194354e-5 -4.246607204549946e-6 1.0811775609909091e-5 -8.92896139248478e-7 -9.698445501271635e-5 0.00010148475121241063 8.05063100415282e-5 -0.0001159976891358383 1.6067201613623183e-7 -8.368297858396545e-5 7.586512219859287e-5 -0.00021594487770926207 6.960427708690986e-5 -7.64423530199565e-5 4.586784780258313e-5 4.6067845687502995e-5 -0.00015519272710662335 4.2004081478808075e-5 -0.00015001697465777397 -5.4830394219607115e-5 -9.761584806255996e-5 5.1612001698231325e-5 -4.7108183935051784e-5; -4.9468103497929405e-6 6.1007780459476635e-5 0.0002248902019346133 2.2481364794657566e-5 -5.91642819927074e-5 -0.00016892838175408542 2.995081558765378e-5 -5.500140832737088e-5 -6.842530274298042e-5 -5.710500045097433e-5 -1.585475547472015e-5 8.887545118341222e-5 -0.000147540369653143 -0.00017234501137863845 -2.125293030985631e-5 3.7674335544579662e-6 -5.5900363804539666e-5 0.00016250161570496857 -6.58197095617652e-5 3.491840834612958e-5 5.6110129662556574e-5 -1.3222975212556776e-5 -5.9341968153603375e-5 -0.0001205423177452758 4.139607335673645e-5 7.074679888319224e-5 0.0001946134289028123 0.00013953108282294124 -6.878441490698606e-5 -0.0001410364784533158 6.211121217347682e-5 -1.097902531910222e-5; -0.0001626904559088871 -4.0386828914051875e-5 5.1514867664081976e-5 -8.924400754040107e-5 3.9883652789285406e-5 -3.983554779551923e-5 -6.908117211423814e-5 2.3048007278703153e-5 -3.676355481729843e-5 -1.3746815056947526e-5 -7.599073433084413e-5 3.2301380997523665e-5 -6.497154390672222e-5 -0.0001581442338647321 -0.00018512579845264554 -1.59981100296136e-5 9.555079304846004e-5 -6.453684909502044e-5 -1.2443972991604824e-5 7.117298082448542e-5 -3.408626525924774e-6 -0.00019522738875821233 2.5452942281845026e-5 -5.498492896549578e-7 -2.3053162294672802e-5 0.00014901232498232275 5.556746691581793e-5 -9.562596096657217e-5 -1.810804860724602e-5 0.00017268590454477817 2.777229610728682e-6 0.00010719968850025907; 1.2146685548941605e-5 -2.4243905500043184e-5 -8.934025390772149e-5 9.877081902232021e-5 -1.6319620044669136e-5 0.00011615259427344427 0.00017797514738049358 9.513004624750465e-5 0.0001017949907691218 0.00010881182970479131 0.00011092268687207252 -2.9001930670347065e-5 2.9529894163715653e-5 -0.00010813214612426236 0.0001337327848887071 -0.00012725153646897525 -9.695451126390253e-7 -9.763456910150126e-5 0.00014091611956246197 3.0294984753709286e-5 -9.552999108564109e-5 8.482510747853667e-5 5.660852912114933e-5 0.00012428859190549701 -6.1276979977265e-5 -6.0485017456812784e-5 8.64721296238713e-5 -7.488916889997199e-5 -1.0679927072487772e-5 -4.725938197225332e-5 6.438435229938477e-5 -0.00018757928046397865; -6.566567753907293e-5 -3.3646112569840625e-5 8.671623800182715e-5 -4.0657461795490235e-5 -7.471213393728249e-6 8.04714800324291e-5 9.747947478899732e-5 -9.746155410539359e-5 2.1854202714166604e-5 -5.392877483245684e-6 1.721821172395721e-5 -1.176630576082971e-5 4.188885213807225e-6 -3.322938209748827e-5 3.1403211323777214e-5 0.00018610239203553647 1.4785525763727492e-6 5.615193003905006e-5 9.095312270801514e-5 1.2546786820166744e-5 8.373901073355228e-5 1.3567898349720053e-5 5.2952953410567716e-5 0.00016013714775908738 5.144462556927465e-5 -0.00013369217049330473 -0.00011386127880541608 -1.4321387425297871e-5 -3.3755291951820254e-5 0.00011034096678486094 3.990255208918825e-5 -9.518568549538031e-5; -0.00016736936231609434 -0.0001481897197663784 2.6946712750941515e-5 8.388280548388138e-5 4.418526714289328e-6 -1.326137953583384e-5 -1.4832712622592226e-5 -4.695268580690026e-5 -6.563720671692863e-5 -2.8683447453659028e-5 -3.1700787076260895e-5 4.5602409954881296e-5 -6.560982001246884e-5 0.00014710317191202193 0.00010638499952619895 -0.00014243619807530195 7.224774890346453e-6 -4.8524852900300175e-5 -9.55118375713937e-5 0.0001995200727833435 -4.13624475186225e-5 -0.000156612804858014 -0.00017194882093463093 0.00015114470443222672 -4.5031811168882996e-5 -4.3834312236867845e-5 5.59776890440844e-5 -0.00016817502910271287 -0.00020149168267380446 0.00011917358642676845 0.00012282287934795022 8.879564848029986e-5; 9.005764150060713e-5 -0.00010427963570691645 -0.00020861726079601794 0.00011654794798232615 -0.00018781149992719293 9.170598059426993e-5 -2.2800568331149407e-5 7.35881258151494e-5 6.131765258032829e-5 8.612518286099657e-5 0.00011414634354878217 -0.00015244155656546354 8.775750029599294e-5 0.00010324119648430496 -7.488192932214588e-5 9.047691128216684e-5 -3.340983676025644e-5 3.994581493316218e-5 0.00012330696335993707 -5.487318412633613e-5 5.871340908925049e-5 -4.309193172957748e-5 -5.435596176539548e-5 1.6422405678895302e-5 -8.279288522317074e-6 -6.490122177638113e-5 3.117708183708601e-5 0.00011712389095919207 9.975941793527454e-6 -3.468328941380605e-5 0.00012418012192938477 9.53529161051847e-5; 1.830734618124552e-5 -7.476097380276769e-6 -9.509076335234568e-6 0.00010720775026129559 0.00022891539265401661 -1.780100865289569e-5 0.00011346080282237381 -4.8712274292483926e-5 -8.340813656104729e-5 2.086001040879637e-5 -0.00015277331112883985 -0.00010062263027066365 -7.459868356818333e-5 -4.6025452320463955e-5 -0.00016344073810614645 -2.9167174943722785e-5 4.6709905291209e-5 8.269950922112912e-5 -9.479973232373595e-5 -9.053312533069402e-5 -6.220904469955713e-5 -2.27564996748697e-5 -1.470692950533703e-6 2.027094342338387e-5 -1.76867197296815e-5 0.00014186426415108144 -4.961090235156007e-5 -4.661494676838629e-5 6.381324783433229e-5 4.503650779952295e-5 8.218586299335584e-5 8.606068149674684e-6; -2.397611751803197e-5 7.435428415192291e-5 -7.891234417911619e-5 -2.284586662426591e-5 -0.0001919454225571826 7.12585388100706e-5 -8.285065268864855e-5 -2.3212543965200894e-5 0.0001720605796435848 -1.9278853869764134e-5 -0.00011775112943723798 9.30517926462926e-5 -0.00017833041783887893 -2.1000874767196365e-5 -6.481426680693403e-5 -5.560593126574531e-5 3.1614799809176475e-5 0.00011925995204364881 9.085574129130691e-5 -2.725220656429883e-5 0.0001303398166783154 -8.462667756248266e-5 5.30334837094415e-5 7.339314379350981e-6 8.881129360815976e-6 3.311338878120296e-5 -1.4659923181170598e-5 -5.856961433892138e-5 3.501023820717819e-5 0.00013967517588753253 0.00023752640117891133 -4.207112579024397e-5; 2.1844134607817978e-5 -0.00014479240053333342 3.092973565799184e-5 3.357848254381679e-5 -5.007241998100653e-5 -1.8200425984105095e-5 9.270272857975215e-5 8.062866982072592e-5 5.490854164236225e-6 -1.204079126182478e-5 -6.637152182520367e-6 0.00017624112660996616 -8.480305405100808e-5 9.326179861091077e-5 0.00010831205145223066 -0.00020070798927918077 -9.279966616304591e-5 -0.0002279688633279875 4.6087949158390984e-5 -7.29132370906882e-5 2.0374476662254892e-5 0.00010940203355858102 -9.017370757646859e-5 -0.00011571280629141256 -9.211356518790126e-5 -1.2122844964324031e-5 4.6797369577689096e-5 -9.038650750881061e-5 -7.778075814712793e-5 -3.6846304283244535e-5 -4.903795343125239e-5 -9.655654866946861e-5; -8.579634709349193e-7 -0.0002066097076749429 -0.00011996288958471268 -0.00012564833741635084 2.5726812964421697e-5 0.00018287413695361465 0.0002073049108730629 -7.017081316007534e-6 -1.3617780496133491e-5 -5.657367000821978e-5 -1.569010601087939e-5 -2.3538550522061996e-5 4.8210826207650825e-5 1.9343942767591216e-5 0.0001385081559419632 0.00011750046542147174 -0.00015189555415418 5.523897925741039e-5 -8.703245111973956e-5 -2.4861637939466164e-5 2.6413419618620537e-5 -4.64096519863233e-5 7.861565245548263e-5 -6.354363722493872e-5 8.942540262069087e-6 2.8756532628904097e-5 -0.00010554955952102318 2.485627555870451e-5 6.854698585812002e-5 -0.0001043926240527071 -8.586132025811821e-5 8.18735352368094e-5; 0.0002135591785190627 0.00013151006714906543 -0.00018185087537858635 0.00010091271542478353 -4.2591953388182446e-5 8.093713404377922e-5 -0.0001391009718645364 0.0001273994566872716 -4.298581916373223e-5 0.0002342983498238027 0.00012029625941067934 -0.00014725365326739848 9.190304263029248e-5 5.5020318541210145e-5 0.00010128232679562643 -5.751055869041011e-5 4.921677100355737e-5 8.696626900928095e-5 0.0001405872026225552 -9.268555731978267e-5 -2.1258920241962187e-5 -6.54149116599001e-5 4.459129195311107e-5 6.439611752284691e-5 -2.172727727156598e-5 -5.356820111046545e-5 9.923696052283049e-5 -0.00016687223978806287 -0.0001219390396727249 2.6427591365063563e-5 -7.411464594042627e-6 1.6011428670026362e-5; 0.00020759271865244955 -8.213899855036288e-5 4.552265090751462e-5 -7.235410521388985e-6 0.00016132413293235004 1.0742498488980345e-5 2.456557376717683e-5 -0.00011407442798372358 -0.00011202173482161015 -6.374440272338688e-5 3.078745794482529e-5 4.057094702147879e-5 1.1581369108171202e-5 0.00015548791270703077 -2.2667089069727808e-5 -7.496504986193031e-5 6.08162627031561e-5 8.652482210891321e-5 3.818407276412472e-5 -0.00013766196207143366 -1.2482308193284553e-5 -0.0002604927576612681 -1.1788180927396752e-5 0.00010245155863231048 -2.3329272153205238e-5 -0.00011171853839186952 2.683536513359286e-5 7.772137178108096e-5 -0.00011871046444866806 -0.00010139432561118156 0.00019171720487065613 -5.803032763651572e-5; -3.524161729728803e-5 -3.799711339524947e-5 -0.0001543878752272576 -7.031435961835086e-5 -7.369306786131347e-6 3.530066896928474e-5 7.216243102448061e-5 -4.750082007376477e-5 -0.00011452152102719992 0.00024547401699237525 7.06018035998568e-5 -0.00015834395890124142 -0.00010509344429010525 -1.9506487660692073e-5 0.0001029026389005594 4.5672910346183926e-5 -0.00017570857016835362 -1.9434161004028283e-5 7.900747004896402e-5 7.070915307849646e-5 -0.00019769961363635957 0.000125920822029002 -1.4219324839359615e-5 -4.7706136683700606e-5 -5.0051727157551795e-5 3.840317367576063e-5 -4.32184970122762e-5 0.00017298516468144953 9.743700502440333e-5 -1.1763959264499135e-5 -0.00011181385343661532 -0.00025018639280460775; -5.5258540669456124e-5 1.8087301214109175e-5 -9.965502977138385e-5 -0.00010975340410368517 -7.624323916388676e-5 0.00013782917812932283 -3.984098293585703e-5 3.073325933655724e-5 -8.78623322932981e-5 -1.8344175259699114e-5 -4.20141986978706e-5 0.00017695750284474343 0.00019724210142157972 6.193435547174886e-5 -4.8233720008283854e-5 4.8301702918251976e-5 -0.00011174748942721635 7.044799713185057e-5 -5.9433896240079775e-5 -6.293572369031608e-5 -7.425008334394079e-6 -3.098271554335952e-5 6.897310231579468e-5 -2.1363888663472608e-5 8.189096115529537e-5 -7.0597679950878955e-6 0.00014192616799846292 -0.0001361532195005566 0.00010379173181718215 4.2008749005617574e-5 0.00011425142292864621 2.156763002858497e-6; -9.368467726744711e-5 -0.0001465632813051343 -0.00013627373846247792 0.00011150490900035948 0.000148041159263812 1.9151713786413893e-5 7.683338481001556e-5 4.6866458433214575e-5 9.218986087944359e-5 -0.00015922932652756572 -1.1877160432050005e-5 -3.651264705695212e-6 -9.151640551863238e-5 -9.945130841515493e-6 -6.324681453406811e-5 9.12429895834066e-5 6.396269600372761e-5 -6.71638481435366e-5 3.683309842017479e-5 5.637649155687541e-5 6.986442895140499e-5 8.971405623015016e-5 7.475108759535942e-6 0.00011671743413899094 0.0001410535187460482 -0.00011666501086438075 0.00010081232903758064 1.9929017071262933e-5 -2.926743218267802e-5 8.425492706010118e-5 -0.0001409021788276732 -0.0001271645596716553; 6.226518598850816e-5 -0.0001550238812342286 0.00015362691192422062 -7.906807149993256e-5 -7.572171307401732e-5 -0.00010579486115602776 -5.746413989982102e-6 0.00010010981350205839 0.00018997275037690997 1.9668566437758273e-6 0.0001091051526600495 0.00024385355936829 1.9836181309074163e-5 -0.0002152212109649554 -0.00016908829275052994 -9.892282105283812e-5 -0.00011787428229581565 -6.13884549238719e-5 0.00015061038720887154 4.092584731552051e-6 -0.0001256628893315792 3.188707341905683e-5 -3.9325565012404695e-5 -5.642693940899335e-5 4.3136824388056993e-5 2.160561598429922e-5 0.00013621446851175278 -3.867835766868666e-5 0.0001232447975780815 0.0003146734379697591 3.1869469239609316e-5 6.352092896122485e-5; 5.8051828091265634e-5 0.00011755238665500656 -8.630123920738697e-5 1.7596246834727935e-5 -9.446116018807516e-5 -3.5806016967399046e-5 -2.9850791179342195e-5 6.186258542584255e-5 1.728420829749666e-5 8.436351345153525e-5 -2.8148258479632204e-7 -7.191062468336895e-5 1.1572702760531683e-6 3.575453592929989e-5 -0.00011033168266294524 3.337827365612611e-5 -5.159179272595793e-5 2.3716676878393628e-5 -5.0667396862991154e-5 -0.00017250435485038906 -0.00011599581193877384 3.971724436269142e-5 9.238779966835864e-6 7.617037044838071e-5 0.0003262680256739259 0.00014796156028751284 6.44019601168111e-5 8.48758572828956e-5 0.00010801957250805572 -0.0001496288605267182 -8.75336118042469e-5 4.65415432699956e-5; -3.4864991903305054e-5 0.00014489937166217715 0.00017197175475303084 0.00016956886975094676 3.615394962253049e-5 -3.60743397322949e-5 -0.0001745673071127385 -0.00020246350322850049 0.0001061250368366018 1.6606882127234712e-5 -0.00016328980564139783 -6.497571303043514e-5 -2.9604376322822645e-5 4.5696990127908066e-5 -0.00015154689026530832 0.00016237853560596704 1.250468358193757e-5 -6.640946958214045e-5 8.525399607606232e-5 6.167071842355654e-5 0.00012972927652299404 -0.00016208075976464897 3.5217868571635336e-5 -0.00021100706362631172 4.6703871703357436e-6 7.03415644238703e-5 6.497756112366915e-5 -3.695793930091895e-5 4.616522346623242e-5 -8.646171045256779e-5 2.5158322387142107e-5 1.3180172572901938e-5; 0.0001078371933544986 -8.609816291027528e-7 -5.5659653298789635e-5 0.00010182469850406051 -6.74868279020302e-5 -0.0001341366587439552 6.28858688287437e-5 3.796168675762601e-5 4.7855657612672076e-5 6.634426881646505e-6 -7.501797517761588e-5 6.090577153372578e-5 6.986076914472505e-5 -0.00019676607917062938 -1.8029442799161188e-5 9.899766882881522e-5 -2.6429879653733224e-5 -4.630584589904174e-5 -3.6691755667561665e-5 4.0621373045723885e-5 5.4369924328057095e-5 -4.235930191498483e-6 -7.384837226709351e-5 -7.061962969601154e-5 -0.00012377793609630316 0.0001642249699216336 1.1223371984669939e-5 5.216310091782361e-5 -2.663106715772301e-5 0.00016381251043640077 1.4793770787946414e-5 8.166553016053513e-5; 7.209289469756186e-5 -0.000183279684279114 -6.494143599411473e-5 0.00026116534718312323 0.00013604013656731695 2.7541878807824105e-5 -3.083774936385453e-5 -9.27741639316082e-5 0.00011816724872915074 -5.3335217671701685e-5 8.633402467239648e-5 0.00013644818682223558 1.873067049018573e-5 4.0324946894543245e-5 -5.828330904478207e-5 2.2703083232045174e-5 -0.00021211568673606962 0.0001991705212276429 9.597362077329308e-5 7.825745342415757e-6 -6.046033377060667e-5 -0.00010037812899099663 -0.00015556903963442892 0.00011028232256649062 -2.845781455107499e-5 -7.17069906386314e-6 -4.36513582826592e-5 6.142103666206822e-5 -3.320914038340561e-5 4.335317498771474e-5 -2.4388624296989292e-5 0.00011741802882170305; 6.885783659527078e-5 -6.577162275789306e-5 -0.00016516853065695614 0.00015954807167872787 -0.00020923551346641034 -7.437654858222231e-5 -1.8963084585266188e-5 -7.688617915846407e-5 0.0001293886307394132 0.00018441291467752308 4.5995824621059e-5 0.00013424406643025577 -7.607619045302272e-5 9.690385195426643e-5 4.898919360130094e-5 0.00015271862503141165 -5.164634785614908e-5 -0.00011597086995607242 1.351633841295552e-6 8.973425428848714e-5 -1.4386297152668703e-5 5.620609226753004e-5 -2.803196730383206e-5 9.230711293639615e-5 -9.703975229058415e-5 -5.702647831640206e-5 -1.490930389991263e-5 0.00014520747936330736 3.7420257285702974e-5 -0.00015659407654311508 3.7704387068515643e-5 6.270015001064166e-5; -0.00010970162838930264 -3.1929748729453422e-6 1.8567363440524787e-5 8.471964247291908e-5 9.443170711165294e-5 -3.129248216282576e-5 -4.110973168280907e-5 -6.364767841660068e-7 9.279819641960785e-5 -1.8605729565024376e-5 -0.00015082447498571128 4.7512261517113075e-5 0.00022276595700532198 -7.833735435269773e-5 -0.00015602796338498592 2.795490036078263e-5 -1.5144451026571915e-5 7.955864930409007e-6 -4.2314368329243734e-5 3.379584813956171e-5 -0.0001531906600575894 -6.305200076894835e-5 7.16425129212439e-5 2.6753130441647954e-5 -5.048239108873531e-5 -0.00023256019630935043 -2.406377279839944e-5 4.369274392956868e-5 -2.6701227398007177e-5 4.475247988011688e-5 -3.310849569970742e-5 2.3644826796953566e-5; 0.0001244883460458368 -0.00015694595640525222 -8.181330485967919e-5 7.419978646794334e-5 3.80677156499587e-5 -9.770221367944032e-5 -0.0001992011530091986 -0.00015588085807394236 -4.474118395592086e-5 3.437306804698892e-5 -4.900018029729836e-5 -0.00025273559731431305 -9.579829929862171e-5 1.4170486792863812e-5 -4.296117549529299e-5 4.181101394351572e-5 0.00010371536336606368 0.00010284704330842942 0.00017538131214678288 -0.00011807677947217599 0.00010016620217356831 -0.00013130252773407847 0.00012288255675230175 0.0001293197856284678 0.0002503397408872843 -9.514749763184227e-6 -0.0002131173387169838 -7.137151987990364e-5 -9.06710556591861e-5 -5.295803930494003e-5 -9.285775377065875e-6 -2.047513407887891e-5; 4.218741014483385e-5 -2.687854976102244e-5 -0.00012609823897946626 -9.970099199563265e-5 -3.627047044574283e-5 0.00010572366591077298 -1.2999311366002075e-5 0.00010253689833916724 0.00016081029025372118 -6.873009988339618e-5 -0.00012459588469937444 0.0001757000427460298 -0.0001416235463693738 5.843212420586497e-6 9.426718315808102e-5 -3.371300044818781e-5 -0.00010652918717823923 -0.00012488446373026818 0.00011061981786042452 0.00011289351095911115 9.947803846444003e-6 0.0001259713462786749 9.078776929527521e-5 1.1399733068628848e-7 2.16452590393601e-5 0.00021435291273519397 -0.00012574777065310627 6.75436167512089e-5 3.791349081438966e-5 -0.00013065061648376286 7.310341516131302e-6 -1.1499429092509672e-5; -6.10337701800745e-5 -9.120344475377351e-5 6.0458471125457436e-5 7.351739623118192e-5 -6.528617086587474e-5 2.244137431262061e-5 9.03368418221362e-5 -5.904215868213214e-5 -5.7658129662740976e-5 9.106579091167077e-5 -3.118138192803599e-5 -0.0001820025354390964 -6.733254122082144e-5 0.00014207218191586435 4.810355312656611e-5 -8.362824155483395e-5 1.1691236977640074e-5 0.00019550722208805382 2.701854464248754e-5 -7.682145951548591e-5 -0.00013286093599162996 -0.00018943096802104264 -9.89918116829358e-5 0.00010191948240390047 -6.756121729267761e-5 6.154821312520653e-5 6.147019303170964e-5 0.00015826142043806612 3.850510256597772e-5 -3.350976112415083e-5 0.0002052895724773407 0.00021367623412515968; 5.3222087444737554e-5 0.00018807554442901164 -2.466087835273356e-6 1.9308688933961093e-5 2.1195686713326722e-5 -6.24933309154585e-5 -2.5295285013271496e-5 0.00011116469977423549 -4.100127989659086e-5 1.4142900909064338e-5 -9.869550012808759e-6 -0.0001601492112968117 -0.0001420916960341856 6.437262345571071e-5 7.957116031320766e-5 -7.501710206270218e-5 9.50536341406405e-5 0.00023381195205729455 -7.245416782097891e-6 1.590992178535089e-5 -4.081567021785304e-5 -4.805360731552355e-5 0.0001450195995857939 0.0001076179978554137 -5.355043685995042e-5 8.168463682522997e-5 -8.734091534279287e-5 -7.210665353341028e-5 -5.25250106875319e-5 -0.00015192330465652049 -6.670652510365471e-5 -5.0764076149789616e-5; -0.00010892246064031497 7.490791176678613e-5 0.00010133239265996963 -2.711246270337142e-5 -2.726160346355755e-5 0.00010722718434408307 -5.403154864325188e-5 6.642168591497466e-5 -0.00010644121357472613 -4.216549496049993e-5 0.00021430062770377845 -1.4636999594586086e-6 9.77125673671253e-5 0.00012267296551726758 -3.076957000303082e-5 -1.1607872693275567e-5 -8.360679203178734e-5 1.1522560271259863e-5 4.2852992919506505e-5 -4.938629354001023e-5 -4.06916078645736e-5 -6.368004687828943e-5 -1.1976171663263813e-5 -0.00010869059042306617 3.5530756576918066e-5 5.0454222218832e-5 0.00010188412852585316 6.841393042122945e-5 -5.024124402552843e-5 0.00015245396934915334 -7.89569821790792e-5 6.021958324708976e-5; -5.303627040120773e-5 -1.8745693523669615e-5 1.634869295230601e-5 5.77467872062698e-5 -7.224714863696136e-6 3.828764420177322e-6 7.062680379021913e-5 -2.284881156811025e-5 -4.180403266218491e-5 -8.662874279252719e-6 -0.0001529385190224275 2.7515592591953464e-5 0.00012639736814890057 4.1465718823019415e-5 -2.8078338800696656e-5 -0.00015498533321078867 0.00014600993017666042 -4.7832705604378134e-5 0.00017865713743958622 -4.187506783637218e-5 5.710517507395707e-5 0.00022340932628139853 0.00010140866652363911 -3.0475546736852266e-5 0.00013225879229139537 5.244702333584428e-5 0.00012626311217900366 -1.7756929082679562e-5 -5.791742296423763e-5 -0.00010622353875078261 -5.012869223719463e-5 -0.00011080242256866768; 5.4646359785692766e-5 5.7637065765447915e-5 0.00012274504115339369 0.0002755761961452663 -6.648313865298405e-5 -0.00014265780919231474 -7.240160630317405e-5 8.115586751955561e-6 9.974197018891573e-5 7.923512748675421e-5 0.000193128376849927 6.184099038364366e-5 8.91939562279731e-5 7.953823660500348e-5 -5.270963811199181e-5 -1.6961545043159276e-5 5.6648590543773025e-5 0.0002228095690952614 6.961478356970474e-5 0.00018504179024603218 -2.9028607968939468e-5 -5.85626985412091e-5 0.0001188105161418207 -9.372176282340661e-5 9.829584087128751e-6 0.00019844778580591083 -7.706311589572579e-5 0.00011331275891279802 2.7304155082674697e-5 2.1979632947477512e-5 -5.350495484890416e-5 2.5904915673891082e-5], bias = [0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;]), layer_4 = (weight = [3.2629279303364456e-5 5.03639348607976e-5 7.98077235231176e-5 -1.852538116509095e-5 -3.1385636702907505e-6 -0.00017683935584500432 0.000164843731909059 4.689307388616726e-5 8.124146552290767e-5 0.00010770842345664278 -0.00015785959840286523 0.00012844763114117086 5.772338408860378e-5 0.00017318777099717408 -0.00015215481107588857 -0.0001098293432733044 -2.3794747903593816e-5 1.9830482415272854e-5 3.390213896636851e-5 3.739298335858621e-5 8.496735244989395e-5 -2.1927227862761356e-5 -3.225576801924035e-5 -6.152622063382296e-6 -0.00019298044207971543 -1.5268085917341523e-5 2.79104078799719e-5 -3.7722322304034606e-5 6.3804063756833784e-6 0.00012652401346713305 1.8189550246461295e-5 -3.907410427927971e-5; -3.0793882615398616e-5 -7.384389755316079e-5 5.234796117292717e-5 -8.966428868006915e-5 0.00010205378930550069 -6.106228101998568e-5 -0.00010570495942374691 -2.4574814233346842e-5 -7.383016054518521e-5 0.00016824266640469432 5.739341213484295e-5 -7.119746442185715e-5 3.492101313895546e-5 7.726706826360896e-5 -5.264234914648114e-6 -7.681448187213391e-5 -8.816291665425524e-5 2.2420015739044175e-5 7.506003021262586e-5 -0.000177245688973926 -6.681910599581897e-5 -0.00023924508423078805 -0.00013460488116834313 8.047367009567097e-5 4.374219133751467e-5 -1.974924271053169e-5 0.0001006708771456033 0.00016900496848393232 6.784794095437974e-5 -0.00010139592632185668 -7.19966774340719e-5 -7.281249736479367e-7], bias = [0.0; 0.0;;])) +``` + + +Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses + + +$$ +u[1] = \chi +$$ + + +$$ +u[2] = \phi +$$ + + +where, $p$, $M$, and $e$ are constants + + +```julia +function ODE_model(u, nn_params, t) + χ, ϕ = u + p, M, e = ode_model_params + + # In this example we know that `st` is am empty NamedTuple hence we can safely ignore + # it, however, in general, we should use `st` to store the state of the neural network. + y = 1 .+ first(nn([first(u)], nn_params, st)) + + numer = (1 + e * cos(χ))^2 + denom = M * (p^(3 / 2)) + + χ̇ = (numer / denom) * y[1] + ϕ̇ = (numer / denom) * y[2] + + return [χ̇, ϕ̇] +end +``` + + +``` +ODE_model (generic function with 1 method) +``` + + +Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model. + + +```julia +prob_nn = ODEProblem(ODE_model, u0, tspan, params) +soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false)) +waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)) + +fig = with_theme(theme_web()) do + fig = Figure() + ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform") + + l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75) + s1 = scatter!(ax, tsteps, waveform; markershape=:circle, markeralpha=0.25, alpha=0.5) + + l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75) + s2 = scatter!(ax, tsteps, waveform_nn; markershape=:circle, markeralpha=0.25, alpha=0.5) + + axislegend(ax, [[l1, s1], [l2, s2]], + ["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb) + + return fig +end +``` + + +![](1_GravitationalWaveForm-35.png) + + + + +## Setting Up for Training the Neural Network + + +Next, we define the objective (loss) function to be minimized when training the neural differential equations. + + +```julia +function loss(θ) + pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false)) + pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params)) + loss = sum(abs2, waveform .- pred_waveform) + return loss, pred_waveform +end +``` + + +``` +loss (generic function with 1 method) +``` + + +Warmup the loss function + + +```julia +loss(params) +``` + + +``` +(0.18041723123998618, [-0.024251308334199348, -0.023467243913447414, -0.022683179492695528, -0.02135833998251774, -0.019464969394766607, -0.01696277285707412, -0.013798835135721017, -0.009904645956411521, -0.005198353185178126, 0.00041560795841743105, 0.00703732551850854, 0.014750451614037544, 0.02357449400598052, 0.033336000811801256, 0.043389001770050645, 0.0519534598875645, 0.054715447077622555, 0.04254942007089733, 0.0018961384696719126, -0.06640517778620761, -0.11020393789897714, -0.07613172240005946, -0.006997405938716619, 0.038432596668994665, 0.05392000026564186, 0.052736500633410476, 0.04477940248318577, 0.03488416504615969, 0.025084605535768252, 0.016145078440774514, 0.008287260499085087, 0.0015164617042157462, -0.004242815498325918, -0.009087554169270103, -0.013112844630807165, -0.016400454801511786, -0.0190201110769075, -0.02102629761775136, -0.022460328810305304, -0.02335115895221063, -0.023716293669543233, -0.023562091222168693, -0.022884952563688794, -0.021669025921950084, -0.019888141651521984, -0.01750348603683818, -0.014462711954016258, -0.010699255077318637, -0.0061313829860049915, -0.0006635000720833607, 0.0058066502226664434, 0.01337240944853092, 0.022069950929458494, 0.03177383069650867, 0.041940091156632775, 0.051018954956663955, 0.05517450089712274, 0.04605317017625843, 0.010162291636163471, -0.05616164974824417, -0.10836009557595651, -0.08508119224191282, -0.016449033236001314, 0.03367316761423824, 0.05275456655955605, 0.05334774660918139, 0.046101450494774994, 0.03641257093584373, 0.026598042668796014, 0.01755430825179534, 0.00955547939385153, 0.0026399063764761996, -0.003266719516305491, -0.00824748793215388, -0.012405482987147918, -0.01581527846806932, -0.018554517574579713, -0.020672337997386502, -0.022216746672874727, -0.02321376077798935, -0.023684821277387406, -0.023636327025123735, -0.023065055817171175, -0.021958120136587524, -0.020290405020548626, -0.01802126209469952, -0.015105383202458012, -0.0114702712865944, -0.007042858150846819, -0.0017195200613367558, 0.004596151297372128, 0.012011857117084709, 0.02057332520002495, 0.030202879133621303, 0.04044124426929896, 0.04995117916474034, 0.05533147062988771, 0.048980423519990844, 0.017747236318357555, -0.045646036160823585, -0.1047285682151961, -0.09301214809083987, -0.026344034939178227, 0.02824388067168037, 0.05118197302027952, 0.05376579469720643, 0.04734347910122554, 0.03791495264426201, 0.0281140679777711, 0.018973293767488103, 0.01084551493419812, 0.003780268599600791, -0.0022651121331484377, -0.00738872693396313, -0.01167332082884245, -0.015211156786496016, -0.01806513089806744, -0.020298952409817703, -0.021950692844659843, -0.02305613776758748, -0.02363208084478266, -0.023688990576161237, -0.02322491549871647, -0.022225724361567212, -0.02067024466468399, -0.018517886500263016, -0.015724727338229046, -0.012219537634478656, -0.007930597702707043, -0.002753781703736446, 0.003407890416298586, 0.010668556744113714, 0.019088757458449422, 0.02862523831161672, 0.038902609444006406, 0.04876685325505109, 0.055214760237833144, 0.05137353296351461, 0.02461589279686688, -0.03507465063399791, -0.09943768882782983, -0.09968005369026373, -0.03654944341181989, 0.022126461309756793, 0.049167732025344545, 0.053965762630863355, 0.048493459141359815, 0.03938734702008381, 0.029623146319916808, 0.020408189076175733, 0.012149072334452766, 0.004944660099856395, -0.0012482380630758976, -0.006502499329086837, -0.010922593529001812, -0.014582874928145499, -0.017558195377909237, -0.019900961326879452, -0.021665120272800868, -0.022876845808170253, -0.023558717892646942, -0.02372077464720352, -0.02336289751006526, -0.02247279567439569, -0.02102639187367114, -0.018994603705340552, -0.01632074944510913, -0.012945039128173427, -0.00879612265269591, -0.003766511207118636, 0.002242801566375155, 0.009345895681782205, 0.017615212004269746, 0.02704608633829339, 0.037331162102399576, 0.04748068168697271, 0.05485669419002649, 0.05326776456989467, 0.030767036760328666, -0.024662636735052743, -0.09264632100660598, -0.10487494193008341, -0.04690376337778372, 0.015322799198816609, 0.04666705323423059, 0.05392561111397336, 0.04953758012683116, 0.040818929770180847, 0.031131900175917115, 0.021847159576810705, 0.013470304119951378, 0.006124750159806839, -0.0002021891256584283, -0.005600484135876164, -0.010149314494301684, -0.013935534342949866, -0.017024303750955053, -0.01948562984370063, -0.021357857566344115, -0.022677047381385575, -0.023463526112302544, -0.02373179362949799, -0.023479745729240723, -0.022696639974348443, -0.021365167810663048, -0.019445766031924303, -0.016895145999914755, -0.013648595713338142, -0.009637392654918923, -0.004754811476356115, 0.0010993908915005557, 0.00804196514343179, 0.01615819437094472, 0.02546895434034182, 0.0357328591894149, 0.04610750106879312, 0.054281408925949966, 0.0547063810537505, 0.036200083564212206, -0.0145713842349482, -0.08457264455375564, -0.10842554110907839, -0.05720353037668435, 0.00783329150543054, 0.043650383041343256, 0.05361447037957761, 0.05046103129383658, 0.04220651126257712, 0.03262670060244742, 0.023295056492654457, 0.014806231850490936, 0.007325980302374394, 0.0008598695636054873, -0.0046751617803424215, -0.009354473869114461, -0.01326528091980349, -0.016472512576163244, -0.019047533278846347, -0.02102962958944012, -0.022456013461469734, -0.023347954947073847, -0.023721481146828496, -0.02357538315247028, -0.02290025486803118, -0.02167972131894394, -0.019876043958227597, -0.017446871355511706, -0.014329138443509325, -0.010455304032226253, -0.006581469620942961]) +``` + + +Now let us define a callback function to store the loss over time + + +```julia +const losses = Float64[] + +function callback(θ, l, pred_waveform) + push!(losses, l) + println("Training || Iteration: $(length(losses)) || Loss: $(l)") + return false +end +``` + + +``` +callback (generic function with 1 method) +``` + + + + +## Training the Neural Network + + +Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess + + +```julia +adtype = Optimization.AutoZygote() +optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype) +optprob = Optimization.OptimizationProblem(optf, params) +res = Optimization.solve(optprob, + BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking()); + callback, maxiters=1000) +``` + + +``` +retcode: Success +u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-3.75342751794858e-5; 9.58729287957472e-5; -0.00020307680824770662; 2.7310172299552448e-5; 4.48020618933588e-5; -9.99435087578435e-6; -2.9701484891115445e-5; -8.724159124532952e-5; -0.00010881431080603352; -9.694041364114735e-5; -4.375306161810088e-6; -0.00011095431545972445; 7.387933874257552e-5; -2.761208997977944e-5; 0.0001104313632822525; -0.00015784006973260328; -1.111359233618317e-5; 6.388432666417436e-5; -0.00014325912343306248; -4.568206350083117e-5; 7.129124423951195e-5; -5.953129220868656e-5; 9.476896229893399e-5; 0.00019663930288491105; 5.5411610446706605e-6; -0.00016974289610484404; -0.00022925164375891554; -4.508995698411197e-5; -3.811260103246339e-5; 1.2938471627411963e-5; -3.396819738555646e-5; -7.005258521522323e-5;;], bias = [-1.1338775861532203e-17; 1.4131118722035616e-16; -1.1374646816717522e-17; 7.98149251598764e-17; -9.169965872600527e-17; -1.885606954833115e-18; -6.079734191504744e-18; -8.290972850704031e-17; -3.5040293003017496e-16; -1.4835009257456544e-16; -1.025322615062851e-17; -1.1305770358056804e-16; 1.428344963184489e-16; -3.6724281228006806e-17; 7.15555110148462e-18; -2.9039273537639157e-17; -4.579780198128353e-18; 1.0416311518146772e-16; -3.653522997483967e-16; -3.629655494541203e-17; 3.99511297633713e-17; -1.9152419388547524e-17; 5.890555190327904e-17; 2.8651602823447427e-16; 5.464956546443027e-18; -3.68349539170966e-16; -3.261630788933324e-16; -5.1162075554914514e-17; -1.9773085046150974e-17; 1.1168240250058263e-17; 9.983899636539334e-18; -9.737290777491694e-17;;]), layer_3 = (weight = [6.760441129780004e-5 -0.00011390855405725154 8.957951932622805e-6 -0.0001611727818172942 0.0001387935156443598 -0.00013879801889663286 -3.554770927541442e-5 6.059032267412215e-5 -0.00024710430051445 -0.00014070768128394647 -8.979977036608121e-5 -1.4420351989512418e-5 -9.937330077182926e-5 -9.672996719847943e-5 -7.264526692696745e-5 5.856084709279309e-5 -2.82975849487418e-5 -6.73043211364632e-5 -8.460172438305754e-5 -3.2531148180274655e-5 1.4332046985956447e-5 -1.1629252811974378e-5 2.6106203505060318e-5 -0.0001221188755962693 -3.070904466687438e-5 -0.0001946703302465288 -8.264529941606794e-5 3.464676075135124e-5 -7.406447993269066e-6 -0.00013037119396278345 0.00013129242534967644 -0.00010237691793837197; -0.0001954767904156397 2.6743497053449e-5 0.00011323061696895125 1.764919759595888e-5 6.908682915097518e-5 -0.0002837187312455407 8.125125530357109e-5 0.00011636910128587507 9.217452677845032e-5 -3.6927144922992227e-6 -5.6750635869617245e-5 8.091884590399395e-5 -9.699152562765487e-6 8.010553208593863e-5 2.8920012462236505e-5 4.7853805861120074e-5 7.918548724562602e-5 -0.00017793995723180338 0.00012017362676273788 0.000110830344035704 0.00017339768599200993 -7.30361491362708e-6 -4.586688739649554e-5 0.00021210455813854077 0.0001369350684196198 1.3267970623593085e-5 0.00014216569614072377 3.746084619494567e-5 -7.864733563322082e-5 -5.684199643139995e-5 1.944820721989563e-5 9.604360455129314e-5; -2.598720436200216e-5 0.00010558463930518961 2.6529056691333625e-6 -3.118213623512279e-5 8.024987130792189e-5 -3.632081767992817e-5 0.00014321163565053866 2.0996989979348813e-5 -2.6607465196691875e-5 5.415113867510715e-5 -4.247834448881961e-6 1.0810548365582004e-5 -8.941233835783121e-7 -9.698568225704804e-5 0.00010148352396808352 8.050508279720616e-5 -0.00011599891638017026 1.5944477180584557e-7 -8.368420582828923e-5 7.58638949542617e-5 -0.0002159461049535921 6.960304984257926e-5 -7.644358026428493e-5 4.58666205582666e-5 4.6066618443170985e-5 -0.0001551939543509438 4.200285423449712e-5 -0.00015001820190210517 -5.4831621463938556e-5 -9.76170753068919e-5 5.161077445389976e-5 -4.710941117938184e-5; -4.947288896727665e-6 6.100730191254312e-5 0.00022488972338768475 2.248088624772274e-5 -5.916476053964203e-5 -0.00016892886030102036 2.9950337040718972e-5 -5.500188687430464e-5 -6.84257812899135e-5 -5.710547899790781e-5 -1.5855234021655092e-5 8.887497263647923e-5 -0.0001475408482000771 -0.00017234548992557327 -2.1253408856789348e-5 3.7669550075269147e-6 -5.590084235147458e-5 0.00016250113715803423 -6.582018810869693e-5 3.491792979919497e-5 5.610965111562243e-5 -1.3223453759491167e-5 -5.9342446700536914e-5 -0.00012054279629220471 4.1395594809801517e-5 7.07463203362618e-5 0.00019461295035588553 0.0001395306042760066 -6.878489345392078e-5 -0.00014103695700025073 6.211073362654206e-5 -1.0979503866036394e-5; -0.00016269271204271565 -4.038908504787466e-5 5.1512611530282926e-5 -8.924626367423013e-5 3.9881396655457286e-5 -3.983780392934876e-5 -6.90834282480671e-5 2.3045751144879187e-5 -3.676581095111925e-5 -1.3749071190770167e-5 -7.599299046467374e-5 3.229912486370318e-5 -6.497380004054777e-5 -0.00015814648999856114 -0.00018512805458646608 -1.6000366163424738e-5 9.554853691463051e-5 -6.453910522884703e-5 -1.2446229125419222e-5 7.117072469065735e-5 -3.4108826597506123e-6 -0.00019522964489203932 2.5450686148022076e-5 -5.521054234559111e-7 -2.3055418428502386e-5 0.00014901006884851447 5.556521078202727e-5 -9.562821710040027e-5 -1.8110304741074546e-5 0.0001726836484109487 2.7749734768999323e-6 0.00010719743236643311; 1.2148911542952005e-5 -2.4241679506038487e-5 -8.933802791374024e-5 9.877304501633109e-5 -1.6317394050659178e-5 0.00011615482026745564 0.0001779773733745044 9.513227224151051e-5 0.00010179721676312458 0.00010881405569879587 0.00011092491286608395 -2.8999704676344652e-5 2.9532120157723084e-5 -0.0001081299201302515 0.00013373501088270962 -0.00012724931047498208 -9.673191186276827e-7 -9.763234310749282e-5 0.0001409183455564584 3.029721074771919e-5 -9.552776509163338e-5 8.48273334725455e-5 5.661075511515419e-5 0.00012429081789948012 -6.127475398325358e-5 -6.048279146282245e-5 8.647435561784423e-5 -7.488694290596206e-5 -1.0677701078477403e-5 -4.725715597824202e-5 6.438657829339535e-5 -0.00018757705446997082; -6.566344441734595e-5 -3.3643879448119265e-5 8.67184711235253e-5 -4.065522867376276e-5 -7.468980272001688e-6 8.047371315415704e-5 9.748170791072469e-5 -9.745932098367108e-5 2.185643583588605e-5 -5.39064436152447e-6 1.722044484568521e-5 -1.1764072639110607e-5 4.191118335531287e-6 -3.3227148975760804e-5 3.140544444549639e-5 0.00018610462515724645 1.480785698100672e-6 5.6154163160775124e-5 9.09553558297283e-5 1.2549019941893248e-5 8.374124385527661e-5 1.35701314714455e-5 5.295518653228923e-5 0.00016013938088078742 5.144685869100264e-5 -0.00013368993737159757 -0.0001138590456837261 -1.4319154303571332e-5 -3.375305883009329e-5 0.00011034319990658882 3.990478521091543e-5 -9.518345237365586e-5; -0.00016737165169571734 -0.00014819200914599555 2.694442337134827e-5 8.388051610425787e-5 4.416237334666774e-6 -1.3263668915457821e-5 -1.4835002002215623e-5 -4.6954975186518646e-5 -6.563949609654382e-5 -2.868573683327606e-5 -3.1703076455884936e-5 4.5600120575266436e-5 -6.561210939208883e-5 0.00014710088253239845 0.00010638271014658402 -0.00014243848745490738 7.2224855107224895e-6 -4.852714227992118e-5 -9.551412695100243e-5 0.000199517783403721 -4.1364736898242756e-5 -0.00015661509423763544 -0.00017195111031424824 0.00015114241505263156 -4.503410054850703e-5 -4.383660161647037e-5 5.5975399664499614e-5 -0.0001681773184823354 -0.00020149397205342743 0.00011917129704714453 0.00012282058996832703 8.879335910067948e-5; 9.00597297509778e-5 -0.00010427754745655108 -0.0002086151725456743 0.00011655003623269725 -0.0001878094116768227 9.170806884464148e-5 -2.2798480080778382e-5 7.359021406551585e-5 6.131974083069186e-5 8.612727111136182e-5 0.00011414843179915377 -0.00015243946831510028 8.775958854636084e-5 0.00010324328473467608 -7.487984107178256e-5 9.047899953252154e-5 -3.3407748509884886e-5 3.994790318353103e-5 0.00012330905161029473 -5.487109587596593e-5 5.871549733961866e-5 -4.308984347920828e-5 -5.435387351502995e-5 1.6424493929240652e-5 -8.277200271945472e-6 -6.489913352602908e-5 3.117917008742192e-5 0.00011712597920956231 9.97803004389809e-6 -3.468120116343453e-5 0.0001241822101797556 9.535500435555298e-5; 1.830675436186484e-5 -7.476689199655936e-6 -9.50966815460752e-6 0.00010720715844191477 0.00022891480083463608 -1.780160047227662e-5 0.00011346021100299302 -4.8712866111863406e-5 -8.340872838042596e-5 2.0859418589417246e-5 -0.0001527739029482208 -0.00010062322209004223 -7.459927538756323e-5 -4.602604413984475e-5 -0.00016344132992552503 -2.91677667630989e-5 4.670931347182806e-5 8.269891740174897e-5 -9.480032414311293e-5 -9.053371715007456e-5 -6.220963651893709e-5 -2.275709149424996e-5 -1.4712847699129106e-6 2.0270351604010414e-5 -1.768731154906244e-5 0.0001418636723317061 -4.961149417093083e-5 -4.661553858776685e-5 6.381265601495161e-5 4.503591598014203e-5 8.218527117397511e-5 8.605476330294684e-6; -2.3975341790094694e-5 7.435505987985818e-5 -7.891156845118913e-5 -2.284509089632846e-5 -0.0001919446468292455 7.125931453800822e-5 -8.284987696071115e-5 -2.321176823726521e-5 0.00017206135537151942 -1.9278078141828905e-5 -0.00011775035370930035 9.305256837422708e-5 -0.00017832964211094267 -2.1000099039258923e-5 -6.481349107899953e-5 -5.5605155537814054e-5 3.161557553711408e-5 0.0001192607277715854 9.08565170192393e-5 -2.7251430836361725e-5 0.00013034059240625174 -8.462590183454594e-5 5.3034259437376836e-5 7.340090107278716e-6 8.881905088753605e-6 3.311416450913322e-5 -1.4659147453246417e-5 -5.856883861098427e-5 3.501101393511546e-5 0.0001396759516154701 0.00023752717690684864 -4.207035006230759e-5; 2.1841082612125975e-5 -0.00014479545252901773 3.092668366233933e-5 3.357543054812414e-5 -5.007547197669794e-5 -1.8203477979798392e-5 9.269967658405961e-5 8.062561782504005e-5 5.4878021685545665e-6 -1.2043843257508862e-5 -6.6402041782137455e-6 0.000176238074614285 -8.480610604669606e-5 9.32587466152181e-5 0.00010830899945654932 -0.0002007110412748495 -9.28027181587392e-5 -0.00022797191532367684 4.608489716271794e-5 -7.291628908637953e-5 2.037142466656654e-5 0.00010939898156289115 -9.01767595721531e-5 -0.00011571585828706762 -9.211661718359463e-5 -1.2125896959988868e-5 4.67943175820478e-5 -9.038955950450198e-5 -7.778381014281986e-5 -3.684935627893775e-5 -4.904100542694463e-5 -9.655960066515713e-5; -8.585000698708431e-7 -0.0002066102442738775 -0.00011996342618364162 -0.0001256488740152869 2.5726276365485876e-5 0.0001828736003546785 0.0002073043742741269 -7.017617914942371e-6 -1.3618317095067583e-5 -5.65742066071543e-5 -1.5690642609815563e-5 -2.3539087120996008e-5 4.821028960871561e-5 1.9343406168655175e-5 0.00013850761934302913 0.00011749992882253997 -0.00015189609075311615 5.523844265847494e-5 -8.703298771867211e-5 -2.4862174538401974e-5 2.6412883019685255e-5 -4.6410188585258843e-5 7.861511585654805e-5 -6.35441738238681e-5 8.942003663132923e-6 2.875599602997298e-5 -0.00010555009611995014 2.485573895976869e-5 6.854644925918411e-5 -0.00010439316065164324 -8.586185685705418e-5 8.187299863787411e-5; 0.0002135618657421895 0.00013151275437218543 -0.0001818481881554944 0.00010091540264791092 -4.258926616505617e-5 8.093982126690717e-5 -0.00013909828464140915 0.00012740214391039295 -4.298313194061456e-5 0.0002343010370469225 0.00012029894663380736 -0.00014725096604428125 9.190572985341575e-5 5.502300576433752e-5 0.00010128501401874381 -5.7507871467303875e-5 4.9219458226685306e-5 8.696895623240542e-5 0.00014058988984566527 -9.268287009665645e-5 -2.1256233018838594e-5 -6.541222443677514e-5 4.459397917623126e-5 6.439880474594108e-5 -2.172459004843796e-5 -5.356551388736263e-5 9.92396477459125e-5 -0.00016686955256493664 -0.00012193635244959813 2.643027858819146e-5 -7.408777370915592e-6 1.601411589315011e-5; 0.00020759255415830715 -8.213916304450485e-5 4.552248641337436e-5 -7.235575015531434e-6 0.0001613239684382077 1.0742333994837863e-5 2.4565409273034387e-5 -0.00011407459247786564 -0.00011202189931575198 -6.374456721752884e-5 3.0787293450682796e-5 4.057078252733695e-5 1.1581204614029008e-5 0.00015548774821288834 -2.2667253563869633e-5 -7.496521435607145e-5 6.081609820901361e-5 8.652465761477096e-5 3.818390826998333e-5 -0.000137662126565576 -1.2482472687426769e-5 -0.0002604929221554104 -1.1788345421538753e-5 0.00010245139413817008 -2.3329436647347728e-5 -0.00011171870288601045 2.6835200639453212e-5 7.772120728693859e-5 -0.00011871062894281046 -0.00010139449010532404 0.00019171704037651373 -5.8030492130657954e-5; -3.524376010713706e-5 -3.799925620509302e-5 -0.00015439001803707866 -7.031650242820034e-5 -7.3714495959799475e-6 3.52985261594348e-5 7.216028821463122e-5 -4.750296288360944e-5 -0.00011452366383704163 0.0002454718741825318 7.059966079000678e-5 -0.00015834610171108277 -0.00010509558709995144 -1.950863047054155e-5 0.00010290049609071796 4.567076753635142e-5 -0.0001757107129782035 -1.9436303813875423e-5 7.900532723912841e-5 7.07070102686479e-5 -0.000197701756446206 0.00012591867921915454 -1.4221467649203316e-5 -4.770827949352344e-5 -5.0053869967401784e-5 3.8401030865930865e-5 -4.322063982208928e-5 0.00017298302187160097 9.743486221455434e-5 -1.176610207434903e-5 -0.00011181599624646451 -0.00025018853561445436; -5.525734967328555e-5 1.8088492210276725e-5 -9.965383877522877e-5 -0.00010975221310751436 -7.624204816771643e-5 0.00013783036912549392 -3.983979193968624e-5 3.07344503327254e-5 -8.786114129713159e-5 -1.834298426353165e-5 -4.201300770169948e-5 0.00017695869384090976 0.0001972432924177487 6.193554646791966e-5 -4.8232529012117466e-5 4.8302893914413404e-5 -0.00011174629843104531 7.044918812802008e-5 -5.9432705243916633e-5 -6.293453269414577e-5 -7.423817338224933e-6 -3.098152454718977e-5 6.89742933119623e-5 -2.1362697667316527e-5 8.189215215146646e-5 -7.058576998927984e-6 0.00014192735899461363 -0.00013615202850438628 0.0001037929228133527 4.200994000178864e-5 0.00011425261392481687 2.1579539990277103e-6; -9.368393533155394e-5 -0.00014656253936924302 -0.00013627299652659442 0.00011150565093625281 0.00014804190119970502 1.91524557223074e-5 7.683412674590886e-5 4.686720036910625e-5 9.219060281533424e-5 -0.00015922858459167447 -1.1876418496156486e-5 -3.650522769804675e-6 -9.151566358274018e-5 -9.944388905622153e-6 -6.324607259817755e-5 9.124373151929408e-5 6.39634379396211e-5 -6.716310620764407e-5 3.683384035606334e-5 5.637723349276843e-5 6.986517088729727e-5 8.971479816604282e-5 7.475850695427287e-6 0.0001167181760748751 0.0001410542606819417 -0.0001166642689284942 0.00010081307097346139 1.9929759007155966e-5 -2.9266690246784846e-5 8.425566899599465e-5 -0.00014090143689177997 -0.00012716381773576296; 6.226711739987321e-5 -0.00015502194982286845 0.00015362884333556048 -7.90661400885671e-5 -7.571978166265267e-5 -0.0001057929297446619 -5.744482578616715e-6 0.00010011174491341951 0.0001899746817882684 1.9687880551358273e-6 0.00010910708407141541 0.00024385549077964814 1.983811272043665e-5 -0.00021521927955358989 -0.00016908636133917173 -9.892088964148797e-5 -0.00011787235088444979 -6.138652351250856e-5 0.00015061231862022452 4.0945161429166775e-6 -0.00012566095792021647 3.188900483042055e-5 -3.932363360104443e-5 -5.642500799765186e-5 4.313875579942293e-5 2.160754739564693e-5 0.0001362163999230855 -3.8676426257321995e-5 0.00012324672898944656 0.0003146753693811249 3.187140065097453e-5 6.352286037258771e-5; 5.805310094543075e-5 0.00011755365950916846 -8.629996635323834e-5 1.759751968889332e-5 -9.445988733391029e-5 -3.58047441132334e-5 -2.9849518325176862e-5 6.186385828000512e-5 1.7285481151657462e-5 8.436478630569707e-5 -2.802097306306331e-7 -7.190935182920834e-5 1.1585431302166085e-6 3.5755808783465275e-5 -0.00011033040980878457 3.337954651028151e-5 -5.159051987179229e-5 2.3717949732557633e-5 -5.066612400883394e-5 -0.00017250308199622428 -0.00011599453908461022 3.9718517216855656e-5 9.240052820997847e-6 7.617164330253041e-5 0.0003262692985280916 0.00014796283314166662 6.440323297095505e-5 8.487713013706045e-5 0.00010802084536222082 -0.0001496275876725526 -8.75323389500817e-5 4.654281612415926e-5; -3.4865069285344795e-5 0.0001448992942801376 0.0001719716773709921 0.00016956879236890702 3.615387224049076e-5 -3.607441711433468e-5 -0.00017456738449477823 -0.00020246358061054006 0.00010612495945456232 1.660680474519518e-5 -0.0001632898830234376 -6.497579041247461e-5 -2.9604453704862274e-5 4.569691274586831e-5 -0.0001515469676473478 0.0001623784582239279 1.2504606199897807e-5 -6.64095469641801e-5 8.525391869402306e-5 6.167064104151682e-5 0.00012972919914095438 -0.00016208083714668866 3.5217791189595785e-5 -0.0002110071410083505 4.6703097882959765e-6 7.034148704183126e-5 6.49774837416307e-5 -3.695801668295868e-5 4.6165146084192685e-5 -8.646178783460754e-5 2.5158245005102376e-5 1.3180095190862294e-5; 0.00010783813510714447 -8.600398764592676e-7 -5.565871154615596e-5 0.00010182564025670659 -6.748588614938451e-5 -0.00013413571699130895 6.288681058138972e-5 3.796262851026997e-5 4.7856599365314755e-5 6.635368634289925e-6 -7.50170334249696e-5 6.090671328636832e-5 6.986171089736967e-5 -0.00019676513741798333 -1.802850104651862e-5 9.899861058145389e-5 -2.6428937901086966e-5 -4.630490414639669e-5 -3.6690813914921635e-5 4.062231479836954e-5 5.437086608070185e-5 -4.234988438853269e-6 -7.384743051444996e-5 -7.061868794337705e-5 -0.0001237769943436569 0.0001642259116742711 1.1224313737300152e-5 5.2164042670469294e-5 -2.6630125405077155e-5 0.000163813452189047 1.479471254059236e-5 8.166647191317992e-5; 7.209461085950997e-5 -0.00018327796811717025 -6.493971983218894e-5 0.0002611670633450717 0.00013604185272926472 2.754359496977294e-5 -3.083603320190614e-5 -9.277244776966358e-5 0.00011816896489109301 -5.333350150975806e-5 8.633574083434537e-5 0.00013644990298417763 1.873238665213156e-5 4.0326663056491704e-5 -5.828159288284002e-5 2.2704799393980113e-5 -0.00021211397057412082 0.00019917223738958955 9.597533693523048e-5 7.827461504363472e-6 -6.0458617608660634e-5 -0.0001003764128290497 -0.00015556732347248505 0.00011028403872841786 -2.8456098389126117e-5 -7.168982901930388e-6 -4.364964212073975e-5 6.142275282401599e-5 -3.3207424221457546e-5 4.335489114966352e-5 -2.438690813504105e-5 0.00011741974498364917; 6.885918439082065e-5 -6.577027496234664e-5 -0.00016516718286142392 0.00015954941947427794 -0.0002092341656708608 -7.437520078667188e-5 -1.8961736789716098e-5 -7.688483136291696e-5 0.00012938997853495843 0.00018441426247306936 4.5997172416609466e-5 0.00013424541422580078 -7.607484265747464e-5 9.690519974981658e-5 4.8990541396846025e-5 0.00015271997282695113 -5.164500006059866e-5 -0.00011596952216052372 1.3529816368369523e-6 8.97356020840367e-5 -1.438494935712047e-5 5.620744006307895e-5 -2.803061950828555e-5 9.230846073192955e-5 -9.703840449503369e-5 -5.702513052086432e-5 -1.49079561043854e-5 0.0001452088271588569 3.742160508125281e-5 -0.00015659272874756465 3.770573486406562e-5 6.270149780619e-5; -0.00010970321251510746 -3.19455899874605e-6 1.8565779314740855e-5 8.471805834711392e-5 9.443012298584846e-5 -3.1294066288631216e-5 -4.1111315808614116e-5 -6.38060909967542e-7 9.279661229380851e-5 -1.8607313690824977e-5 -0.00015082605911151682 4.7510677391314005e-5 0.00022276437287951933 -7.833893847850287e-5 -0.0001560295475107851 2.7953316234990154e-5 -1.5146035152377368e-5 7.954280804605626e-6 -4.23159524550385e-5 3.379426401375729e-5 -0.00015319224418339225 -6.305358489475204e-5 7.164092879544309e-5 2.6751546315862678e-5 -5.04839752145408e-5 -0.00023256178043514088 -2.4065356924177447e-5 4.369115980376423e-5 -2.670281152381194e-5 4.475089575431146e-5 -3.311007982551233e-5 2.3643242671150614e-5; 0.00012448589721240506 -0.0001569484052386777 -8.181575369307905e-5 7.419733763451106e-5 3.8065266816527425e-5 -9.770466251287312e-5 -0.00019920360184263076 -0.00015588330690736915 -4.4743632789344254e-5 3.437061921356358e-5 -4.900262913073123e-5 -0.0002527380461477361 -9.580074813205024e-5 1.4168037959431543e-5 -4.29636243287161e-5 4.180856511010282e-5 0.0001037129145326309 0.0001028445944749998 0.0001753788633133665 -0.00011807922830560721 0.00010016375334013949 -0.00013130497656750847 0.0001228801079188761 0.0001293173367950659 0.0002503372920538515 -9.517198596594021e-6 -0.00021311978755037453 -7.13739687133349e-5 -9.067350449261782e-5 -5.2960488138372775e-5 -9.288224210497834e-6 -2.0477582912307867e-5; 4.218874147041787e-5 -2.6877218435441806e-5 -0.00012609690765389962 -9.969966067004831e-5 -3.626913912015907e-5 0.00010572499723635758 -1.2997980040417812e-5 0.00010253822966474856 0.00016081162157930066 -6.872876855781563e-5 -0.0001245945533737898 0.0001757013740716091 -0.00014162221504379155 5.8445437461708124e-6 9.426851448366036e-5 -3.371166912261405e-5 -0.00010652785585265462 -0.00012488313240468534 0.0001106211491860002 0.0001128948422846949 9.949135172026435e-6 0.00012597267760425802 9.078910062085594e-5 1.1532865625405334e-7 2.1646590364944738e-5 0.000214354244060766 -0.00012574643932754457 6.754494807679267e-5 3.791482213997367e-5 -0.0001306492851581783 7.311672841715446e-6 -1.1498097766927164e-5; -6.103250401273264e-5 -9.120217858643489e-5 6.045973729278268e-5 7.351866239852408e-5 -6.528490469853314e-5 2.2442640479963027e-5 9.033810798947827e-5 -5.904089251479287e-5 -5.765686349540346e-5 9.106705707900931e-5 -3.118011576069355e-5 -0.00018200126927175905 -6.733127505348128e-5 0.0001420734480832065 4.810481929390348e-5 -8.362697538750186e-5 1.1692503144982481e-5 0.00019550848825539458 2.7019810809821437e-5 -7.682019334814434e-5 -0.0001328596698242896 -0.00018942970185370167 -9.899054551559709e-5 0.0001019207485712268 -6.755995112533516e-5 6.154947929253695e-5 6.14714591990302e-5 0.00015826268660540775 3.850636873331957e-5 -3.350849495680844e-5 0.0002052908386446827 0.00021367750029250012; 5.32224267564572e-5 0.00018807588374073042 -2.4657485235581463e-6 1.9309028245680806e-5 2.11960260250463e-5 -6.24929916037387e-5 -2.5294945701551797e-5 0.00011116503908595444 -4.1000940584872386e-5 1.4143240220783091e-5 -9.869210701088963e-6 -0.00016014887198509325 -0.0001420913567224664 6.437296276743044e-5 7.95714996249261e-5 -7.501676275098515e-5 9.505397345236029e-5 0.00023381229136901387 -7.2450774703803755e-6 1.5910261097070456e-5 -4.081533090613381e-5 -4.805326800380414e-5 0.00014501993889751272 0.00010761833716712923 -5.3550097548230636e-5 8.168497613694657e-5 -8.734057603107889e-5 -7.21063142216907e-5 -5.2524671375812265e-5 -0.0001519229653448007 -6.670618579193504e-5 -5.076373683807036e-5; -0.00010892068186833868 7.49096905387579e-5 0.00010133417143192293 -2.711068393139477e-5 -2.7259824691581627e-5 0.0001072289631160601 -5.402976987127531e-5 6.642346468694735e-5 -0.00010643943480275587 -4.216371618852827e-5 0.00021430240647575552 -1.4619211874886343e-6 9.771434613909923e-5 0.0001226747442892442 -3.076779123106078e-5 -1.1606093921312868e-5 -8.360501325981032e-5 1.1524339043234589e-5 4.285477169147173e-5 -4.938451476803436e-5 -4.068982909259945e-5 -6.36782681063144e-5 -1.1974392891291914e-5 -0.00010868881165111139 3.5532535348895136e-5 5.045600099079244e-5 0.00010188590729779988 6.841570919320538e-5 -5.023946525355219e-5 0.00015245574812113027 -7.895520340710277e-5 6.022136201906401e-5; -5.303452908255669e-5 -1.8743952205022994e-5 1.6350434270934423e-5 5.774852852492121e-5 -7.222973545045434e-6 3.830505738829106e-6 7.062854510887047e-5 -2.2847070249462738e-5 -4.1802291343539785e-5 -8.66113296060622e-6 -0.0001529367777037757 2.751733391059831e-5 0.0001263991094675493 4.146746014167083e-5 -2.8076597482051743e-5 -0.00015498359189215096 0.0001460116714953122 -4.783096428572861e-5 0.00017865887875822642 -4.187332651772153e-5 5.710691639260601e-5 0.00022341106760004835 0.00010141040784228586 -3.047380541822239e-5 0.00013226053361004716 5.2448764654479746e-5 0.00012626485349762564 -1.775518776402888e-5 -5.791568164558662e-5 -0.00010622179743213086 -5.0126950918543446e-5 -0.0001108006812500186; 5.4652674872077795e-5 5.7643380851816837e-5 0.0001227513562396963 0.0002755825112316526 -6.647682356660027e-5 -0.000142651494105927 -7.239529121678792e-5 8.121901838327756e-6 9.974828527527914e-5 7.924144257312272e-5 0.0001931346919363149 6.184730547000613e-5 8.920027131434976e-5 7.954455169138985e-5 -5.27033230256291e-5 -1.69552299568229e-5 5.665490563016071e-5 0.0002228158841816409 6.962109865605018e-5 0.00018504810533241585 -2.9022292882562042e-5 -5.8556383454828504e-5 0.00011881683122819006 -9.37154477370987e-5 9.835899173516636e-6 0.00019845410089223915 -7.705680080944664e-5 0.00011331907399918176 2.731047016905964e-5 2.198594803386512e-5 -5.349863976251861e-5 2.591123076026887e-5], bias = [-6.639028580158189e-9; 4.180017344085463e-9; -1.2272443320229317e-9; -4.785469349449767e-10; -2.2561338296038223e-9; 2.225994011433195e-9; 2.233121728012265e-9; -2.289379624054782e-9; 2.088250371623597e-9; -5.918193809492375e-10; 7.757279376376072e-10; -3.0519956933974755e-9; -5.365989361707419e-10; 2.6872231280446437e-9; -1.6449414248937268e-10; -2.1428098500105308e-9; 1.190996171121769e-9; 7.419358935259395e-10; 1.9314113659455934e-9; 1.2728541656967603e-9; -7.738203976736433e-11; 9.417526462982933e-10; 1.7161619488832803e-9; 1.3477955504792646e-9; -1.584125805519438e-9; -2.4488334328842434e-9; 1.331325584648675e-9; 1.2661673424599823e-9; 3.393117197984222e-10; 1.778771977081073e-9; 1.7413186518398889e-9; 6.31508638794759e-9;;]), layer_4 = (weight = [-0.0006698443046147509 -0.000652110171599187 -0.000622666690826289 -0.0007209998202020295 -0.0007056129051792856 -0.0008793136960560769 -0.0005376306192965661 -0.0006555812678504734 -0.0006212328947452045 -0.0005947660132730908 -0.0008603340293044003 -0.0005740266371898066 -0.0006447510538482616 -0.0005292865363897034 -0.0008546292540613204 -0.0008123036934005458 -0.0007262691634519201 -0.0006826439502848893 -0.0006685722309568772 -0.0006650814291719909 -0.0006175070909844167 -0.0007244016542202028 -0.0007347301535814814 -0.000708627029252182 -0.0008954548323318181 -0.0007177424104775897 -0.0006745640005112629 -0.0007401967330781717 -0.0006960940348816483 -0.0005759503707309255 -0.0006842848341656654 -0.0007415477483931231; 0.00021637547293231062 0.000173325641115639 0.0002995176077414295 0.00015750536654074505 0.00034922341034878003 0.0001861073395754659 0.00014146466501269237 0.00022259480688832265 0.00017333946707001908 0.00041541232081665177 0.0003045630645053873 0.00017597213096987087 0.0002820906679741262 0.00032443667734078046 0.00024190542168987065 0.00017035514219546657 0.0001590067303356151 0.0002695896687390872 0.00032222966121245913 6.992395695786546e-5 0.00018035055076600677 7.924566546932636e-6 0.00011256475531393427 0.0003276433141557217 0.0002909118294659588 0.00022742037239427933 0.0003478405216265833 0.00041617461380067544 0.00031501759695326146 0.00014577370967894148 0.00017517295864401266 0.00024644125168912936], bias = [-0.0007024744435487016; 0.00024716965680188186;;])) +``` + + + + +## Visualizing the Results + + +Let us now plot the loss over time + + +```julia +fig = with_theme(theme_web()) do + fig = Figure() + ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss") + + lines!(ax, losses; linewidth=2, alpha=0.75) + + return fig +end +``` + + +![](1_GravitationalWaveForm-48.png) + + +Finally let us visualize the results + + +```julia +prob_nn = ODEProblem(ODE_model, u0, tspan, res.u) +soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false)) +waveform_nn_trained = first(compute_waveform(dt_data, soln_nn, mass_ratio, + ode_model_params)) + +fig = with_theme(theme_web()) do + fig = Figure() + ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform") + + l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75) + s1 = scatter!(ax, tsteps, waveform; markershape=:circle, markeralpha=0.25, alpha=0.5) + + l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75) + s2 = scatter!(ax, tsteps, waveform_nn; markershape=:circle, markeralpha=0.25, alpha=0.5) + + l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75) + s3 = scatter!(ax, tsteps, waveform_nn_trained; markershape=:circle, markeralpha=0.25, + alpha=0.5) + + axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]], + ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"]; + position=:lb) + + return fig +end +``` + + +![](1_GravitationalWaveForm-50.png) + + +--- + + +*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* + diff --git a/previews/PR474/tutorials/beginner/1_Basics.md b/previews/PR474/tutorials/beginner/1_Basics.md new file mode 100644 index 000000000..e13618c34 --- /dev/null +++ b/previews/PR474/tutorials/beginner/1_Basics.md @@ -0,0 +1,811 @@ + + + + + +# Julia & Lux for the Uninitiated + + +This is a quick intro to [Lux](https://github.com/avik-pal/Lux.jl) loosely based on: + + +1. [PyTorch's tutorial](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html). +2. [Flux's tutorial](https://fluxml.ai/Flux.jl/stable/tutorials/2020-09-15-deep-learning-flux/). +3. [Flax's tutorial](https://flax.readthedocs.io/en/latest/notebooks/jax_for_the_impatient.html). + + +It introduces basic Julia programming, as well `Zygote`, a source-to-source automatic differentiation (AD) framework in Julia. We'll use these tools to build a very simple neural network. Let's start with importing `Lux.jl` + + +```julia +using Lux, Random +``` + + +Now let us control the randomness in our code using proper Pseudo Random Number Generator (PRNG) + + +```julia +rng = Random.default_rng() +Random.seed!(rng, 0) +``` + + +``` +Random.TaskLocalRNG() +``` + + + + +## Arrays + + +The starting point for all of our models is the `Array` (sometimes referred to as a `Tensor` in other frameworks). This is really just a list of numbers, which might be arranged into a shape like a square. Let's write down an array with three elements. + + +```julia +x = [1, 2, 3] +``` + + +``` +3-element Vector{Int64}: + 1 + 2 + 3 +``` + + +Here's a matrix – a square array with four elements. + + +```julia +x = [1 2; 3 4] +``` + + +``` +2×2 Matrix{Int64}: + 1 2 + 3 4 +``` + + +We often work with arrays of thousands of elements, and don't usually write them down by hand. Here's how we can create an array of 5×3 = 15 elements, each a random number from zero to one. + + +```julia +x = rand(rng, 5, 3) +``` + + +``` +5×3 Matrix{Float64}: + 0.455238 0.746943 0.193291 + 0.547642 0.746801 0.116989 + 0.773354 0.97667 0.899766 + 0.940585 0.0869468 0.422918 + 0.0296477 0.351491 0.707534 +``` + + +There's a few functions like this; try replacing `rand` with `ones`, `zeros`, or `randn`. + + +By default, Julia works stores numbers is a high-precision format called `Float64`. In ML we often don't need all those digits, and can ask Julia to work with `Float32` instead. We can even ask for more digits using `BigFloat`. + + +```julia +x = rand(BigFloat, 5, 3) +``` + + +``` +5×3 Matrix{BigFloat}: + 0.981339 0.793159 0.459019 + 0.043883 0.624384 0.56055 + 0.164786 0.524008 0.0355555 + 0.414769 0.577181 0.621958 + 0.00823197 0.30215 0.655881 +``` + + +```julia +x = rand(Float32, 5, 3) +``` + + +``` +5×3 Matrix{Float32}: + 0.567794 0.369178 0.342539 + 0.0985227 0.201145 0.587206 + 0.776598 0.148248 0.0851708 + 0.723731 0.0770206 0.839303 + 0.404728 0.230954 0.679087 +``` + + +We can ask the array how many elements it has. + + +```julia +length(x) +``` + + +``` +15 +``` + + +Or, more specifically, what size it has. + + +```julia +size(x) +``` + + +``` +(5, 3) +``` + + +We sometimes want to see some elements of the array on their own. + + +```julia +x +``` + + +``` +5×3 Matrix{Float32}: + 0.567794 0.369178 0.342539 + 0.0985227 0.201145 0.587206 + 0.776598 0.148248 0.0851708 + 0.723731 0.0770206 0.839303 + 0.404728 0.230954 0.679087 +``` + + +```julia +x[2, 3] +``` + + +``` +0.58720636f0 +``` + + +This means get the second row and the third column. We can also get every row of the third column. + + +```julia +x[:, 3] +``` + + +``` +5-element Vector{Float32}: + 0.34253937 + 0.58720636 + 0.085170805 + 0.8393034 + 0.67908657 +``` + + +We can add arrays, and subtract them, which adds or subtracts each element of the array. + + +```julia +x + x +``` + + +``` +5×3 Matrix{Float32}: + 1.13559 0.738356 0.685079 + 0.197045 0.40229 1.17441 + 1.5532 0.296496 0.170342 + 1.44746 0.154041 1.67861 + 0.809456 0.461908 1.35817 +``` + + +```julia +x - x +``` + + +``` +5×3 Matrix{Float32}: + 0.0 0.0 0.0 + 0.0 0.0 0.0 + 0.0 0.0 0.0 + 0.0 0.0 0.0 + 0.0 0.0 0.0 +``` + + +Julia supports a feature called *broadcasting*, using the `.` syntax. This tiles small arrays (or single numbers) to fill bigger ones. + + +```julia +x .+ 1 +``` + + +``` +5×3 Matrix{Float32}: + 1.56779 1.36918 1.34254 + 1.09852 1.20114 1.58721 + 1.7766 1.14825 1.08517 + 1.72373 1.07702 1.8393 + 1.40473 1.23095 1.67909 +``` + + +We can see Julia tile the column vector `1:5` across all rows of the larger array. + + +```julia +zeros(5, 5) .+ (1:5) +``` + + +``` +5×5 Matrix{Float64}: + 1.0 1.0 1.0 1.0 1.0 + 2.0 2.0 2.0 2.0 2.0 + 3.0 3.0 3.0 3.0 3.0 + 4.0 4.0 4.0 4.0 4.0 + 5.0 5.0 5.0 5.0 5.0 +``` + + +The x' syntax is used to transpose a column `1:5` into an equivalent row, and Julia will tile that across columns. + + +```julia +zeros(5, 5) .+ (1:5)' +``` + + +``` +5×5 Matrix{Float64}: + 1.0 2.0 3.0 4.0 5.0 + 1.0 2.0 3.0 4.0 5.0 + 1.0 2.0 3.0 4.0 5.0 + 1.0 2.0 3.0 4.0 5.0 + 1.0 2.0 3.0 4.0 5.0 +``` + + +We can use this to make a times table. + + +```julia +(1:5) .* (1:5)' +``` + + +``` +5×5 Matrix{Int64}: + 1 2 3 4 5 + 2 4 6 8 10 + 3 6 9 12 15 + 4 8 12 16 20 + 5 10 15 20 25 +``` + + +Finally, and importantly for machine learning, we can conveniently do things like matrix multiply. + + +```julia +W = randn(5, 10) +x = rand(10) +W * x +``` + + +``` +5-element Vector{Float64}: + 1.2197981041108443 + -2.62625877100596 + -2.8573820474674845 + -2.4319346874291314 + 1.0108668577150213 +``` + + +Julia's arrays are very powerful, and you can learn more about what they can do [here](https://docs.julialang.org/en/v1/manual/arrays/). + + + + +### CUDA Arrays + + +CUDA functionality is provided separately by the [CUDA.jl package](https://github.com/JuliaGPU/CUDA.jl). If you have a GPU and LuxCUDA is installed, Lux will provide CUDA capabilities. For additional details on backends see the manual section. + + +You can manually add `CUDA`. Once CUDA is loaded you can move any array to the GPU with the `cu` function (or the `gpu` function exported by `Lux``), and it supports all of the above operations with the same syntax. + + +```julia +using LuxCUDA +if LuxCUDA.functional() + x_cu = cu(rand(5, 3)) + @show x_cu +end +``` + + +``` +5×3 CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}: + 0.857126 0.681728 0.73806 + 0.191956 0.506485 0.622865 + 0.857257 0.663036 0.239756 + 0.54452 0.503186 0.27993 + 0.833518 0.975649 0.967811 +``` + + + + +## (Im)mutability + + +Lux as you might have read is [Immutable by convention](http://lux.csail.mit.edu/dev/introduction/overview/#Design-Principles) which means that the core library is built without any form of mutation and all functions are pure. However, we don't enforce it in any form. We do **strongly recommend** that users extending this framework for their respective applications don't mutate their arrays. + + +```julia +x = reshape(1:8, 2, 4) +``` + + +``` +2×4 reshape(::UnitRange{Int64}, 2, 4) with eltype Int64: + 1 3 5 7 + 2 4 6 8 +``` + + +To update this array, we should first copy the array. + + +```julia +x_copy = copy(x) +view(x_copy, :, 1) .= 0 + +println("Original Array ", x) +println("Mutated Array ", x_copy) +``` + + +``` +Original Array [1 3 5 7; 2 4 6 8] +Mutated Array [0 3 5 7; 0 4 6 8] + +``` + + +Note that our current default AD engine (Zygote) is unable to differentiate through this mutation, however, for these specialized cases it is quite trivial to write custom backward passes. (This problem will be fixed once we move towards Enzyme.jl) + + + + +## Managing Randomness + + +We rely on the Julia StdLib `Random` for managing the randomness in our execution. First, we create an PRNG (pseudorandom number generator) and seed it. + + +```julia +rng = Random.default_rng() # Creates a Xoshiro PRNG +Random.seed!(rng, 0) +``` + + +``` +Random.TaskLocalRNG() +``` + + +If we call any function that relies on `rng` and uses it via `randn`, `rand`, etc. `rng` will be mutated. As we have already established we care a lot about immutability, hence we should use `Lux.replicate` on PRNGs before using them. + + +First, let us run a random number generator 3 times with the `replicate`d rng. + + +```julia +for i in 1:3 + println("Iteration $i ", rand(Lux.replicate(rng), 10)) +end +``` + + +``` +Iteration 1 [0.4552384158732863, 0.5476424498276177, 0.7733535276924052, 0.9405848223512736, 0.02964765308691042, 0.74694291453392, 0.7468008914093891, 0.9766699015845924, 0.08694684883050086, 0.35149138733595564] +Iteration 2 [0.4552384158732863, 0.5476424498276177, 0.7733535276924052, 0.9405848223512736, 0.02964765308691042, 0.74694291453392, 0.7468008914093891, 0.9766699015845924, 0.08694684883050086, 0.35149138733595564] +Iteration 3 [0.4552384158732863, 0.5476424498276177, 0.7733535276924052, 0.9405848223512736, 0.02964765308691042, 0.74694291453392, 0.7468008914093891, 0.9766699015845924, 0.08694684883050086, 0.35149138733595564] + +``` + + +As expected we get the same output. We can remove the `replicate` call and we will get different outputs. + + +```julia +for i in 1:3 + println("Iteration $i ", rand(rng, 10)) +end +``` + + +``` +Iteration 1 [0.4552384158732863, 0.5476424498276177, 0.7733535276924052, 0.9405848223512736, 0.02964765308691042, 0.74694291453392, 0.7468008914093891, 0.9766699015845924, 0.08694684883050086, 0.35149138733595564] +Iteration 2 [0.018743665453639813, 0.8601828553599953, 0.6556360448565952, 0.7746656838366666, 0.7817315740767116, 0.5553797706980106, 0.1261990389976131, 0.4488101521328277, 0.624383955429775, 0.05657739601024536] +Iteration 3 [0.19597391412112541, 0.6830945313415872, 0.6776220912718907, 0.6456416023530093, 0.6340362477836592, 0.5595843665394066, 0.5675557670686644, 0.34351700231383653, 0.7237308297251812, 0.3691778381831775] + +``` + + + + +## Automatic Differentiation + + +Julia has quite a few (maybe too many) AD tools. For the purpose of this tutorial, we will use: + + +1. [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) – For Jacobian-Vector Product (JVP) +2. [Zygote.jl](https://github.com/FluxML/Zygote.jl) – For Vector-Jacobian Product (VJP) + + +*Slight Detour*: We have had several questions regarding if we will be considering any other AD system for the reverse-diff backend. For now we will stick to Zygote.jl, however once we have tested Lux extensively with [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl), we will make the switch. + + +Even though, theoretically, a VJP (Vector-Jacobian product - reverse autodiff) and a JVP (Jacobian-Vector product - forward-mode autodiff) are similar—they compute a product of a Jacobian and a vector—they differ by the computational complexity of the operation. In short, when you have a large number of parameters (hence a wide matrix), a JVP is less efficient computationally than a VJP, and, conversely, a JVP is more efficient when the Jacobian matrix is a tall matrix. + + +```julia +using ComponentArrays, ForwardDiff, Zygote +``` + + + + +### Gradients + + +For our first example, consider a simple function computing $f(x) = \frac{1}{2}x^T x$, where $\nabla f(x) = x$ + + +```julia +f(x) = x' * x / 2 +∇f(x) = x # `∇` can be typed as `\nabla` +v = randn(rng, Float32, 4) +``` + + +``` +4-element Vector{Float32}: + -0.4051151 + -0.4593922 + 0.92155594 + 1.1871622 +``` + + +Let's use AbstractDifferentiation and Zygote to compute the gradients. + + +```julia +println("Actual Gradient: ", ∇f(v)) +println("Computed Gradient via Reverse Mode AD (Zygote): ", only(Zygote.gradient(f, v))) +println("Computed Gradient via Forward Mode AD (ForwardDiff): ", ForwardDiff.gradient(f, v)) +``` + + +``` +Actual Gradient: Float32[-0.4051151, -0.4593922, 0.92155594, 1.1871622] +Computed Gradient via Reverse Mode AD (Zygote): Float32[-0.4051151, -0.4593922, 0.92155594, 1.1871622] +Computed Gradient via Forward Mode AD (ForwardDiff): Float32[-0.4051151, -0.4593922, 0.92155594, 1.1871622] + +``` + + +Note that `AD.gradient` will only work for scalar valued outputs. + + + + +### Jacobian-Vector Product + + +I will defer the discussion on forward-mode AD to [https://book.sciml.ai/notes/08/](https://book.sciml.ai/notes/08/). Here let us just look at a mini example on how to use it. + + +```julia +f(x) = x .* x ./ 2 +x = randn(rng, Float32, 5) +v = ones(Float32, 5) +``` + + +``` +5-element Vector{Float32}: + 1.0 + 1.0 + 1.0 + 1.0 + 1.0 +``` + + +Construct the pushforward function. We will write out the function here but in practice we recommend using [SparseDiffTools.auto_jacvec](https://docs.sciml.ai/SparseDiffTools/stable/#Jacobian-Vector-and-Hessian-Vector-Products)! + + +First we need to create a Tag for ForwardDiff. It is enough to know that this is something that you must do. For more details, see the [ForwardDiff Documentation](https://juliadiff.org/ForwardDiff.jl/dev/user/advanced/#Custom-tags-and-tag-checking)! + + +```julia +struct TestTag end +``` + + +Going in the details of what is function is doing is beyond the scope of this tutorial. But in short, it is constructing a new Dual Vector with the partials set to the input to the pushforward function. When this is propagated through the original function we get the value and the jvp + + +```julia +function pushforward_forwarddiff(f, x) + T = eltype(x) + function pushforward(v) + v_ = reshape(v, axes(x)) + y = ForwardDiff.Dual{ + ForwardDiff.Tag{TestTag, T}, + T, + 1, + }.(x, ForwardDiff.Partials.(tuple.(v_))) + res = vec(f(y)) + return ForwardDiff.value.(res), vec(ForwardDiff.partials.(res, 1)) + end + return pushforward +end + +pf_f = pushforward_forwarddiff(f, x) +``` + + +``` +(::Main.var"##225".var"#pushforward#1"{typeof(Main.var"##225".f), Vector{Float32}, DataType}) (generic function with 1 method) +``` + + +Compute the jvp. + + +```julia +val, jvp = pf_f(v) +println("Computed Value: f(", x, ") = ", val) +println("JVP: ", jvp[1]) +``` + + +``` +Computed Value: f(Float32[-0.877497, 1.1953009, -0.057005208, 0.25055695, 0.09351656]) = Float32[0.3850005, 0.71437216, 0.0016247969, 0.031389393, 0.0043726736] +JVP: -0.877497 + +``` + + + + +### Vector-Jacobian Product + + +Using the same function and inputs, let us compute the VJP. + + +```julia +val, pb_f = Zygote.pullback(f, x) +``` + + +``` +(Float32[0.3850005, 0.71437216, 0.0016247969, 0.031389393, 0.0043726736], Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(Main.var"##225".f), Vector{Float32}}, Tuple{Zygote.var"#3798#back#1205"{Zygote.var"#1201#1204"{Vector{Float32}, Vector{Float32}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Float32}}, Tuple{}}, Zygote.var"#3862#back#1231"{Zygote.ZBack{ChainRules.var"#slash_pullback_scalar#1560"{Vector{Float32}, Int64}}}}}}(∂(f))) +``` + + +Compute the vjp. + + +```julia +vjp = only(pb_f(v)) +println("Computed Value: f(", x, ") = ", val) +println("VJP: ", vjp[1]) +``` + + +``` +Computed Value: f(Float32[-0.877497, 1.1953009, -0.057005208, 0.25055695, 0.09351656]) = Float32[0.3850005, 0.71437216, 0.0016247969, 0.031389393, 0.0043726736] +VJP: -0.877497 + +``` + + + + +## Linear Regression + + +Finally, now let us consider a linear regression problem. From a set of data-points $\{ (x_i, y_i), i \in \{ 1, \dots, k \}, x_i \in \mathbb{R}^n, y_i \in \mathbb{R}^m \}$, we try to find a set of parameters $W$ and $b$, s.t. $f_{W,b}(x) = Wx + b$, which minimizes the mean squared error: + + +$$ +L(W, b) \longrightarrow \sum_{i = 1}^{k} \frac{1}{2} \| y_i - f_{W,b}(x_i) \|_2^2 +$$ + + +We can write `f` from scratch, but to demonstrate `Lux`, let us use the `Dense` layer. + + +```julia +model = Dense(10 => 5) + +rng = Random.default_rng() +Random.seed!(rng, 0) +``` + + +``` +Random.TaskLocalRNG() +``` + + +Let us initialize the parameters and states (in this case it is empty) for the model. + + +```julia +ps, st = Lux.setup(rng, model) +ps = ps |> ComponentArray +``` + + +``` +ComponentVector{Float32}(weight = Float32[-0.5583162 0.3457679 0.50863314 0.60294497 0.23095794 0.16602759 5.5791984f-6 0.61324424 -0.35419345 0.039559156; -0.05661944 -0.4899126 0.31236076 0.47100115 -0.5062956 -0.20445547 -0.03762182 0.5370978 0.22614014 0.27704597; 0.5198015 0.55730057 -0.34535396 -0.21587563 -0.12729146 -0.51019937 0.46597028 0.2918885 0.20849374 -0.4068233; 0.06026341 -0.11202827 0.31218112 0.14536527 -0.3413506 0.40088427 -0.48716235 -0.15096173 0.42526972 -0.3576447; 0.23414856 -0.5949539 -0.26137677 0.21756552 0.34443143 0.25046515 -0.049256783 -0.48404032 0.08254115 -0.5224755], bias = Float32[0.0; 0.0; 0.0; 0.0; 0.0;;]) +``` + + +Set problem dimensions. + + +```julia +n_samples = 20 +x_dim = 10 +y_dim = 5 +``` + + +``` +5 +``` + + +Generate random ground truth W and b. + + +```julia +W = randn(rng, Float32, y_dim, x_dim) +b = randn(rng, Float32, y_dim) +``` + + +``` +5-element Vector{Float32}: + 0.68468636 + -0.57578707 + 0.0594993 + -0.9436797 + 1.5164032 +``` + + +Generate samples with additional noise. + + +```julia +x_samples = randn(rng, Float32, x_dim, n_samples) +y_samples = W * x_samples .+ b .+ 0.01f0 .* randn(rng, Float32, y_dim, n_samples) +println("x shape: ", size(x_samples), "; y shape: ", size(y_samples)) +``` + + +``` +x shape: (10, 20); y shape: (5, 20) + +``` + + +For updating our parameters let's use [Optimisers.jl](https://github.com/FluxML/Optimisers.jl). We will use Stochastic Gradient Descent (SGD) with a learning rate of `0.01`. + + +```julia +using Optimisers + +opt = Optimisers.Descent(0.01f0) +``` + + +``` +Descent(0.01f0) +``` + + +Initialize the initial state of the optimiser + + +```julia +opt_state = Optimisers.setup(opt, ps) +``` + + +``` +Leaf(Descent(0.01), nothing) +``` + + +Define the loss function + + +```julia +mse(model, ps, st, X, y) = sum(abs2, model(X, ps, st)[1] .- y) +mse(weight, bias, X, y) = sum(abs2, weight * X .+ bias .- y) +loss_function(ps, X, y) = mse(model, ps, st, X, y) + +println("Loss Value with ground true parameters: ", mse(W, b, x_samples, y_samples)) + +for i in 1:100 + # In actual code, don't use globals. But here I will simply for the sake of + # demonstration + global ps, st, opt_state + # Compute the gradient + gs = gradient(loss_function, ps, x_samples, y_samples)[1] + # Update model parameters + opt_state, ps = Optimisers.update(opt_state, ps, gs) + if i % 10 == 1 || i == 100 + println("Loss Value after $i iterations: ", + mse(model, ps, st, x_samples, y_samples)) + end +end +``` + + +``` +Loss Value with ground true parameters: 0.009175307 +Loss Value after 1 iterations: 165.57005 +Loss Value after 11 iterations: 4.351237 +Loss Value after 21 iterations: 0.6856849 +Loss Value after 31 iterations: 0.15421417 +Loss Value after 41 iterations: 0.041469414 +Loss Value after 51 iterations: 0.014032223 +Loss Value after 61 iterations: 0.006883738 +Loss Value after 71 iterations: 0.004938521 +Loss Value after 81 iterations: 0.004391277 +Loss Value after 91 iterations: 0.0042331247 +Loss Value after 100 iterations: 0.0041888584 + +``` + + +--- + + +*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* + diff --git a/previews/PR474/tutorials/beginner/2_PolynomialFitting-12.png b/previews/PR474/tutorials/beginner/2_PolynomialFitting-12.png new file mode 100644 index 000000000..a58322afa Binary files /dev/null and b/previews/PR474/tutorials/beginner/2_PolynomialFitting-12.png differ diff --git a/previews/PR474/tutorials/beginner/2_PolynomialFitting-30.png b/previews/PR474/tutorials/beginner/2_PolynomialFitting-30.png new file mode 100644 index 000000000..b0878d2c0 Binary files /dev/null and b/previews/PR474/tutorials/beginner/2_PolynomialFitting-30.png differ diff --git a/previews/PR474/tutorials/beginner/2_PolynomialFitting.md b/previews/PR474/tutorials/beginner/2_PolynomialFitting.md new file mode 100644 index 000000000..8315f04f4 --- /dev/null +++ b/previews/PR474/tutorials/beginner/2_PolynomialFitting.md @@ -0,0 +1,233 @@ + + + + + +# Fitting a Polynomial using MLP + + +In this tutorial we will fit a MultiLayer Perceptron (MLP) on data generated from a polynomial. + + + + +## Package Imports + + +```julia +using Lux, LuxAMDGPU, LuxCUDA, Optimisers, Random, Statistics, Zygote +using CairoMakie, MakiePublication +``` + + + + +## Dataset + + +Generate 128 datapoints from the polynomial $y = x^2 - 2x$. + + +```julia +function generate_data(rng::AbstractRNG) + x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128)) + y = evalpoly.(x, ((0, -2, 1),)) .+ randn(rng, (1, 128)) .* 0.1f0 + return (x, y) +end +``` + + +``` +generate_data (generic function with 1 method) +``` + + +Initialize the random number generator and fetch the dataset. + + +```julia +rng = MersenneTwister() +Random.seed!(rng, 12345) + +(x, y) = generate_data(rng) +``` + + +``` +(Float32[-2.0 -1.968504 -1.9370079 -1.9055119 -1.8740157 -1.8425196 -1.8110236 -1.7795275 -1.7480315 -1.7165354 -1.6850394 -1.6535434 -1.6220472 -1.5905511 -1.5590551 -1.527559 -1.496063 -1.464567 -1.4330709 -1.4015749 -1.3700787 -1.3385826 -1.3070866 -1.2755905 -1.2440945 -1.2125984 -1.1811024 -1.1496063 -1.1181102 -1.0866141 -1.0551181 -1.023622 -0.992126 -0.96062994 -0.92913383 -0.8976378 -0.86614174 -0.8346457 -0.8031496 -0.77165353 -0.7401575 -0.70866144 -0.6771653 -0.6456693 -0.61417323 -0.5826772 -0.5511811 -0.51968503 -0.48818898 -0.4566929 -0.42519686 -0.39370078 -0.36220473 -0.33070865 -0.2992126 -0.26771653 -0.23622048 -0.20472442 -0.17322835 -0.14173229 -0.11023622 -0.07874016 -0.047244094 -0.015748031 0.015748031 0.047244094 0.07874016 0.11023622 0.14173229 0.17322835 0.20472442 0.23622048 0.26771653 0.2992126 0.33070865 0.36220473 0.39370078 0.42519686 0.4566929 0.48818898 0.51968503 0.5511811 0.5826772 0.61417323 0.6456693 0.6771653 0.70866144 0.7401575 0.77165353 0.8031496 0.8346457 0.86614174 0.8976378 0.92913383 0.96062994 0.992126 1.023622 1.0551181 1.0866141 1.1181102 1.1496063 1.1811024 1.2125984 1.2440945 1.2755905 1.3070866 1.3385826 1.3700787 1.4015749 1.4330709 1.464567 1.496063 1.527559 1.5590551 1.5905511 1.6220472 1.6535434 1.6850394 1.7165354 1.7480315 1.7795275 1.8110236 1.8425196 1.8740157 1.9055119 1.9370079 1.968504 2.0], [8.11723579535073 7.8972862806322315 7.667572185253954 7.493641443881164 7.328542256257643 7.1081451188446065 6.754145700236098 6.73844851250885 6.698323804024227 6.3637494708272655 6.270117709011731 6.2419372753805 5.816280759896085 5.718319527208828 5.741347639508506 5.258118446989299 5.268165780092538 5.195746082529355 5.032704772846244 4.733409783966572 4.520239616672976 4.369386593776045 4.107888442446331 4.182845399340577 4.002249800810884 3.8969011895086174 3.910820824989613 3.646440085736948 3.3343752660206305 3.3980378243437745 3.1887817476268587 2.9930802717826603 3.018980452144523 2.690492107796345 2.8576513349182378 2.4778283273281008 2.452401424624867 2.401875695877283 2.2896425232872755 2.2812518842985035 1.9742292519472466 1.7663454774622869 1.7829663021691418 1.6248666914928798 1.635090436697959 1.4887378757184528 1.4396068206428336 1.5047223947023354 1.2439428212858357 1.1770575798169982 1.0519113712665473 0.8008025630753797 0.8011788202541421 0.7702484835053167 0.9010273188596704 0.48114290312426095 0.4605012716399809 0.42308333113261615 0.2890108900859864 0.3324716507588617 0.2126899641074972 0.2560113968739265 0.08350192481301627 0.046225582753114294 -0.16118930624459 -0.013928769802494537 -0.030805824695545894 -0.10629780224701328 -0.17643440564041185 -0.2494508100897751 -0.3322350480467481 -0.45414851684613733 -0.6965624404632386 -0.38861245182183696 -0.4708530312086873 -0.6274991143463677 -0.5617763080815885 -0.6438360803492721 -0.7565600800322707 -0.5662591600023589 -0.6591533520776037 -0.9166793344639054 -0.8520467822193756 -0.9507226194240974 -1.0248823046771698 -0.97772916365376 -0.8199294436184201 -0.9080088282844027 -0.9682665790685976 -1.031816361263047 -0.9296919748814573 -1.1145618706755287 -1.2139119971536336 -1.0157839085777947 -0.9417175810509869 -0.9783498813733602 -0.9123675448444001 -1.138088633455826 -1.1212038088290894 -0.911429094488635 -1.023486657428913 -0.9287179111905346 -1.0396518660677925 -1.0370046468920306 -0.9846375721966646 -0.833026219703481 -0.8200258902651266 -0.789500663251252 -0.9068267920931062 -0.7284236770750803 -0.7093213401368348 -0.7048862544448803 -0.6215870033126495 -0.5892481295457608 -0.8462913756395639 -0.5544688796856879 -0.5805399434794658 -0.5761396334948753 -0.5851955365208916 -0.5561461874821676 -0.1969227628706652 -0.34073487813889014 -0.2738635064414512 -0.1425063756241582 -0.18330825579933746 -0.054321035831595324 -0.21213293699653427 0.049985105882301]) +``` + + +Let's visualize the dataset + + +```julia +with_theme(theme_web()) do + fig = Figure() + ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y") + + l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3) + s = scatter!(ax, x[1, :], y[1, :]; markersize=8, color=:orange, + strokecolor=:black, strokewidth=1) + + axislegend(ax, [l, s], ["True Quadratic Function", "Data Points"]) + + return fig +end +``` + + +![](2_PolynomialFitting-12.png) + + + + +## Neural Network + + +For this problem, you should not be using a neural network. But let's still do that! + + +```julia +model = Chain(Dense(1 => 16, relu), Dense(16 => 1)) +``` + + +``` +Chain( + layer_1 = Dense(1 => 16, relu), # 32 parameters + layer_2 = Dense(16 => 1), # 17 parameters +) # Total: 49 parameters, + # plus 0 states. +``` + + + + +## Optimizer + + +We will use Adam from Optimisers.jl + + +```julia +opt = Adam(0.03f0) +``` + + +``` +Adam(0.03, (0.9, 0.999), 1.0e-8) +``` + + + + +## Loss Function + + +We will use the `Lux.Training` API so we need to ensure that our loss function takes 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics. + + +```julia +function loss_function(model, ps, st, data) + y_pred, st = Lux.apply(model, data[1], ps, st) + mse_loss = mean(abs2, y_pred .- data[2]) + return mse_loss, st, () +end +``` + + +``` +loss_function (generic function with 1 method) +``` + + + + +## Training + + +First we will create a [`Lux.Experimental.TrainState`](../../api/Lux/contrib#Lux.Experimental.TrainState) which is essentially a convenience wrapper over parameters, states and optimizer states. + + +```julia +tstate = Lux.Training.TrainState(rng, model, opt) +``` + + +``` +Lux.Experimental.TrainState{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.relu), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, bias::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, layer_2::@NamedTuple{weight::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, bias::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Optimisers.Adam, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Optimisers.Adam, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}}}}}}(Chain(), (layer_1 = (weight = Float32[0.36222202; 0.23371002; -0.49825558; -0.18142056; -0.13757975; -0.50849473; 0.13773328; -0.035294008; 0.21778254; 0.04964345; -0.56594235; -0.45329624; -0.08787567; 0.5648949; 0.5260752; -0.07562564;;], bias = Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;]), layer_2 = (weight = Float32[-0.14330137 -0.39328107 -0.18253882 -0.55998546 -0.5919335 -0.3069779 -0.39085856 -0.4838621 0.3979575 0.5851314 0.24242708 0.35374007 0.10175798 0.29761198 -0.34761065 -0.05758927], bias = Float32[0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()), (layer_1 = (weight = Leaf(Adam(0.03, (0.9, 0.999), 1.0e-8), (Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;], Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;], (0.9, 0.999))), bias = Leaf(Adam(0.03, (0.9, 0.999), 1.0e-8), (Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;], Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;], (0.9, 0.999)))), layer_2 = (weight = Leaf(Adam(0.03, (0.9, 0.999), 1.0e-8), (Float32[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], Float32[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(0.03, (0.9, 0.999), 1.0e-8), (Float32[0.0;;], Float32[0.0;;], (0.9, 0.999))))), 0) +``` + + +Now we will use Zygote for our AD requirements. + + +```julia +vjp_rule = Lux.Training.AutoZygote() +``` + + +``` +ADTypes.AutoZygote() +``` + + +Finally the training loop. + + +```julia +function main(tstate::Lux.Experimental.TrainState, vjp, data, epochs) + data = data .|> gpu_device() + for epoch in 1:epochs + grads, loss, stats, tstate = Lux.Training.compute_gradients(vjp, + loss_function, data, tstate) + println("Epoch: $(epoch) || Loss: $(loss)") + tstate = Lux.Training.apply_gradients(tstate, grads) + end + return tstate +end + +dev_cpu = cpu_device() +dev_gpu = gpu_device() + +tstate = main(tstate, vjp_rule, (x, y), 250) +y_pred = dev_cpu(Lux.apply(tstate.model, dev_gpu(x), tstate.parameters, tstate.states)[1]) +``` + + +``` +1×128 Matrix{Float32}: + 7.93183 7.76661 7.60138 7.43616 7.27094 7.10571 6.94049 6.77526 6.61004 6.44482 6.27959 6.11437 5.94914 5.78392 5.61869 5.45347 5.28825 5.12302 4.9578 4.79257 4.62735 4.46213 4.29696 4.14682 4.01403 3.88123 3.74844 3.61565 3.48286 3.35007 3.21728 3.08449 2.9517 2.82191 2.70562 2.58933 2.47304 2.35675 2.24045 2.12416 2.00787 1.89158 1.77932 1.67136 1.5634 1.45544 1.34748 1.2629 1.18945 1.116 1.04255 0.969102 0.895652 0.822202 0.748752 0.675302 0.601852 0.528403 0.454953 0.381503 0.308053 0.234603 0.161153 0.0877038 0.014254 -0.0591961 -0.132646 -0.206095 -0.279545 -0.352995 -0.426445 -0.499895 -0.570313 -0.604513 -0.638713 -0.672913 -0.707113 -0.741312 -0.775512 -0.809712 -0.843911 -0.878112 -0.912311 -0.946511 -0.980711 -0.986985 -0.984268 -0.981551 -0.978835 -0.976118 -0.973401 -0.970685 -0.967968 -0.965252 -0.962535 -0.959818 -0.957101 -0.954385 -0.951669 -0.938957 -0.914585 -0.890213 -0.86584 -0.841468 -0.817095 -0.792723 -0.768351 -0.743978 -0.719606 -0.695233 -0.670861 -0.646488 -0.622116 -0.597744 -0.573371 -0.548999 -0.524627 -0.500254 -0.475882 -0.451509 -0.427137 -0.402765 -0.378392 -0.35402 -0.329648 -0.305275 -0.280903 -0.256531 +``` + + +Let's plot the results + + +```julia +with_theme(theme_web()) do + fig = Figure() + ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y") + + l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3) + s1 = scatter!(ax, x[1, :], y[1, :]; markersize=8, color=:orange, + strokecolor=:black, strokewidth=1) + s2 = scatter!(ax, x[1, :], y_pred[1, :]; markersize=8, color=:green, + strokecolor=:black, strokewidth=1) + + axislegend(ax, [l, s1, s2], ["True Quadratic Function", "Actual Data", "Predictions"]) + + return fig +end +``` + + +![](2_PolynomialFitting-30.png) + + +--- + + +*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* + diff --git a/previews/PR474/tutorials/beginner/3_SimpleRNN.md b/previews/PR474/tutorials/beginner/3_SimpleRNN.md new file mode 100644 index 000000000..d6bf436f9 --- /dev/null +++ b/previews/PR474/tutorials/beginner/3_SimpleRNN.md @@ -0,0 +1,272 @@ + + + + + +# Training a Simple LSTM + + +In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to: + + +1. Create custom Lux models. +2. Become familiar with the Lux recurrent neural network API. +3. Training using Optimisers.jl and Zygote.jl. + + + + +## Package Imports + + +```julia +using Lux, LuxAMDGPU, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Random, Statistics +``` + + + + +## Dataset + + +We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a `MLUtils.DataLoader`. Our dataloader will give us sequences of size 2 × seq*len × batch*size and we need to predict a binary value whether the sequence is clockwise or anticlockwise. + + +```julia +function get_dataloaders(; dataset_size=1000, sequence_length=50) + # Create the spirals + data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size] + # Get the labels + labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2)) + clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1) + for d in data[1:(dataset_size ÷ 2)]] + anticlockwise_spirals = [reshape(d[1][:, (sequence_length + 1):end], :, + sequence_length, 1) for d in data[((dataset_size ÷ 2) + 1):end]] + x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3)) + # Split the dataset + (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true) + # Create DataLoaders + return ( + # Use DataLoader to automatically minibatch and shuffle the data + DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true), + # Don't shuffle the validation data + DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false)) +end +``` + + +``` +get_dataloaders (generic function with 1 method) +``` + + + + +## Creating a Classifier + + +We will be extending the `Lux.AbstractExplicitContainerLayer` type for our custom model since it will contain a lstm block and a classifier head. + + +We pass the fieldnames `lstm_cell` and `classifier` to the type to ensure that the parameters and states are automatically populated and we don't have to define `Lux.initialparameters` and `Lux.initialstates`. + + +To understand more about container layers, please look at [Container Layer](http://lux.csail.mit.edu/stable/manual/interface/#container-layer). + + +```julia +struct SpiralClassifier{L, C} <: + Lux.AbstractExplicitContainerLayer{(:lstm_cell, :classifier)} + lstm_cell::L + classifier::C +end +``` + + +We won't define the model from scratch but rather use the [`Lux.LSTMCell`](../../api/Lux/layers#Lux.LSTMCell) and [`Lux.Dense`](../../api/Lux/layers#Lux.Dense). + + +```julia +function SpiralClassifier(in_dims, hidden_dims, out_dims) + return SpiralClassifier(LSTMCell(in_dims => hidden_dims), + Dense(hidden_dims => out_dims, sigmoid)) +end +``` + + +``` +Main.var"##225".SpiralClassifier +``` + + +We can use default Lux blocks – `Recurrence(LSTMCell(in_dims => hidden_dims)` – instead of defining the following. But let's still do it for the sake of it. + + +Now we need to define the behavior of the Classifier when it is invoked. + + +```julia +function (s::SpiralClassifier)(x::AbstractArray{T, 3}, ps::NamedTuple, + st::NamedTuple) where {T} + # First we will have to run the sequence through the LSTM Cell + # The first call to LSTM Cell will create the initial hidden state + # See that the parameters and states are automatically populated into a field called + # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying, + # and `Iterators.peel` to split out the first element for LSTM initialization. + x_init, x_rest = Iterators.peel(eachslice(x; dims=2)) + (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell) + # Now that we have the hidden state and memory in `carry` we will pass the input and + # `carry` jointly + for x in x_rest + (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm) + end + # After running through the sequence we will pass the output through the classifier + y, st_classifier = s.classifier(y, ps.classifier, st.classifier) + # Finally remember to create the updated state + st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm)) + return vec(y), st +end +``` + + + + +## Defining Accuracy, Loss and Optimiser + + +Now let's define the binarycrossentropy loss. Typically it is recommended to use `logitbinarycrossentropy` since it is more numerically stable, but for the sake of simplicity we will use `binarycrossentropy`. + + +```julia +function xlogy(x, y) + result = x * log(y) + return ifelse(iszero(x), zero(result), result) +end + +function binarycrossentropy(y_pred, y_true) + y_pred = y_pred .+ eps(eltype(y_pred)) + return mean(@. -xlogy(y_true, y_pred) - xlogy(1 - y_true, 1 - y_pred)) +end + +function compute_loss(x, y, model, ps, st) + y_pred, st = model(x, ps, st) + return binarycrossentropy(y_pred, y), y_pred, st +end + +matches(y_pred, y_true) = sum((y_pred .> 0.5) .== y_true) +accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred) +``` + + +``` +accuracy (generic function with 1 method) +``` + + +Finally lets create an optimiser given the model parameters. + + +```julia +function create_optimiser(ps) + opt = Optimisers.ADAM(0.01f0) + return Optimisers.setup(opt, ps) +end +``` + + +``` +create_optimiser (generic function with 1 method) +``` + + + + +## Training the Model + + +```julia +function main() + # Get the dataloaders + (train_loader, val_loader) = get_dataloaders() + + # Create the model + model = SpiralClassifier(2, 8, 1) + rng = Random.default_rng() + Random.seed!(rng, 0) + ps, st = Lux.setup(rng, model) + + dev = gpu_device() + ps = ps |> dev + st = st |> dev + + # Create the optimiser + opt_state = create_optimiser(ps) + + for epoch in 1:25 + # Train the model + for (x, y) in train_loader + x = x |> dev + y = y |> dev + (loss, y_pred, st), back = pullback(compute_loss, x, y, model, ps, st) + gs = back((one(loss), nothing, nothing))[4] + opt_state, ps = Optimisers.update(opt_state, ps, gs) + + println("Epoch [$epoch]: Loss $loss") + end + + # Validate the model + st_ = Lux.testmode(st) + for (x, y) in val_loader + x = x |> dev + y = y |> dev + (loss, y_pred, st_) = compute_loss(x, y, model, ps, st_) + acc = accuracy(y_pred, y) + println("Validation: Loss $loss Accuracy $acc") + end + end + + return (ps, st) |> cpu_device() +end + +ps_trained, st_trained = main() +``` + + +``` +((lstm_cell = (weight_i = Float32[-0.86167264 -0.42293867; -0.25197864 -0.75622934; 0.80377936 0.8820403; -0.0547525 0.07492376; -0.14647458 -0.6231062; 0.29097492 0.5499769; -0.8304294 0.065542385; 0.07603514 0.057226676; -0.039134953 0.08563263; -0.66015226 0.45116213; 1.1452403 -0.025277287; -0.009173643 0.0865255; -0.10448508 0.24553284; -0.8490004 0.3217581; -0.13019586 -0.2794364; 1.2182649 0.09943227; 0.67666495 -0.6293481; 0.13180551 -0.34722963; 0.46946698 -0.32844293; -0.5604955 0.51667506; 0.58226407 -0.8396601; 0.10203429 0.29503736; 0.86597884 -0.6396308; 0.9413931 0.6190642; -1.2567298 -0.09239372; 0.4872737 0.6442594; 1.0555211 0.677901; -0.45782351 -0.16882007; 0.73598784 -0.70290995; -0.3351124 0.7386465; -0.21653703 0.61397123; 0.6072665 -0.30303505], weight_h = Float32[-0.5175666 -0.0540258 0.2841425 -0.25916955 0.30149105 0.0050748046 -0.7026908 0.5526991; -0.6490742 0.23569672 -0.0657121 0.63530207 0.38919583 0.31218508 0.13132168 -0.045637704; 0.020180171 0.059647303 -0.029842459 0.5560518 -0.7445428 0.32197398 -0.63064736 -0.16421439; 0.036692034 -0.2747645 0.8872755 -0.09527659 0.9130826 0.08058891 0.27669194 0.9094774; -0.37817487 0.47252735 0.7732971 0.30380586 0.40297952 0.6386529 -0.34773254 -0.07745119; -0.05665953 -0.33936316 0.33388793 -0.099190064 -0.2845129 -0.22204192 -0.49584073 0.055964217; -0.7831404 0.371064 0.7016749 0.49620107 -0.8038273 0.6366681 0.03247163 -0.6194672; 0.6832412 0.3045688 1.032591 -1.3146787 0.8194584 0.15439296 -0.69981533 1.1493646; -0.15157364 0.4725639 0.4027123 -0.5002129 0.46449873 0.6740645 0.19627932 0.64133644; -0.41709745 -0.18856978 -0.38025862 -0.013487852 0.24634175 0.000894413 -0.7477189 0.76205194; -0.14617918 0.5503976 -0.06335492 0.5272881 -0.52698165 0.29227322 -0.26541632 -0.38074842; 0.055357717 0.8829194 0.28737235 -0.10625662 0.82950836 0.6134693 0.38962427 0.5013018; 0.64761835 0.4976158 0.3712376 -0.06839726 0.92167103 0.17874587 -0.6569381 0.4423014; -0.10055038 0.443792 0.08596908 0.08124385 0.611622 -0.012558972 -0.13516638 -0.39537105; -0.83281535 0.32090607 0.095000885 -0.40616816 -0.25754517 0.80708814 -0.36402744 0.42187604; -0.65100646 0.7926142 0.41881317 1.0657226 -0.078389876 0.81525993 -0.10534099 0.77736783; 0.15182799 -0.636 -0.05088705 -0.43641883 -0.35166016 -0.21045214 -0.16133972 0.17930262; -0.43767962 0.22737151 0.1986822 0.71358293 0.39698264 -0.35199252 -0.36271957 0.68317175; -0.5515306 0.7061721 0.094453976 -0.43427184 -0.3254669 0.7382981 0.05300062 0.5999323; -0.7740696 0.34048513 -0.16695552 0.5024247 -0.6962158 -0.15788634 -0.19976024 -0.1962582; -0.2690459 -0.66632074 0.40241653 -0.64508957 -0.1676396 0.26341486 0.24788627 -0.12475211; -0.7222454 -0.42830938 0.5290629 -0.4817089 -0.14141986 0.028442098 -0.43946353 0.2232964; 0.48485422 -0.18870239 -0.6670593 0.29468715 0.3413936 0.06973601 -0.52349377 -0.41981265; -0.41206646 0.35407877 0.2138094 0.36741307 -0.32585493 -0.19491813 -0.7442241 0.040227354; -0.64932895 0.55335957 0.25934076 -0.45103806 0.60108 0.4349758 -0.9321656 0.8701914; -0.60544336 0.27684203 0.06540652 0.42771336 -0.26987347 0.56500614 -0.4150196 -0.31804633; -0.82066774 -0.27268243 0.46054098 -0.2730553 0.22719508 0.015610166 -0.88704866 0.7498642; 0.28800654 0.26777756 1.0138662 -0.8935669 0.23755403 0.17644142 -1.0990795 0.47627932; -0.36108485 0.2799544 0.7699313 -0.43001822 0.45007658 0.90909237 -1.186951 0.7066328; 0.43395674 0.3999122 -0.6922677 0.59945565 -1.1694397 -0.37203658 -0.15417767 -0.054846168; -0.4359103 0.5161594 -0.24220508 0.6989991 -0.8204905 -0.14530644 0.083356254 -0.08953531; -0.5547889 0.75421494 0.69012684 -0.53708357 0.63538307 0.056566037 -0.4801861 0.6526276], bias = Float32[0.29331326; 0.29186064; 0.13662067; 0.32338896; 0.34485674; 0.11691242; 0.012926703; 0.97716063; 0.3833292; 0.15369596; 0.0057431255; 0.35792914; 0.4336497; 0.06991861; -0.014430795; 0.32400763; 0.8478263; 1.1282555; 1.154213; 0.7697143; 0.68986756; 1.2385734; 0.6564312; 0.91835827; 0.4764083; 0.049980924; 0.30275682; 0.613904; 0.6341952; 0.024868174; -0.119812824; 0.8763963;;]), classifier = (weight = Float32[-1.4362625 0.7662956 1.2438403 1.2575152 -0.936867 0.11635197 -0.26727143 1.2291836], bias = Float32[-0.5941662;;])), (lstm_cell = (rng = Random.Xoshiro(0x2026f555c226bf09, 0x8a6bb764b93cadda, 0x5ba3c10439600514, 0x446f763658f71987, 0x22a21880af5dc689),), classifier = NamedTuple())) +``` + + + + +## Saving the Model + + +We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model + + +```julia +@save "trained_model.jld2" {compress = true} ps_trained st_trained +``` + + +Let's try loading the model + + +```julia +@load "trained_model.jld2" ps_trained st_trained +``` + + +``` +2-element Vector{Symbol}: + :ps_trained + :st_trained +``` + + +--- + + +*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* + diff --git a/previews/PR474/tutorials/beginner/trained_model.jld2 b/previews/PR474/tutorials/beginner/trained_model.jld2 new file mode 100644 index 000000000..070255baf Binary files /dev/null and b/previews/PR474/tutorials/beginner/trained_model.jld2 differ diff --git a/previews/PR474/tutorials/index.md b/previews/PR474/tutorials/index.md new file mode 100644 index 000000000..2282766a7 --- /dev/null +++ b/previews/PR474/tutorials/index.md @@ -0,0 +1,129 @@ +--- +layout: page +--- + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/previews/PR474/tutorials/intermediate/1_NeuralODE.md b/previews/PR474/tutorials/intermediate/1_NeuralODE.md new file mode 100644 index 000000000..e84a5d57c --- /dev/null +++ b/previews/PR474/tutorials/intermediate/1_NeuralODE.md @@ -0,0 +1,436 @@ + + + + + +# MNIST Classification using Neural ODEs + + +To understand Neural ODEs, users should look up [these lecture notes](https://book.sciml.ai/notes/11-Differentiable_Programming_and_Neural_Differential_Equations/). We recommend users to directly use [DiffEqFlux.jl](https://docs.sciml.ai/DiffEqFlux/stable/), instead of implementing Neural ODEs from scratch. + + + + +## Package Imports + + +```julia +using Lux, ComponentArrays, SciMLSensitivity, LuxAMDGPU, LuxCUDA, Optimisers, + OrdinaryDiffEq, Random, Statistics, Zygote, OneHotArrays, InteractiveUtils +import MLDatasets: MNIST +import MLUtils: DataLoader, splitobs +CUDA.allowscalar(false) +``` + + + + +## Loading MNIST + + +```julia +function loadmnist(batchsize, train_split) + # Load MNIST: Only 1500 for demonstration purposes + N = 1500 + dataset = MNIST(; split=:train) + imgs = dataset.features[:, :, 1:N] + labels_raw = dataset.targets[1:N] + + # Process images into (H,W,C,BS) batches + x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) + y_data = onehotbatch(labels_raw, 0:9) + (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split) + + return ( + # Use DataLoader to automatically minibatch and shuffle the data + DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true), + # Don't shuffle the test data + DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false)) +end +``` + + +``` +loadmnist (generic function with 1 method) +``` + + + + +## Define the Neural ODE Layer + + +The NeuralODE is a ContainerLayer, which stores a `model`. The parameters and states of the NeuralODE are same as those of the underlying model. + + +```julia +struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, Se, T, K} <: + Lux.AbstractExplicitContainerLayer{(:model,)} + model::M + solver::So + sensealg::Se + tspan::T + kwargs::K +end + +function NeuralODE(model::Lux.AbstractExplicitLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), + sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()), kwargs...) + return NeuralODE(model, solver, sensealg, tspan, kwargs) +end +``` + + +``` +Main.var"##225".NeuralODE +``` + + +OrdinaryDiffEq.jl can deal with non-Vector Inputs! However, certain discrete sensitivities like `ReverseDiffAdjoint` can't handle non-Vector inputs. Hence, we need to convert the input and output of the ODE solver to a Vector. + + +```julia +function (n::NeuralODE)(x, ps, st) + function dudt(u, p, t) + u_, st = n.model(reshape(u, size(x)), p, st) + return vec(u_) + end + prob = ODEProblem{false}(ODEFunction{false}(dudt), vec(x), n.tspan, ps) + return solve(prob, n.solver; sensealg=n.sensealg, n.kwargs...), st +end + +@views diffeqsol_to_array(l::Int, x::ODESolution) = reshape(last(x.u), (l, :)) +@views diffeqsol_to_array(l::Int, x::AbstractMatrix) = reshape(x[:, end], (l, :)) +``` + + +``` +diffeqsol_to_array (generic function with 2 methods) +``` + + + + +## Create and Initialize the Neural ODE Layer + + +```julia +function create_model(model_fn=NeuralODE; dev=gpu_device(), use_named_tuple::Bool=false, + sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP())) + # Construct the Neural ODE Model + model = Chain(FlattenLayer(), + Dense(784 => 20, tanh), + model_fn(Chain(Dense(20 => 10, tanh), Dense(10 => 10, tanh), Dense(10 => 20, tanh)); + save_everystep=false, reltol=1.0f-3, abstol=1.0f-3, save_start=false, + sensealg), + Base.Fix1(diffeqsol_to_array, 20), + Dense(20 => 10)) + + rng = Random.default_rng() + Random.seed!(rng, 0) + + ps, st = Lux.setup(rng, model) + ps = (use_named_tuple ? ps : ComponentArray(ps)) |> dev + st = st |> dev + + return model, ps, st +end +``` + + +``` +create_model (generic function with 2 methods) +``` + + + + +## Define Utility Functions + + +```julia +logitcrossentropy(y_pred, y) = mean(-sum(y .* logsoftmax(y_pred); dims=1)) + +function loss(x, y, model, ps, st) + y_pred, st = model(x, ps, st) + return logitcrossentropy(y_pred, y), st +end + +function accuracy(model, ps, st, dataloader; dev=gpu_device()) + total_correct, total = 0, 0 + st = Lux.testmode(st) + cpu_dev = cpu_device() + for (x, y) in dataloader + target_class = onecold(y) + predicted_class = onecold(cpu_dev(first(model(dev(x), ps, st)))) + total_correct += sum(target_class .== predicted_class) + total += length(target_class) + end + return total_correct / total +end +``` + + +``` +accuracy (generic function with 1 method) +``` + + + + +## Training + + +```julia +function train(model_function; cpu::Bool=false, kwargs...) + dev = cpu ? cpu_device() : gpu_device() + model, ps, st = create_model(model_function; dev, kwargs...) + + # Training + train_dataloader, test_dataloader = loadmnist(128, 0.9) + + opt = Adam(0.001f0) + st_opt = Optimisers.setup(opt, ps) + + ### Warmup the Model + img = dev(train_dataloader.data[1][:, :, :, 1:1]) + lab = dev(train_dataloader.data[2][:, 1:1]) + loss(img, lab, model, ps, st) + (l, _), back = pullback(p -> loss(img, lab, model, p, st), ps) + back((one(l), nothing)) + + ### Lets train the model + nepochs = 9 + for epoch in 1:nepochs + stime = time() + for (x, y) in train_dataloader + x = dev(x) + y = dev(y) + (l, st), back = pullback(p -> loss(x, y, model, p, st), ps) + ### We need to add `nothing`s equal to the number of returned values - 1 + gs = back((one(l), nothing))[1] + st_opt, ps = Optimisers.update(st_opt, ps, gs) + end + ttime = time() - stime + + println("[$epoch/$nepochs] \t Time $(round(ttime; digits=2))s \t Training Accuracy: " * + "$(round(accuracy(model, ps, st, train_dataloader; dev) * 100; digits=2))% \t " * + "Test Accuracy: $(round(accuracy(model, ps, st, test_dataloader; dev) * 100; digits=2))%") + end +end + +train(NeuralODE) +``` + + +``` +[1/9] Time 2.88s Training Accuracy: 51.26% Test Accuracy: 42.67% +[2/9] Time 0.26s Training Accuracy: 71.11% Test Accuracy: 66.67% +[3/9] Time 0.28s Training Accuracy: 78.22% Test Accuracy: 70.67% +[4/9] Time 0.3s Training Accuracy: 80.81% Test Accuracy: 75.33% +[5/9] Time 0.53s Training Accuracy: 82.37% Test Accuracy: 78.0% +[6/9] Time 0.28s Training Accuracy: 84.67% Test Accuracy: 79.33% +[7/9] Time 0.28s Training Accuracy: 85.33% Test Accuracy: 80.67% +[8/9] Time 0.3s Training Accuracy: 86.81% Test Accuracy: 81.33% +[9/9] Time 0.57s Training Accuracy: 87.7% Test Accuracy: 83.33% + +``` + + +We can also change the sensealg and train the model! `GaussAdjoint` allows you to use any arbitrary parameter structure and not just a flat vector (`ComponentArray`). + + +```julia +train(NeuralODE; sensealg=GaussAdjoint(; autojacvec=ZygoteVJP()), use_named_tuple=true) +``` + + +``` +[1/9] Time 2.18s Training Accuracy: 50.89% Test Accuracy: 42.67% +[2/9] Time 0.25s Training Accuracy: 70.22% Test Accuracy: 65.33% +[3/9] Time 0.27s Training Accuracy: 77.85% Test Accuracy: 72.0% +[4/9] Time 0.42s Training Accuracy: 80.37% Test Accuracy: 74.67% +[5/9] Time 0.28s Training Accuracy: 82.59% Test Accuracy: 78.0% +[6/9] Time 0.42s Training Accuracy: 84.44% Test Accuracy: 79.33% +[7/9] Time 0.28s Training Accuracy: 85.7% Test Accuracy: 80.67% +[8/9] Time 0.28s Training Accuracy: 87.19% Test Accuracy: 82.0% +[9/9] Time 0.28s Training Accuracy: 88.15% Test Accuracy: 82.67% + +``` + + +But remember some AD backends like `ReverseDiff` is not GPU compatible. For a model this size, you will notice that training time is significantly lower for training on CPU than on GPU. + + +```julia +train(NeuralODE; sensealg=InterpolatingAdjoint(; autojacvec=ReverseDiffVJP()), cpu=true) +``` + + +``` +[1/9] Time 0.68s Training Accuracy: 50.96% Test Accuracy: 43.33% +[2/9] Time 0.09s Training Accuracy: 69.63% Test Accuracy: 66.0% +[3/9] Time 0.08s Training Accuracy: 77.93% Test Accuracy: 71.33% +[4/9] Time 0.07s Training Accuracy: 80.74% Test Accuracy: 76.67% +[5/9] Time 0.07s Training Accuracy: 82.52% Test Accuracy: 78.0% +[6/9] Time 0.07s Training Accuracy: 84.07% Test Accuracy: 78.67% +[7/9] Time 0.08s Training Accuracy: 85.33% Test Accuracy: 80.67% +[8/9] Time 0.08s Training Accuracy: 86.59% Test Accuracy: 81.33% +[9/9] Time 0.08s Training Accuracy: 87.7% Test Accuracy: 82.0% + +``` + + +For completeness, let's also test out discrete sensitivities! + + +```julia +train(NeuralODE; sensealg=ReverseDiffAdjoint(), cpu=true) +``` + + +``` +[1/9] Time 7.43s Training Accuracy: 50.96% Test Accuracy: 43.33% +[2/9] Time 7.53s Training Accuracy: 69.63% Test Accuracy: 66.0% +[3/9] Time 6.3s Training Accuracy: 77.93% Test Accuracy: 71.33% +[4/9] Time 6.98s Training Accuracy: 80.74% Test Accuracy: 76.67% +[5/9] Time 8.2s Training Accuracy: 82.52% Test Accuracy: 78.0% +[6/9] Time 8.95s Training Accuracy: 84.07% Test Accuracy: 78.67% +[7/9] Time 10.1s Training Accuracy: 85.33% Test Accuracy: 80.67% +[8/9] Time 8.91s Training Accuracy: 86.59% Test Accuracy: 81.33% +[9/9] Time 9.58s Training Accuracy: 87.7% Test Accuracy: 82.0% + +``` + + + + +## Alternate Implementation using Stateful Layer + + +Starting `v0.5.5`, Lux provides a `Lux.Experimental.StatefulLuxLayer` which can be used to avoid the [`Box`ing of `st`](https://github.com/JuliaLang/julia/issues/15276). + + +```julia +struct StatefulNeuralODE{M <: Lux.AbstractExplicitLayer, So, Se, T, K} <: + Lux.AbstractExplicitContainerLayer{(:model,)} + model::M + solver::So + sensealg::Se + tspan::T + kwargs::K +end + +function StatefulNeuralODE(model::Lux.AbstractExplicitLayer; solver=Tsit5(), + tspan=(0.0f0, 1.0f0), sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()), + kwargs...) + return StatefulNeuralODE(model, solver, sensealg, tspan, kwargs) +end + +function (n::StatefulNeuralODE)(x, ps, st) + st_model = Lux.Experimental.StatefulLuxLayer(n.model, ps, st) + dudt(u, p, t) = st_model(u, p) + prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps) + return solve(prob, n.solver; sensealg=n.sensealg, n.kwargs...), st_model.st +end +``` + + + + +## Train the new Stateful Neural ODE + + +```julia +train(StatefulNeuralODE) +``` + + +``` +[1/9] Time 1.02s Training Accuracy: 48.89% Test Accuracy: 38.0% +[2/9] Time 0.26s Training Accuracy: 69.41% Test Accuracy: 64.0% +[3/9] Time 0.25s Training Accuracy: 77.04% Test Accuracy: 73.33% +[4/9] Time 0.27s Training Accuracy: 80.0% Test Accuracy: 75.33% +[5/9] Time 0.3s Training Accuracy: 82.37% Test Accuracy: 76.67% +[6/9] Time 0.28s Training Accuracy: 84.37% Test Accuracy: 78.0% +[7/9] Time 0.28s Training Accuracy: 85.33% Test Accuracy: 80.67% +[8/9] Time 0.28s Training Accuracy: 86.81% Test Accuracy: 81.33% +[9/9] Time 0.28s Training Accuracy: 87.7% Test Accuracy: 82.67% + +``` + + +We might not see a significant difference in the training time, but let us investigate the type stabilities of the layers. + + + + +## Type Stability + + +```julia +model, ps, st = create_model(NeuralODE) + +model_stateful, ps_stateful, st_stateful = create_model(StatefulNeuralODE) + +x = gpu_device()(ones(Float32, 28, 28, 1, 3)); +``` + + +NeuralODE is not type stable due to the boxing of `st` + + +```julia +@code_warntype model(x, ps, st) +``` + + +``` +MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer, layer_2::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_3::Main.var"##225".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_3::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Real, NTuple{4, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##225".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{ComponentArrays.Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784), NamedTuple())), bias = ViewAxis(15681:15700, ShapedAxis((20, 1), NamedTuple())))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20), NamedTuple())), bias = ViewAxis(201:210, ShapedAxis((10, 1), NamedTuple())))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10), NamedTuple())), bias = ViewAxis(201:220, ShapedAxis((20, 1), NamedTuple())))))), layer_4 = 16241:16240, layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20), NamedTuple())), bias = ViewAxis(201:210, ShapedAxis((10, 1), NamedTuple())))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}) + from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-15/julialang/lux-dot-jl/src/layers/containers.jl:478 +Arguments + c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer, layer_2::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_3::Main.var"##225".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_3::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Real, NTuple{4, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##225".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing} + x::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer} + ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{ComponentArrays.Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784), NamedTuple())), bias = ViewAxis(15681:15700, ShapedAxis((20, 1), NamedTuple())))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20), NamedTuple())), bias = ViewAxis(201:210, ShapedAxis((10, 1), NamedTuple())))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10), NamedTuple())), bias = ViewAxis(201:220, ShapedAxis((20, 1), NamedTuple())))))), layer_4 = 16241:16240, layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20), NamedTuple())), bias = ViewAxis(201:210, ShapedAxis((10, 1), NamedTuple())))))}}} + st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), layer_4 = NamedTuple(), layer_5 = NamedTuple())) +Body::TUPLE{CUDA.CUARRAY{FLOAT32, 2, CUDA.MEM.DEVICEBUFFER}, NAMEDTUPLE{(:LAYER_1, :LAYER_2, :LAYER_3, :LAYER_4, :LAYER_5), <:TUPLE{@NAMEDTUPLE{}, @NAMEDTUPLE{}, ANY, @NAMEDTUPLE{}, @NAMEDTUPLE{}}}} +1 ─ %1 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer, layer_2::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_3::Main.var"##225".NeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_3::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Real, NTuple{4, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##225".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}} +│ %2 = Lux.applychain(%1, x, ps, st)::TUPLE{CUDA.CUARRAY{FLOAT32, 2, CUDA.MEM.DEVICEBUFFER}, NAMEDTUPLE{(:LAYER_1, :LAYER_2, :LAYER_3, :LAYER_4, :LAYER_5), <:TUPLE{@NAMEDTUPLE{}, @NAMEDTUPLE{}, ANY, @NAMEDTUPLE{}, @NAMEDTUPLE{}}}} +└── return %2 + + +``` + + +We avoid the problem entirely by using `StatefulNeuralODE` + + +```julia +@code_warntype model_stateful(x, ps_stateful, st_stateful) +``` + + +``` +MethodInstance for (::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer, layer_2::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_3::Main.var"##225".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_3::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Real, NTuple{4, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##225".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing})(::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, ::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{ComponentArrays.Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784), NamedTuple())), bias = ViewAxis(15681:15700, ShapedAxis((20, 1), NamedTuple())))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20), NamedTuple())), bias = ViewAxis(201:210, ShapedAxis((10, 1), NamedTuple())))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10), NamedTuple())), bias = ViewAxis(201:220, ShapedAxis((20, 1), NamedTuple())))))), layer_4 = 16241:16240, layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20), NamedTuple())), bias = ViewAxis(201:210, ShapedAxis((10, 1), NamedTuple())))))}}}, ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}) + from (c::Lux.Chain)(x, ps, st::NamedTuple) @ Lux /var/lib/buildkite-agent/builds/gpuci-15/julialang/lux-dot-jl/src/layers/containers.jl:478 +Arguments + c::Lux.Chain{@NamedTuple{layer_1::Lux.FlattenLayer, layer_2::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_3::Main.var"##225".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_3::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Real, NTuple{4, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##225".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing} + x::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer} + ps::ComponentArrays.ComponentVector{Float32, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{ComponentArrays.Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:15700, Axis(weight = ViewAxis(1:15680, ShapedAxis((20, 784), NamedTuple())), bias = ViewAxis(15681:15700, ShapedAxis((20, 1), NamedTuple())))), layer_3 = ViewAxis(15701:16240, Axis(layer_1 = ViewAxis(1:210, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20), NamedTuple())), bias = ViewAxis(201:210, ShapedAxis((10, 1), NamedTuple())))), layer_2 = ViewAxis(211:320, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_3 = ViewAxis(321:540, Axis(weight = ViewAxis(1:200, ShapedAxis((20, 10), NamedTuple())), bias = ViewAxis(201:220, ShapedAxis((20, 1), NamedTuple())))))), layer_4 = 16241:16240, layer_5 = ViewAxis(16241:16450, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20), NamedTuple())), bias = ViewAxis(201:210, ShapedAxis((10, 1), NamedTuple())))))}}} + st::Core.Const((layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), layer_4 = NamedTuple(), layer_5 = NamedTuple())) +Body::Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}} +1 ─ %1 = Base.getproperty(c, :layers)::@NamedTuple{layer_1::Lux.FlattenLayer, layer_2::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_3::Main.var"##225".StatefulNeuralODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_3::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, SciMLSensitivity.InterpolatingAdjoint{0, true, Val{:central}, SciMLSensitivity.ZygoteVJP}, Tuple{Float32, Float32}, Base.Pairs{Symbol, Real, NTuple{4, Symbol}, @NamedTuple{save_everystep::Bool, reltol::Float32, abstol::Float32, save_start::Bool}}}, layer_4::Lux.WrappedFunction{Base.Fix1{typeof(Main.var"##225".diffeqsol_to_array), Int64}}, layer_5::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}} +│ %2 = Lux.applychain(%1, x, ps, st)::Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}, layer_4::@NamedTuple{}, layer_5::@NamedTuple{}}} +└── return %2 + + +``` + + +Note, that we still recommend using this layer internally and not exposing this as the default API to the users. + + +--- + + +*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* + diff --git a/previews/PR474/tutorials/intermediate/2_BayesianNN-22.png b/previews/PR474/tutorials/intermediate/2_BayesianNN-22.png new file mode 100644 index 000000000..0dcf73b75 Binary files /dev/null and b/previews/PR474/tutorials/intermediate/2_BayesianNN-22.png differ diff --git a/previews/PR474/tutorials/intermediate/2_BayesianNN-29.png b/previews/PR474/tutorials/intermediate/2_BayesianNN-29.png new file mode 100644 index 000000000..6ba498cb4 Binary files /dev/null and b/previews/PR474/tutorials/intermediate/2_BayesianNN-29.png differ diff --git a/previews/PR474/tutorials/intermediate/2_BayesianNN-8.png b/previews/PR474/tutorials/intermediate/2_BayesianNN-8.png new file mode 100644 index 000000000..a485a2cc5 Binary files /dev/null and b/previews/PR474/tutorials/intermediate/2_BayesianNN-8.png differ diff --git a/previews/PR474/tutorials/intermediate/2_BayesianNN.md b/previews/PR474/tutorials/intermediate/2_BayesianNN.md new file mode 100644 index 000000000..67f089985 --- /dev/null +++ b/previews/PR474/tutorials/intermediate/2_BayesianNN.md @@ -0,0 +1,359 @@ + + + + + +# Bayesian Neural Network + + +We borrow this tutorial from the [official Turing Docs](https://turing.ml/dev/tutorials/03-bayesian-neural-network/). We will show how the explicit parameterization of Lux enables first-class composability with packages which expect flattened out parameter vectors. + + +We will use [Turing.jl](https://turing.ml) with [Lux.jl](https://lux.csail.mit.edu/stable) to implement implementing a classification algorithm. Lets start by importing the relevant libraries. + + +```julia +# Import libraries +using Lux, Turing, CairoMakie, Random, ReverseDiff, Functors, MakiePublication + +# Hide sampling progress +Turing.setprogress!(false); + +# Use reverse_diff due to the number of parameters in neural networks +Turing.setadbackend(:reversediff) +``` + + +``` +:reversediff +``` + + + + +## Generating data + + +Our goal here is to use a Bayesian neural network to classify points in an artificial dataset. The code below generates data points arranged in a box-like pattern and displays a graph of the dataset we'll be working with. + + +```julia +# Number of points to generate +N = 80 +M = round(Int, N / 4) +rng = Random.default_rng() +Random.seed!(rng, 1234) + +# Generate artificial data +x1s = rand(rng, Float32, M) * 4.5f0; +x2s = rand(rng, Float32, M) * 4.5f0; +xt1s = Array([[x1s[i] + 0.5f0; x2s[i] + 0.5f0] for i in 1:M]) +x1s = rand(rng, Float32, M) * 4.5f0; +x2s = rand(rng, Float32, M) * 4.5f0; +append!(xt1s, Array([[x1s[i] - 5.0f0; x2s[i] - 5.0f0] for i in 1:M])) + +x1s = rand(rng, Float32, M) * 4.5f0; +x2s = rand(rng, Float32, M) * 4.5f0; +xt0s = Array([[x1s[i] + 0.5f0; x2s[i] - 5.0f0] for i in 1:M]) +x1s = rand(rng, Float32, M) * 4.5f0; +x2s = rand(rng, Float32, M) * 4.5f0; +append!(xt0s, Array([[x1s[i] - 5.0f0; x2s[i] + 0.5f0] for i in 1:M])) + +# Store all the data for later +xs = [xt1s; xt0s] +ts = [ones(2 * M); zeros(2 * M)] + +# Plot data points + +function plot_data() + x1 = first.(xt1s) + y1 = last.(xt1s) + x2 = first.(xt0s) + y2 = last.(xt0s) + + fig = with_theme(theme_web()) do + fig = Figure() + ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y") + + scatter!(ax, x1, y1; markersize=8, color=:red, strokecolor=:black, strokewidth=1) + scatter!(ax, x2, y2; markersize=8, color=:blue, strokecolor=:black, strokewidth=1) + + return fig + end + + return fig +end + +plot_data() +``` + + +![](2_BayesianNN-8.png) + + + + +## Building the Neural Network + + +The next step is to define a feedforward neural network where we express our parameters as distributions, and not single points as with traditional neural networks. For this we will use `Dense` to define liner layers and compose them via `Chain`, both are neural network primitives from `Lux`. The network `nn` we will create will have two hidden layers with `tanh` activations and one output layer with `sigmoid` activation, as shown below. + + +The `nn` is an instance that acts as a function and can take data, parameters and current state as inputs and output predictions. We will define distributions on the neural network parameters. + + +```julia +# Construct a neural network using Lux +nn = Chain(Dense(2 => 3, tanh), Dense(3 => 2, tanh), Dense(2 => 1, sigmoid)) + +# Initialize the model weights and state +ps, st = Lux.setup(rng, nn) + +Lux.parameterlength(nn) # number of paraemters in NN +``` + + +``` +20 +``` + + +The probabilistic model specification below creates a parameters variable, which has IID normal variables. The parameters represents all parameters of our neural net (weights and biases). + + +```julia +# Create a regularization term and a Gaussian prior variance term. +alpha = 0.09 +sig = sqrt(1.0 / alpha) +``` + + +``` +3.3333333333333335 +``` + + +Construct named tuple from a sampled parameter vector. We could also use ComponentArrays here and simply broadcast to avoid doing this. But let's do it this way to avoid dependencies. + + +```julia +function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple) + @assert length(ps_new) == Lux.parameterlength(ps) + i = 1 + function get_ps(x) + z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x)) + i += length(x) + return z + end + return fmap(get_ps, ps) +end + +# Specify the probabilistic model. +@model function bayes_nn(xs, ts) + global st + + # Sample the parameters + nparameters = Lux.parameterlength(nn) + parameters ~ MvNormal(zeros(nparameters), sig .* ones(nparameters)) + + # Forward NN to make predictions + preds, st = nn(xs, vector_to_parameters(parameters, ps), st) + + # Observe each prediction. + for i in 1:length(ts) + ts[i] ~ Bernoulli(preds[i]) + end +end +``` + + +``` +bayes_nn (generic function with 2 methods) +``` + + +Inference can now be performed by calling sample. We use the HMC sampler here. + + +```julia +# Perform inference. +N = 5000 +ch = sample(bayes_nn(reduce(hcat, xs), ts), HMC(0.05, 4), N) +``` + + +``` +Chains MCMC chain (5000×30×1 Array{Float64, 3}): + +Iterations = 1:1:5000 +Number of chains = 1 +Samples per chain = 5000 +Wall duration = 41.06 seconds +Compute duration = 41.06 seconds +parameters = parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7], parameters[8], parameters[9], parameters[10], parameters[11], parameters[12], parameters[13], parameters[14], parameters[15], parameters[16], parameters[17], parameters[18], parameters[19], parameters[20] +internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size + +Summary Statistics + parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec + Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64 + + parameters[1] -2.0521 0.7625 0.1721 20.6080 90.3161 1.0078 0.5019 + parameters[2] -3.0083 3.1557 0.9621 12.7726 21.0999 1.5426 0.3110 + parameters[3] 1.3245 0.8417 0.2430 13.3461 24.3246 1.4587 0.3250 + parameters[4] -1.4877 1.0772 0.2493 15.1287 26.2870 1.1494 0.3684 + parameters[5] 0.7760 0.5463 0.1115 27.1685 102.1179 1.2247 0.6616 + parameters[6] 3.4906 1.7171 0.4559 14.8031 45.0051 1.0628 0.3605 + parameters[7] -3.9710 1.4956 0.3945 15.0640 45.7012 1.1078 0.3669 + parameters[8] -3.4420 1.3931 0.3730 14.6316 21.0172 1.4495 0.3563 + parameters[9] -4.5405 3.0887 0.9352 12.0159 20.0739 1.6221 0.2926 + parameters[10] -3.2319 2.7507 0.8110 12.1252 22.4200 1.3478 0.2953 + parameters[11] -3.5758 1.4822 0.4076 13.6637 33.8975 1.1729 0.3328 + parameters[12] -1.0570 2.4954 0.7492 11.7223 21.0856 1.9121 0.2855 + parameters[13] 2.9443 1.5863 0.4630 12.8882 20.7551 1.4709 0.3139 + parameters[14] 1.7753 3.9004 1.1683 11.7170 20.8646 1.4595 0.2853 + parameters[15] -2.5492 0.9557 0.2088 22.5673 38.5675 1.1274 0.5496 + parameters[16] 1.3716 2.0800 0.5691 14.0441 34.5833 1.1682 0.3420 + parameters[17] -0.4643 1.9644 0.5871 11.5634 21.2536 1.8976 0.2816 + parameters[18] 1.1676 1.4714 0.3611 16.5438 23.5525 1.0044 0.4029 + parameters[19] -5.9863 0.9522 0.1526 38.9319 54.0819 1.0467 0.9481 + parameters[20] -1.7391 1.2860 0.2821 21.6130 50.2786 1.0414 0.5263 + +Quantiles + parameters 2.5% 25.0% 50.0% 75.0% 97.5% + Symbol Float64 Float64 Float64 Float64 Float64 + + parameters[1] -3.3949 -2.6340 -2.0720 -1.4720 -0.6806 + parameters[2] -10.8256 -4.7942 -1.1807 -0.8554 -0.4700 + parameters[3] 0.2152 0.6476 1.1033 1.9170 3.1882 + parameters[4] -5.7669 -1.7303 -1.3366 -0.9107 -0.3575 + parameters[5] -0.4868 0.5344 0.8637 1.1188 1.7370 + parameters[6] 0.6259 2.2367 3.5342 4.5333 6.7872 + parameters[7] -6.6437 -5.1932 -3.9708 -2.7398 -1.3800 + parameters[8] -5.8277 -4.4686 -3.5359 -2.5644 -0.2680 + parameters[9] -11.4662 -7.0356 -3.2984 -2.0051 -0.8413 + parameters[10] -8.1245 -4.8659 -3.5562 -1.3859 2.0080 + parameters[11] -6.4989 -4.7170 -3.5606 -2.3902 -1.0114 + parameters[12] -4.6375 -3.1774 -1.3213 0.7895 4.0594 + parameters[13] 0.8759 1.7717 2.6272 3.5623 6.9655 + parameters[14] -5.8894 -2.1307 3.2297 4.9656 7.0390 + parameters[15] -4.6926 -3.0672 -2.5009 -1.9073 -0.9006 + parameters[16] -2.5810 -0.0622 0.9090 2.9472 5.4754 + parameters[17] -3.6208 -2.0939 -0.6436 1.0709 3.5339 + parameters[18] -2.1827 0.1799 1.2132 2.2497 3.7561 + parameters[19] -7.7264 -6.5933 -6.0417 -5.4946 -3.7762 + parameters[20] -4.6060 -2.5530 -1.6430 -0.8152 0.5254 + +``` + + +Now we extract the parameter samples from the sampled chain as θ (this is of size `5000 x 20` where `5000` is the number of iterations and `20` is the number of parameters). We'll use these primarily to determine how good our model's classifier is. + + +```julia +# Extract all weight and bias parameters. +θ = MCMCChains.group(ch, :parameters).value; +``` + + + + +## Prediction Visualization + + +```julia +# A helper to run the nn through data `x` using parameters `θ` +nn_forward(x, θ) = first(nn(x, vector_to_parameters(θ, ps), st)) + +# Plot the data we have. +fig = plot_data() + +# Find the index that provided the highest log posterior in the chain. +_, i = findmax(ch[:lp]) + +# Extract the max row value from i. +i = i.I[1] + +# Plot the posterior distribution with a contour plot +x1_range = collect(range(-6; stop=6, length=25)) +x2_range = collect(range(-6; stop=6, length=25)) +Z = [nn_forward([x1, x2], θ[i, :])[1] for x1 in x1_range, x2 in x2_range] +contour!(x1_range, x2_range, Z) +fig +``` + + +![](2_BayesianNN-22.png) + + +The contour plot above shows that the MAP method is not too bad at classifying our data. Now we can visualize our predictions. + + +$$ +p(\tilde{x} | X, \alpha) = \int_{\theta} p(\tilde{x} | \theta) p(\theta | X, \alpha) \approx \sum_{\theta \sim p(\theta | X, \alpha)}f_{\theta}(\tilde{x}) +$$ + + +The `nn_predict` function takes the average predicted value from a network parameterized by weights drawn from the MCMC chain. + + +```julia +# Return the average predicted value across multiple weights. +nn_predict(x, θ, num) = mean([first(nn_forward(x, view(θ, i, :))) for i in 1:10:num]) +``` + + +``` +nn_predict (generic function with 1 method) +``` + + +Next, we use the `nn_predict` function to predict the value at a sample of points where the x1 and x2 coordinates range between -6 and 6. As we can see below, we still have a satisfactory fit to our data, and more importantly, we can also see where the neural network is uncertain about its predictions much easier–-those regions between cluster boundaries. + + +Plot the average prediction. + + +```julia +fig = plot_data() + +n_end = 1500 +x1_range = collect(range(-6; stop=6, length=25)) +x2_range = collect(range(-6; stop=6, length=25)) +Z = [nn_predict([x1, x2], θ, n_end)[1] for x1 in x1_range, x2 in x2_range] +contour!(x1_range, x2_range, Z) +fig +``` + + +![](2_BayesianNN-29.png) + + +Suppose we are interested in how the predictive power of our Bayesian neural network evolved between samples. In that case, the following graph displays an animation of the contour plot generated from the network weights in samples 1 to 5,000. + + +```julia +fig = plot_data() +Z = [first(nn_forward([x1, x2], θ[1, :])) for x1 in x1_range, x2 in x2_range] +c = contour!(x1_range, x2_range, Z) +record(fig, "results.gif", 1:250:size(θ, 1)) do i + fig.current_axis[].title = "Iteration: $i" + Z = [first(nn_forward([x1, x2], θ[i, :])) for x1 in x1_range, x2 in x2_range] + c[3] = Z + return fig +end +``` + + +``` +"results.gif" +``` + + +![](results.gif) + + +--- + + +*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* + diff --git a/previews/PR474/tutorials/intermediate/3_HyperNet.md b/previews/PR474/tutorials/intermediate/3_HyperNet.md new file mode 100644 index 000000000..2808a90f2 --- /dev/null +++ b/previews/PR474/tutorials/intermediate/3_HyperNet.md @@ -0,0 +1,250 @@ + + + + + +# Training a HyperNetwork on MNIST and FashionMNIST + + + + +## Package Imports + + +```julia +using Lux, ComponentArrays, LuxAMDGPU, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, + Optimisers, Random, Setfield, Statistics, Zygote +CUDA.allowscalar(false) +``` + + + + +## Loading Datasets + + +```julia +function _load_dataset(dset, n_train::Int, n_eval::Int, batchsize::Int) + imgs, labels = dset(:train)[1:n_train] + x_train, y_train = reshape(imgs, 28, 28, 1, n_train), onehotbatch(labels, 0:9) + + imgs, labels = dset(:test)[1:n_eval] + x_test, y_test = reshape(imgs, 28, 28, 1, n_eval), onehotbatch(labels, 0:9) + + return (DataLoader((x_train, y_train); batchsize=min(batchsize, n_train), shuffle=true), + DataLoader((x_test, y_test); batchsize=min(batchsize, n_eval), shuffle=false)) +end + +function load_datasets(n_train=1024, n_eval=32, batchsize=256) + return _load_dataset.((MNIST, FashionMNIST), n_train, n_eval, batchsize) +end +``` + + +``` +load_datasets (generic function with 4 methods) +``` + + + + +## Implement a HyperNet Layer + + +```julia +struct HyperNet{W <: Lux.AbstractExplicitLayer, C <: Lux.AbstractExplicitLayer, A} <: + Lux.AbstractExplicitContainerLayer{(:weight_generator, :core_network)} + weight_generator::W + core_network::C + ca_axes::A +end + +function HyperNet(w::Lux.AbstractExplicitLayer, c::Lux.AbstractExplicitLayer) + ca_axes = Lux.initialparameters(Random.default_rng(), c) |> ComponentArray |> getaxes + return HyperNet(w, c, ca_axes) +end + +function Lux.initialparameters(rng::AbstractRNG, h::HyperNet) + return (weight_generator=Lux.initialparameters(rng, h.weight_generator),) +end + +function (hn::HyperNet)(x, ps, st::NamedTuple) + ps_new, st_ = hn.weight_generator(x, ps.weight_generator, st.weight_generator) + @set! st.weight_generator = st_ + return ComponentArray(vec(ps_new), hn.ca_axes), st +end + +function (hn::HyperNet)((x, y)::T, ps, st::NamedTuple) where {T <: Tuple} + ps_ca, st = hn(x, ps, st) + pred, st_ = hn.core_network(y, ps_ca, st.core_network) + @set! st.core_network = st_ + return pred, st +end +``` + + + + +## Create and Initialize the HyperNet + + +```julia +function create_model() + # Doesn't need to be a MLP can have any Lux Layer + core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10)) + weight_generator = Chain(Embedding(2 => 32), + Dense(32, 64, relu), + Dense(64, Lux.parameterlength(core_network))) + + model = HyperNet(weight_generator, core_network) + + rng = Random.default_rng() + Random.seed!(rng, 0) + + ps, st = Lux.setup(rng, model) .|> gpu_device() + + return model, ps, st +end +``` + + +``` +create_model (generic function with 1 method) +``` + + + + +## Define Utility Functions + + +```julia +logitcrossentropy(y_pred, y) = mean(-sum(y .* logsoftmax(y_pred); dims=1)) + +function loss(data_idx, x, y, model, ps, st) + y_pred, st = model((data_idx, x), ps, st) + return logitcrossentropy(y_pred, y), st +end + +function accuracy(model, ps, st, dataloader, data_idx) + total_correct, total = 0, 0 + st = Lux.testmode(st) + dev = gpu_device() + cpu_dev = cpu_device() + for (x, y) in dataloader + x = x |> dev + y = y |> dev + target_class = onecold(cpu_dev(y)) + predicted_class = onecold(cpu_dev(model((data_idx, x), ps, st)[1])) + total_correct += sum(target_class .== predicted_class) + total += length(target_class) + end + return total_correct / total +end +``` + + +``` +accuracy (generic function with 1 method) +``` + + + + +## Training + + +```julia +function train() + model, ps, st = create_model() + + # Training + dataloaders = load_datasets() + + opt = Adam(0.001f0) + st_opt = Optimisers.setup(opt, ps) + + dev = gpu_device() + + ### Warmup the Model + img, lab = dev(dataloaders[1][1].data[1][:, :, :, 1:1]), + dev(dataloaders[1][1].data[2][:, 1:1]) + loss(1, img, lab, model, ps, st) + (l, _), back = pullback(p -> loss(1, img, lab, model, p, st), ps) + back((one(l), nothing)) + + ### Lets train the model + nepochs = 9 + for epoch in 1:nepochs + for data_idx in 1:2 + train_dataloader, test_dataloader = dataloaders[data_idx] + + stime = time() + for (x, y) in train_dataloader + x = x |> dev + y = y |> dev + (l, st), back = pullback(p -> loss(data_idx, x, y, model, p, st), ps) + gs = back((one(l), nothing))[1] + st_opt, ps = Optimisers.update(st_opt, ps, gs) + end + ttime = time() - stime + + train_acc = round(accuracy(model, ps, st, train_dataloader, data_idx) * 100; + digits=2) + test_acc = round(accuracy(model, ps, st, test_dataloader, data_idx) * 100; + digits=2) + + data_name = data_idx == 1 ? "MNIST" : "FashionMNIST" + + println("[$epoch/$nepochs] \t $data_name Time $(round(ttime; digits=2))s \t " * + "Training Accuracy: $(train_acc)% \t Test Accuracy: $(test_acc)%") + end + end + + for data_idx in 1:2 + train_dataloader, test_dataloader = dataloaders[data_idx] + train_acc = round(accuracy(model, ps, st, train_dataloader, data_idx) * 100; + digits=2) + test_acc = round(accuracy(model, ps, st, test_dataloader, data_idx) * 100; digits=2) + + data_name = data_idx == 1 ? "MNIST" : "FashionMNIST" + + println("[FINAL] \t $data_name Training Accuracy: $(train_acc)% \t " * + "Test Accuracy: $(test_acc)%") + end +end + +train() +``` + + +``` +[1/9] MNIST Time 2.73s Training Accuracy: 54.49% Test Accuracy: 56.25% +[1/9] FashionMNIST Time 0.02s Training Accuracy: 57.13% Test Accuracy: 53.12% +[2/9] MNIST Time 0.02s Training Accuracy: 77.73% Test Accuracy: 62.5% +[2/9] FashionMNIST Time 0.02s Training Accuracy: 63.18% Test Accuracy: 68.75% +[3/9] MNIST Time 0.02s Training Accuracy: 83.11% Test Accuracy: 87.5% +[3/9] FashionMNIST Time 0.02s Training Accuracy: 60.55% Test Accuracy: 59.38% +[4/9] MNIST Time 0.02s Training Accuracy: 90.43% Test Accuracy: 84.38% +[4/9] FashionMNIST Time 0.02s Training Accuracy: 67.19% Test Accuracy: 65.62% +[5/9] MNIST Time 0.02s Training Accuracy: 90.53% Test Accuracy: 87.5% +[5/9] FashionMNIST Time 0.02s Training Accuracy: 71.88% Test Accuracy: 62.5% +[6/9] MNIST Time 0.04s Training Accuracy: 93.36% Test Accuracy: 87.5% +[6/9] FashionMNIST Time 0.01s Training Accuracy: 75.49% Test Accuracy: 68.75% +[7/9] MNIST Time 0.02s Training Accuracy: 94.34% Test Accuracy: 87.5% +[7/9] FashionMNIST Time 0.01s Training Accuracy: 76.56% Test Accuracy: 71.88% +[8/9] MNIST Time 0.01s Training Accuracy: 95.21% Test Accuracy: 90.62% +[8/9] FashionMNIST Time 0.01s Training Accuracy: 79.3% Test Accuracy: 71.88% +[9/9] MNIST Time 0.02s Training Accuracy: 96.97% Test Accuracy: 90.62% +[9/9] FashionMNIST Time 0.02s Training Accuracy: 73.54% Test Accuracy: 68.75% +[FINAL] MNIST Training Accuracy: 96.19% Test Accuracy: 87.5% +[FINAL] FashionMNIST Training Accuracy: 73.54% Test Accuracy: 68.75% + +``` + + +--- + + +*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* + diff --git a/previews/PR474/tutorials/intermediate/results.gif b/previews/PR474/tutorials/intermediate/results.gif new file mode 100644 index 000000000..7dcbc9bf9 Binary files /dev/null and b/previews/PR474/tutorials/intermediate/results.gif differ