Skip to content

Commit

Permalink
Add documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 6, 2024
1 parent 6293e2a commit 43ad012
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 10 deletions.
6 changes: 4 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ pages = [
"manual/freezing_model_parameters.md",
"manual/gpu_management.md",
"manual/migrate_from_flux.md",
"manual/weight_initializers.md"
"manual/weight_initializers.md",
"manual/distributed_utils.md"
],
"API Reference" => [
"Lux" => [
"api/Lux/layers.md",
"api/Lux/utilities.md",
"api/Lux/contrib.md",
"api/Lux/switching_frameworks.md"
"api/Lux/switching_frameworks.md",
"api/Lux/distributed_utils.md",
],
"Accelerator Support" => [
"api/Accelerator_Support/LuxAMDGPU.md",
Expand Down
9 changes: 6 additions & 3 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ export default defineConfig({
{ text: 'Built-In Layers', link: '/api/Lux/layers' },
{ text: 'Utilities', link: '/api/Lux/utilities' },
{ text: 'Experimental', link: '/api/Lux/contrib' },
{ text: 'InterOp', link: '/api/Lux/switching_frameworks' }
{ text: 'InterOp', link: '/api/Lux/switching_frameworks' },
{ text: 'DistributedUtils', link: '/api/Lux/distributed_utils' }
]
},
{
Expand Down Expand Up @@ -146,7 +147,8 @@ export default defineConfig({
{ text: 'Freezing Model Parameters', link: '/manual/freezing_model_parameters' },
{ text: 'GPU Management', link: '/manual/gpu_management' },
{ text: 'Migrating from Flux to Lux', link: '/manual/migrate_from_flux' },
{ text: 'Initializing Weights', link: '/manual/weight_initializers' }]
{ text: 'Initializing Weights', link: '/manual/weight_initializers' },
{ text: 'Distributed Data Parallel Training', link: '/manual/distributed_utils' },]
},
"/api/": {
text: 'API Reference', collapsed: false, items: [
Expand All @@ -155,7 +157,8 @@ export default defineConfig({
{ text: 'Built-In Layers', link: '/api/Lux/layers' },
{ text: 'Utilities', link: '/api/Lux/utilities' },
{ text: 'Experimental Features', link: '/api/Lux/contrib' },
{ text: 'Switching between Deep Learning Frameworks', link: '/api/Lux/switching_frameworks' }]
{ text: 'Switching between Deep Learning Frameworks', link: '/api/Lux/switching_frameworks' },
{ text: 'DistributedUtils', link: '/api/Lux/distributed_utils' }]
},
{
text: 'Accelerator Support', collapsed: false, items: [
Expand Down
58 changes: 58 additions & 0 deletions docs/src/api/Lux/distributed_utils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Distributed Utils

!!! note

These functionalities are available via the `Lux.DistributedUtils` module.

```@meta
CurrentModule = Lux
```

## Index

```@index
Pages = ["distributed_utils.md"]
```

## [Backends](@id communication-backends)

```@docs
MPIBackend
NCCLBackend
```

## Initialization

```@docs
DistributedUtils.initialize
DistributedUtils.initialized
DistributedUtils.get_distributed_backend
```

## Helper Functions

```@docs
DistributedUtils.local_rank
DistributedUtils.total_workers
```

## Communication Primitives

```@docs
DistributedUtils.allreduce!
DistributedUtils.bcast!
DistributedUtils.reduce!
DistributedUtils.synchronize!!
```

## Optimizers.jl Integration

```@docs
DistributedUtils.DistributedOptimizer
```

## MLUtils.jl Integration

```@docs
DistributedUtils.DistributedDataLoader
```
55 changes: 55 additions & 0 deletions docs/src/manual/distributed_utils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Distributed Data Parallel Training

!!! tip

For a fully functional example, see the
[ImageNet Training Example](https://github.com/LuxDL/Lux.jl/tree/main/examples/ImageNet)

DDP Training using `Lux.DistributedUtils` is a spiritual successor to
[FluxMPI.jl](https://github.com/avik-pal/FluxMPI.jl), but has some key differences.

## Guide to Integrating DistributedUtils into your code

## [GPU-Aware MPI](@id gpu-aware-mpi)

If you are using a custom MPI build that supports CUDA or ROCM, you can use the following
preferences with [Preferences.jl](https://github.com/JuliaPackaging/Preferences.jl):

1. `LuxDistributedMPICUDAAware` - Set this to `true` if your MPI build is CUDA aware.
2. `LuxDistributedMPIROCMAware` - Set this to `true` if your MPI build is ROCM aware.

By default, both of these values are set to `false`.

## Migration Guide from `FluxMPI.jl`

Let's compare the changes we need to make wrt the
[FluxMPI.jl integration guide](https://avik-pal.github.io/FluxMPI.jl/dev/guide/).

1. `FluxMPI.Init` is now [`DistributedUtils.initialize`](@ref).
2. `FluxMPI.synchronize!(x)` needs to be changed to
`x_new = DistributedUtils.synchronize!!(backend, x)`.
3. [`DistributedUtils.DistributedDataContainer`](@ref),
[`DistributedUtils.local_rank`](@ref), and
[`DistributedUtils.DistributedOptimizer`](@ref) need `backend` as the first input.

And that's pretty much it!

### Removed Functionality

1. `FluxMPI.allreduce_gradients` no longer exists. Previously this was needed when CUDA
communication was flaky, with `NCCL.jl` this is no longer the case.
2. `FluxMPIFluxModel` has been removed. `DistributedUtils` no longer works with `Flux`.

### Key Differences

1. `FluxMPI.synchronize!` is now `DistributedUtils.synchronize!!` to highlight the fact
that some of the inputs are not updated in-place.
2. All of the functions now require a [communication backend](@ref communication-backends)
as input.
3. We don't automatically determine if the MPI Implementation is CUDA or ROCM aware. See
[GPU-aware MPI](@ref gpu-aware-mpi) for more information.
4. Older [`Lux.gpu`](@ref) implementations used to "just work" with `FluxMPI.jl`. We expect
[`gpu_device`](@ref) to continue working as expected, however, we recommend using
[`gpu_device`](@ref) after calling [`DistributedUtils.initialize`](@ref) to avoid any
mismatch between the device set via `DistributedUtils` and the device stores in
[`LuxCUDADevice`](@ref) or [`LuxAMDGPUDevice`](@ref).
6 changes: 3 additions & 3 deletions ext/LuxMPINCCLExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ using Setfield: @set!

function DistributedUtils.__initialize(
::Type{NCCLBackend}; cuda_devices=nothing, amdgpu_devices=missing)
DistributedUtils.NCCL_Initialized[] = true
@assert amdgpu_devices===missing "`AMDGPU` is not supported by `NCCL`."
DistributedUtils.__initialize(Val(:MPI); cuda_devices, amdgpu_devices)
DistributedUtils.__initialize(MPIBackend; cuda_devices, amdgpu_devices)
DistributedUtils.NCCL_Initialized[] = true
return
end

function DistributedUtils.__get_distributed_backend(::Type{NCCLBackend})
unique_id = NCCL.UniqueID() # Generate on all ranks to know the type
mpi_backend = DistributedUtils.__get_distributed_backend(Val(:MPI))
mpi_backend = DistributedUtils.__get_distributed_backend(MPIBackend)
buf = [unique_id.internal...]
DistributedUtils.bcast!(mpi_backend, buf; root=0)
@set! unique_id.internal = Tuple(buf)
Expand Down
4 changes: 2 additions & 2 deletions src/distributed/backend.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ abstract type AbstractLuxDistributedBackend end
MPIBackend(comm = nothing)
Create an MPI backend for distributed training. Users should not use this function directly.
Instead use [`DistributedUtils.get_distributed_backend(Val(:NCCL))`](@ref).
Instead use [`DistributedUtils.get_distributed_backend(MPIBackend)`](@ref).
"""
struct MPIBackend{C} <: AbstractLuxDistributedBackend
comm::C
Expand All @@ -21,7 +21,7 @@ end
NCCLBackend(comm = nothing, mpi_backend = nothing)
Create an NCCL backend for distributed training. Users should not use this function
directly. Instead use [`DistributedUtils.get_distributed_backend(Val(:NCCL))`](@ref).
directly. Instead use [`DistributedUtils.get_distributed_backend(NCCLBackend)`](@ref).
"""
struct NCCLBackend{C, M <: Union{Nothing, MPIBackend}} <: AbstractLuxDistributedBackend
comm::C
Expand Down
8 changes: 8 additions & 0 deletions src/distributed/public_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ end
`data` must be compatible with `MLUtils` interface. The returned container is compatible
with `MLUtils` interface and is used to partition the dataset across the available
processes.
!!! danger
`MLUtils.jl` must be installed and loaded before using this.
"""
@concrete struct DistributedDataContainer
data
Expand Down Expand Up @@ -250,6 +254,10 @@ averages the gradients across the processes using Allreduce.
## Arguments
- `optimizer`: An Optimizer compatible with the Optimisers.jl package
!!! danger
`Optimisers.jl` must be installed and loaded before using this.
"""
function DistributedOptimizer(backend::AbstractLuxDistributedBackend, opt)
mod = Base.get_extension(@__MODULE__, :LuxOptimisersExt)
Expand Down

0 comments on commit 43ad012

Please sign in to comment.