Skip to content

Commit

Permalink
Add documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 6, 2024
1 parent 6293e2a commit b26ed78
Show file tree
Hide file tree
Showing 13 changed files with 144 additions and 16 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ LuxCore = "0.1.12"
LuxDeviceUtils = "0.1.19"
LuxLib = "0.3.11"
LuxTestUtils = "0.1.15"
MLUtils = "0.4"
MLUtils = "0.4.3"
MacroTools = "0.5.13"
Markdown = "1.10"
NCCL = "0.1.1"
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6 changes: 4 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ pages = [
"manual/freezing_model_parameters.md",
"manual/gpu_management.md",
"manual/migrate_from_flux.md",
"manual/weight_initializers.md"
"manual/weight_initializers.md",
"manual/distributed_utils.md"
],
"API Reference" => [
"Lux" => [
"api/Lux/layers.md",
"api/Lux/utilities.md",
"api/Lux/contrib.md",
"api/Lux/switching_frameworks.md"
"api/Lux/switching_frameworks.md",
"api/Lux/distributed_utils.md",
],
"Accelerator Support" => [
"api/Accelerator_Support/LuxAMDGPU.md",
Expand Down
9 changes: 6 additions & 3 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ export default defineConfig({
{ text: 'Built-In Layers', link: '/api/Lux/layers' },
{ text: 'Utilities', link: '/api/Lux/utilities' },
{ text: 'Experimental', link: '/api/Lux/contrib' },
{ text: 'InterOp', link: '/api/Lux/switching_frameworks' }
{ text: 'InterOp', link: '/api/Lux/switching_frameworks' },
{ text: 'DistributedUtils', link: '/api/Lux/distributed_utils' }
]
},
{
Expand Down Expand Up @@ -146,7 +147,8 @@ export default defineConfig({
{ text: 'Freezing Model Parameters', link: '/manual/freezing_model_parameters' },
{ text: 'GPU Management', link: '/manual/gpu_management' },
{ text: 'Migrating from Flux to Lux', link: '/manual/migrate_from_flux' },
{ text: 'Initializing Weights', link: '/manual/weight_initializers' }]
{ text: 'Initializing Weights', link: '/manual/weight_initializers' },
{ text: 'Distributed Data Parallel Training', link: '/manual/distributed_utils' },]
},
"/api/": {
text: 'API Reference', collapsed: false, items: [
Expand All @@ -155,7 +157,8 @@ export default defineConfig({
{ text: 'Built-In Layers', link: '/api/Lux/layers' },
{ text: 'Utilities', link: '/api/Lux/utilities' },
{ text: 'Experimental Features', link: '/api/Lux/contrib' },
{ text: 'Switching between Deep Learning Frameworks', link: '/api/Lux/switching_frameworks' }]
{ text: 'Switching between Deep Learning Frameworks', link: '/api/Lux/switching_frameworks' },
{ text: 'DistributedUtils', link: '/api/Lux/distributed_utils' }]
},
{
text: 'Accelerator Support', collapsed: false, items: [
Expand Down
58 changes: 58 additions & 0 deletions docs/src/api/Lux/distributed_utils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Distributed Utils

!!! note

These functionalities are available via the `Lux.DistributedUtils` module.

```@meta
CurrentModule = Lux
```

## Index

```@index
Pages = ["distributed_utils.md"]
```

## [Backends](@id communication-backends)

```@docs
MPIBackend
NCCLBackend
```

## Initialization

```@docs
DistributedUtils.initialize
DistributedUtils.initialized
DistributedUtils.get_distributed_backend
```

## Helper Functions

```@docs
DistributedUtils.local_rank
DistributedUtils.total_workers
```

## Communication Primitives

```@docs
DistributedUtils.allreduce!
DistributedUtils.bcast!
DistributedUtils.reduce!
DistributedUtils.synchronize!!
```

## Optimizers.jl Integration

```@docs
DistributedUtils.DistributedOptimizer
```

## MLUtils.jl Integration

```@docs
DistributedUtils.DistributedDataLoader
```
3 changes: 2 additions & 1 deletion docs/src/introduction/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import Pkg; Pkg.add("Lux")
```@example quickstart
using Lux, Random, Optimisers, Zygote
# using LuxCUDA, LuxAMDGPU, Metal # Optional packages for GPU support
using Printf # For pretty printing
```

We take randomness very seriously
Expand Down Expand Up @@ -117,7 +118,7 @@ for epoch in 1:1000
return sum(abs2, y .- y_data), st_
end
gs = only(pb((one(loss), nothing)))
epoch % 100 == 1 && println("Epoch: $(epoch) | Loss: $(loss)")
epoch % 100 == 1 && @printf "Epoch: %04d \t Loss: %10.9g\n" epoch loss
Optimisers.update!(st_opt, ps, gs)
end
```
Expand Down
55 changes: 55 additions & 0 deletions docs/src/manual/distributed_utils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Distributed Data Parallel Training

!!! tip

For a fully functional example, see the
[ImageNet Training Example](https://github.com/LuxDL/Lux.jl/tree/main/examples/ImageNet)

DDP Training using `Lux.DistributedUtils` is a spiritual successor to
[FluxMPI.jl](https://github.com/avik-pal/FluxMPI.jl), but has some key differences.

## Guide to Integrating DistributedUtils into your code

## [GPU-Aware MPI](@id gpu-aware-mpi)

If you are using a custom MPI build that supports CUDA or ROCM, you can use the following
preferences with [Preferences.jl](https://github.com/JuliaPackaging/Preferences.jl):

1. `LuxDistributedMPICUDAAware` - Set this to `true` if your MPI build is CUDA aware.
2. `LuxDistributedMPIROCMAware` - Set this to `true` if your MPI build is ROCM aware.

By default, both of these values are set to `false`.

## Migration Guide from `FluxMPI.jl`

Let's compare the changes we need to make wrt the
[FluxMPI.jl integration guide](https://avik-pal.github.io/FluxMPI.jl/dev/guide/).

1. `FluxMPI.Init` is now [`DistributedUtils.initialize`](@ref).
2. `FluxMPI.synchronize!(x)` needs to be changed to
`x_new = DistributedUtils.synchronize!!(backend, x)`.
3. [`DistributedUtils.DistributedDataContainer`](@ref),
[`DistributedUtils.local_rank`](@ref), and
[`DistributedUtils.DistributedOptimizer`](@ref) need `backend` as the first input.

And that's pretty much it!

### Removed Functionality

1. `FluxMPI.allreduce_gradients` no longer exists. Previously this was needed when CUDA
communication was flaky, with `NCCL.jl` this is no longer the case.
2. `FluxMPIFluxModel` has been removed. `DistributedUtils` no longer works with `Flux`.

### Key Differences

1. `FluxMPI.synchronize!` is now `DistributedUtils.synchronize!!` to highlight the fact
that some of the inputs are not updated in-place.
2. All of the functions now require a [communication backend](@ref communication-backends)
as input.
3. We don't automatically determine if the MPI Implementation is CUDA or ROCM aware. See
[GPU-aware MPI](@ref gpu-aware-mpi) for more information.
4. Older [`Lux.gpu`](@ref) implementations used to "just work" with `FluxMPI.jl`. We expect
[`gpu_device`](@ref) to continue working as expected, however, we recommend using
[`gpu_device`](@ref) after calling [`DistributedUtils.initialize`](@ref) to avoid any
mismatch between the device set via `DistributedUtils` and the device stores in
[`LuxCUDADevice`](@ref) or [`LuxAMDGPUDevice`](@ref).
2 changes: 1 addition & 1 deletion ext/LuxComponentArraysExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module LuxComponentArraysExt

using ComponentArrays: ComponentArrays, ComponentArray, FlatAxis
using Lux: Lux
using Lux: Lux, DistributedUtils

# Empty NamedTuple: Hack to avoid breaking precompilation
function ComponentArrays.ComponentArray(data::Vector{Any}, axes::Tuple{FlatAxis})
Expand Down
4 changes: 2 additions & 2 deletions ext/LuxFluxExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ function __from_flux_adaptor(l::Flux.ConvTranspose; preserve_ps_st::Bool=false,
if preserve_ps_st
_bias = l.bias isa Bool ? nothing :
reshape(copy(l.bias), ntuple(_ -> 1, length(k))..., out_chs, 1)
return Lux.ConvTranspose(k, in_chs * groups => out_chs, l.σ; l.stride, pad,
l.dilation, groups, use_bias=!(l.bias isa Bool),
return Lux.ConvTranspose(k, in_chs * groups => out_chs, l.σ; l.stride,
pad, l.dilation, groups, use_bias=!(l.bias isa Bool),
init_weight=__copy_anonymous_closure(Lux._maybe_flip_conv_weight(l.weight)),
init_bias=__copy_anonymous_closure(_bias))
else
Expand Down
2 changes: 1 addition & 1 deletion ext/LuxMLUtilsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ function MLUtils.getobs(dc::DistributedUtils.DistributedDataContainer, idx)
return MLUtils.getobs(dc.data, dc.idxs[idx])
end

end
end
6 changes: 3 additions & 3 deletions ext/LuxMPINCCLExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ using Setfield: @set!

function DistributedUtils.__initialize(
::Type{NCCLBackend}; cuda_devices=nothing, amdgpu_devices=missing)
DistributedUtils.NCCL_Initialized[] = true
@assert amdgpu_devices===missing "`AMDGPU` is not supported by `NCCL`."
DistributedUtils.__initialize(Val(:MPI); cuda_devices, amdgpu_devices)
DistributedUtils.__initialize(MPIBackend; cuda_devices, amdgpu_devices)
DistributedUtils.NCCL_Initialized[] = true
return
end

function DistributedUtils.__get_distributed_backend(::Type{NCCLBackend})
unique_id = NCCL.UniqueID() # Generate on all ranks to know the type
mpi_backend = DistributedUtils.__get_distributed_backend(Val(:MPI))
mpi_backend = DistributedUtils.__get_distributed_backend(MPIBackend)
buf = [unique_id.internal...]
DistributedUtils.bcast!(mpi_backend, buf; root=0)
@set! unique_id.internal = Tuple(buf)
Expand Down
4 changes: 2 additions & 2 deletions src/distributed/backend.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ abstract type AbstractLuxDistributedBackend end
MPIBackend(comm = nothing)
Create an MPI backend for distributed training. Users should not use this function directly.
Instead use [`DistributedUtils.get_distributed_backend(Val(:NCCL))`](@ref).
Instead use [`DistributedUtils.get_distributed_backend(MPIBackend)`](@ref).
"""
struct MPIBackend{C} <: AbstractLuxDistributedBackend
comm::C
Expand All @@ -21,7 +21,7 @@ end
NCCLBackend(comm = nothing, mpi_backend = nothing)
Create an NCCL backend for distributed training. Users should not use this function
directly. Instead use [`DistributedUtils.get_distributed_backend(Val(:NCCL))`](@ref).
directly. Instead use [`DistributedUtils.get_distributed_backend(NCCLBackend)`](@ref).
"""
struct NCCLBackend{C, M <: Union{Nothing, MPIBackend}} <: AbstractLuxDistributedBackend
comm::C
Expand Down
8 changes: 8 additions & 0 deletions src/distributed/public_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ end
`data` must be compatible with `MLUtils` interface. The returned container is compatible
with `MLUtils` interface and is used to partition the dataset across the available
processes.
!!! danger
`MLUtils.jl` must be installed and loaded before using this.
"""
@concrete struct DistributedDataContainer
data
Expand Down Expand Up @@ -250,6 +254,10 @@ averages the gradients across the processes using Allreduce.
## Arguments
- `optimizer`: An Optimizer compatible with the Optimisers.jl package
!!! danger
`Optimisers.jl` must be installed and loaded before using this.
"""
function DistributedOptimizer(backend::AbstractLuxDistributedBackend, opt)
mod = Base.get_extension(@__MODULE__, :LuxOptimisersExt)
Expand Down

0 comments on commit b26ed78

Please sign in to comment.