Skip to content

Commit

Permalink
Add tests for the distributed backends
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 7, 2024
1 parent 2743579 commit 6cbdc76
Show file tree
Hide file tree
Showing 17 changed files with 272 additions and 102 deletions.
1 change: 1 addition & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ steps:
cuda: "*"
env:
GROUP: "CUDA"
JULIA_MPI_TEST_NPROCS: 2 # Needs to be same as number of GPUs for NCCL
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 240
matrix:
Expand Down
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -135,4 +138,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ADTypes", "Adapt", "Aqua", "ChainRules", "ChainRulesCore", "ComponentArrays", "ExplicitImports", "Flux", "Functors", "Logging", "LuxAMDGPU", "LuxCUDA", "LuxCore", "LuxLib", "LuxTestUtils", "Optimisers", "Random", "ReTestItems", "Reexport", "ReverseDiff", "Setfield", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"]
test = ["ADTypes", "Adapt", "Aqua", "ChainRules", "ChainRulesCore", "ComponentArrays", "ExplicitImports", "Flux", "Functors", "Logging", "LuxAMDGPU", "LuxCUDA", "LuxCore", "LuxLib", "LuxTestUtils", "MLUtils", "MPI", "NCCL", "Optimisers", "Pkg", "Random", "ReTestItems", "Reexport", "ReverseDiff", "Setfield", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"]
2 changes: 1 addition & 1 deletion docs/src/manual/distributed_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
!!! tip

For a fully functional example, see the
[ImageNet Training Example](https://github.com/LuxDL/Lux.jl/tree/main/examples/ImageNet)
[ImageNet Training Example](https://github.com/LuxDL/Lux.jl/tree/main/examples/ImageNet).

DDP Training using `Lux.DistributedUtils` is a spiritual successor to
[FluxMPI.jl](https://github.com/avik-pal/FluxMPI.jl), but has some key differences.
Expand Down
13 changes: 10 additions & 3 deletions examples/ImageNet/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Configurations = "5218b696-f38b-4ac9-8b61-a12ec717816d"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0"
Format = "1fa38f19-a742-5d3f-a2b9-30dd87b9d5f8"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Expand All @@ -30,17 +30,24 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Augmentor = "0.6"
Boltz = "0.1, 0.2, 0.3"
Configurations = "0.17"
Formatting = "0.4"
FLoops = "0.2"
FileIO = "1.16"
Format = "1.3"
Functors = "0.2, 0.3, 0.4"
Images = "0.26"
JLD2 = "0.4.46"
JpegTurbo = "0.1"
Lux = "0.4, 0.5"
LuxAMDGPU = "0.1, 0.2"
LuxCUDA = "0.2, 0.3"
MLUtils = "0.2.10, 0.3, 0.4"
MPI = "0.20.19"
Metalhead = "0.9"
NCCL = "0.1.1"
OneHotArrays = "0.1, 0.2"
Optimisers = "0.2, 0.3"
Setfield = "0.8.2, 1"
ParameterSchedulers = "0.4"
Setfield = "1"
SimpleConfig = "0.1"
Statistics = "1"
Zygote = "0.6"
11 changes: 6 additions & 5 deletions examples/ImageNet/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ import FLoops: ThreadedEx
import Metalhead
import MPI, NCCL
using LuxAMDGPU, LuxCUDA

using Formatting # TODO: Get rid of this
using Format

# Distributed Training: NCCL for NVIDIA GPUs and MPI for anything else
if LuxCUDA.functional()
Expand Down Expand Up @@ -79,10 +78,12 @@ function construct(cfg::OptimizerConfig)
end

if cfg.scheduler.name == "cosine"
scheduler = CosineAnnealSchedule(cfg.learning_rate, cfg.learning_rate / 100,
cfg.scheduler.cycle_length; dampen=cfg.scheduler.damp_factor)
l0 = cfg.learning_rate
l1 = cfg.learning_rate / 100
scheduler = ComposedSchedule(CosAnneal(l0, l1, cfg.scheduler.cycle_length),
Step(l0, cfg.scheduler.damp_factor, cfg.scheduler.cycle_length))
elseif cfg.scheduler.name == "constant"
scheduler = ConstantSchedule(cfg.learning_rate)
scheduler = Constant(cfg.learning_rate)
elseif cfg.scheduler.name == "step"
scheduler = Step(
cfg.learning_rate, cfg.scheduler.lr_step_decay, cfg.scheduler.lr_step)
Expand Down
57 changes: 5 additions & 52 deletions examples/ImageNet/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,51 +77,6 @@ function load_checkpoint(fname::String)
end
end

# Parameter Scheduling
## Copied from ParameterSchedulers.jl due to its heavy dependencies
struct CosineAnnealSchedule{restart, T, S <: Integer}
range::T
offset::T
dampen::T
period::S

function CosineAnnealSchedule(
lambda_0, lambda_1, period; restart::Bool=true, dampen=1.0f0)
range = abs(lambda_0 - lambda_1)
offset = min(lambda_0, lambda_1)
return new{restart, typeof(range), typeof(period)}(range, offset, dampen, period)
end
end

function (s::CosineAnnealSchedule{true})(t)
d = s.dampen^div(t - 1, s.period)
return (s.range * (1 + cos(pi * mod(t - 1, s.period) / s.period)) / 2 + s.offset) / d
end

function (s::CosineAnnealSchedule{false})(t)
return s.range * (1 + cos(pi * (t - 1) / s.period)) / 2 + s.offset
end

struct Step{T, S}
start::T
decay::T
step_sizes::S

function Step(start::T, decay::T, step_sizes::S) where {T, S}
_step_sizes = (S <: Integer) ? Iterators.repeated(step_sizes) : step_sizes

return new{T, typeof(_step_sizes)}(start, decay, _step_sizes)
end
end

(s::Step)(t) = s.start * s.decay^(searchsortedfirst(s.step_sizes, t - 1) - 1)

struct ConstantSchedule{T}
val::T
end

(s::ConstantSchedule)(t) = s.val

# Tracking
@kwdef mutable struct AverageMeter
fmtstr
Expand All @@ -132,7 +87,7 @@ end
end

function AverageMeter(name::String, fmt::String)
fmtstr = Formatting.FormatExpr("$name {1:$fmt} ({2:$fmt})")
fmtstr = FormatExpr("$name {1:$fmt} ({2:$fmt})")
return AverageMeter(; fmtstr=fmtstr)
end

Expand Down Expand Up @@ -160,7 +115,7 @@ function reset_meter!(meter::AverageMeter)
end

function print_meter(meter::AverageMeter)
return Formatting.printfmt(meter.fmtstr, meter.val, meter.average)
return printfmt(meter.fmtstr, meter.val, meter.average)
end

# ProgressMeter
Expand All @@ -171,10 +126,9 @@ end

function ProgressMeter(num_batches::Int, meters::NTuple{N}, prefix::String="") where {N}
fmt = "%" * string(length(string(num_batches))) * "d"
fmt2 = "{1:" * string(length(string(num_batches))) * "d}"
prefix = prefix != "" ? endswith(prefix, " ") ? prefix : prefix * " " : ""
batch_fmtstr = Formatting.generate_formatter("$prefix[$fmt/" *
Formatting.sprintf1(fmt, num_batches) *
"]")
batch_fmtstr = FormatExpr("$prefix[$fmt2/" * cfmt(fmt, num_batches) * "]")
return ProgressMeter{N}(batch_fmtstr, meters)
end

Expand All @@ -184,8 +138,7 @@ function reset_meter!(meter::ProgressMeter)
end

function print_meter(meter::ProgressMeter, batch::Int)
base_str = meter.batch_fmtstr(batch)
print(base_str)
printfmt(meter.batch_fmtstr, batch)
foreach(x -> (print("\t"); print_meter(x)), meter.meters[1:end])
println()
return nothing
Expand Down
4 changes: 2 additions & 2 deletions ext/LuxComponentArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ Lux.__named_tuple(ca::ComponentArray) = NamedTuple(ca)
# Distributed Functionality
function DistributedUtils.synchronize!!(
backend::Lux.AbstractLuxDistributedBackend, ps::ComponentArray; root::Int=0)
ps_synced = DistributedUtils.synchronize!!(backend, getdata(ps); root)
return ComponentArray(ps_synced, getaxes(ps))
ps_synced = DistributedUtils.synchronize!!(backend, ComponentArrays.getdata(ps); root)
return ComponentArray(ps_synced, ComponentArrays.getaxes(ps))
end

end
23 changes: 12 additions & 11 deletions ext/LuxMPIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ DistributedUtils.local_rank(backend::MPIBackend) = MPI.Comm_rank(backend.comm)
DistributedUtils.total_workers(backend::MPIBackend) = MPI.Comm_size(backend.comm)

# Broadcast

function DistributedUtils.__bcast!(
backend::MPIBackend, sendrecvbuf, dev::AbstractLuxDevice; root=0)
MPI.Bcast!(sendrecvbuf, backend.comm; root)
Expand Down Expand Up @@ -78,20 +77,21 @@ for (aware, dType) in ((MPI_CUDA_AWARE, LuxCUDADevice), (MPI_ROCM_AWARE, LuxAMDG
end

# Allreduce

function DistributedUtils.__allreduce!(
backend::MPIBackend, sendrecvbuf, op::F, dev::AbstractLuxDevice) where {F}
MPI.Allreduce!(sendrecvbuf, op, backend.comm)
if op === typeof(DistributedUtils.avg)
mpiop = ifelse(op === DistributedUtils.avg, +, op)
MPI.Allreduce!(sendrecvbuf, mpiop, backend.comm)
if op === DistributedUtils.avg
sendrecvbuf ./= DistributedUtils.total_workers(backend)
end
return sendrecvbuf
end

function DistributedUtils.__allreduce!(
backend::MPIBackend, sendbuf, recvbuf, op::F, dev::AbstractLuxDevice) where {F}
MPI.Allreduce!(sendbuf, recvbuf, op, backend.comm)
if op === typeof(DistributedUtils.avg)
mpiop = ifelse(op === DistributedUtils.avg, +, op)
MPI.Allreduce!(sendbuf, recvbuf, mpiop, backend.comm)
if op === DistributedUtils.avg
recvbuf ./= DistributedUtils.total_workers(backend)
end
return recvbuf
Expand Down Expand Up @@ -123,20 +123,21 @@ for (aware, dType) in ((MPI_CUDA_AWARE, LuxCUDADevice), (MPI_ROCM_AWARE, LuxAMDG
end

# Reduce

function DistributedUtils.__reduce!(backend::MPIBackend, sendrecvbuf, op::F,
dev::AbstractLuxDevice; root::Int) where {F}
MPI.Reduce!(sendrecvbuf, op, backend.comm; root)
if op === typeof(DistributedUtils.avg)
mpiop = ifelse(op === DistributedUtils.avg, +, op)
MPI.Reduce!(sendrecvbuf, mpiop, backend.comm; root)
if op === DistributedUtils.avg
sendrecvbuf ./= DistributedUtils.total_workers(backend)
end
return sendrecvbuf
end

function DistributedUtils.__reduce!(backend::MPIBackend, sendbuf, recvbuf, op::F,
dev::AbstractLuxDevice; root::Int) where {F}
MPI.Reduce!(sendbuf, recvbuf, op, backend.comm; root)
if op === typeof(DistributedUtils.avg)
mpiop = ifelse(op === DistributedUtils.avg, +, op)
MPI.Reduce!(sendbuf, recvbuf, mpiop, backend.comm; root)
if op === DistributedUtils.avg
recvbuf ./= DistributedUtils.total_workers(backend)
end
return recvbuf
Expand Down
3 changes: 0 additions & 3 deletions ext/LuxMPINCCLExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ DistributedUtils.total_workers(backend::NCCLBackend) = NCCL.size(backend.comm)

# For non-CUDA Arrays, fallback to MPI
# Broadcast

function DistributedUtils.__bcast!(
backend::NCCLBackend, sendrecvbuf, ::LuxCUDADevice; root=0)
NCCL.Broadcast!(sendrecvbuf, backend.comm; root)
Expand All @@ -57,7 +56,6 @@ function DistributedUtils.__bcast!(
end

# Allreduce

function DistributedUtils.__allreduce!(
backend::NCCLBackend, sendrecvbuf, op::F, ::LuxCUDADevice) where {F}
op = ifelse(op === DistributedUtils.avg, NCCL.avg, op)
Expand All @@ -83,7 +81,6 @@ function DistributedUtils.__allreduce!(
end

# Reduce

function DistributedUtils.__reduce!(
backend::NCCLBackend, sendrecvbuf, op::F, ::LuxCUDADevice; root::Int) where {F}
op = ifelse(op === DistributedUtils.avg, NCCL.avg, op)
Expand Down
2 changes: 1 addition & 1 deletion ext/LuxOptimisersExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ end
end

function Optimisers.apply!(opt::DistributedOptimizer, state, x, y)
y_avg = allreduce!(opt.backend, y, DistributedUtils.avg)
y_avg = DistributedUtils.allreduce!(opt.backend, y, DistributedUtils.avg)
return Optimisers.apply!(opt.opt, state, x, y_avg)
end

Expand Down
15 changes: 15 additions & 0 deletions test/distributed/common_distributedtest.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using Lux, MPI, NCCL, Test

const input_args = length(ARGS) == 3 ? ARGS[1:3] :
(length(ARGS) == 2 ? (ARGS[2], "mpi") : ("CPU", "mpi"))
const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend
const dev = input_args[1] == "CPU" ? LuxCPUDevice() :
(input_args[1] == "CUDA" ? LuxCUDADevice() : LuxAMDGPUDevice())

DistributedUtils.initialize(backend_type)
backend = DistributedUtils.get_distributed_backend(backend_type)

@test DistributedUtils.initialized(backend_type)

# Should always hold true
@test DistributedUtils.local_rank(backend) < DistributedUtils.total_workers(backend)
28 changes: 28 additions & 0 deletions test/distributed/data_distributedtest.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
using Lux, MLUtils, MPI, NCCL, Random, Test

const input_args = length(ARGS) == 3 ? ARGS[1:3] :
(length(ARGS) == 2 ? (ARGS[2], "mpi") : ("CPU", "mpi"))
const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend
const dev = input_args[1] == "CPU" ? LuxCPUDevice() :
(input_args[1] == "CUDA" ? LuxCUDADevice() : LuxAMDGPUDevice())

rng = Xoshiro(1234)

DistributedUtils.initialize(backend_type)
backend = DistributedUtils.get_distributed_backend(backend_type)

data = randn(rng, Float32, 10)
dcontainer = DistributedUtils.DistributedDataContainer(backend, data)

rank = DistributedUtils.local_rank(backend)
tworkers = DistributedUtils.total_workers(backend)

if rank != tworkers - 1
@test length(dcontainer) == ceil(length(data) / tworkers)
else
@test length(dcontainer) ==
length(data) - (tworkers - 1) * ceil(length(data) / tworkers)
end

dsum = sum(Base.Fix1(MLUtils.getobs, dcontainer), 1:MLUtils.numobs(dcontainer))
@test DistributedUtils.allreduce!(backend, [dsum], +)[1] sum(data)
30 changes: 30 additions & 0 deletions test/distributed/optimizer_distributedtest.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
using Lux, MPI, NCCL, Optimisers, Random, Test

const input_args = length(ARGS) == 3 ? ARGS[1:3] :
(length(ARGS) == 2 ? (ARGS[2], "mpi") : ("CPU", "mpi"))
const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend
const dev = input_args[1] == "CPU" ? LuxCPUDevice() :
(input_args[1] == "CUDA" ? LuxCUDADevice() : LuxAMDGPUDevice())

DistributedUtils.initialize(backend_type)
backend = DistributedUtils.get_distributed_backend(backend_type)

opt = Adam(0.001f0)
ps = (a=zeros(4), b=zeros(4)) |> dev
st_opt = Optimisers.setup(opt, ps)

dopt = DistributedUtils.DistributedOptimizer(backend, opt)
st_dopt = Optimisers.setup(dopt, ps)

@test st_dopt.a.state == st_opt.a.state
@test st_dopt.b.state == st_opt.b.state

@test_nowarn DistributedUtils.synchronize!!(backend, st_dopt)

gs = (a=ones(4), b=ones(4)) |> dev

_, ps_dopt = Optimisers.update(st_dopt, ps, gs)
_, ps_opt = Optimisers.update(st_opt, ps, gs)

@test ps_dopt.aps_opt.a atol=1.0e-5 rtol=1.0e-5
@test ps_dopt.bps_opt.b atol=1.0e-5 rtol=1.0e-5
Loading

0 comments on commit 6cbdc76

Please sign in to comment.