Skip to content

Commit

Permalink
More informative error messages for failing to load devices
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 7, 2024
1 parent e0089c4 commit 93cedad
Show file tree
Hide file tree
Showing 10 changed files with 101 additions and 17 deletions.
1 change: 0 additions & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ steps:
cuda: "*"
env:
GROUP: "CUDA"
JULIA_MPI_TEST_NPROCS: 2 # Needs to be same as number of GPUs for NCCL
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 240
matrix:
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ docs/src/tutorials/advanced
*.log

bench/benchmark_results.json
*.cov
24 changes: 15 additions & 9 deletions ext/LuxMPIExt.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,37 @@
module LuxMPIExt

using Lux: MPIBackend, NCCLBackend, DistributedUtils, __is_extension_loaded, __unwrap_val,
MPI_CUDA_AWARE, MPI_ROCM_AWARE
using Lux: MPIBackend, NCCLBackend, DistributedUtils, __unwrap_val, MPI_CUDA_AWARE,
MPI_ROCM_AWARE
using LuxDeviceUtils: AbstractLuxDevice, LuxCUDADevice, LuxAMDGPUDevice, cpu_device,
set_device!, __is_functional
using MPI: MPI

function DistributedUtils.__initialize(
::Type{MPIBackend}; cuda_devices=nothing, amdgpu_devices=nothing)
::Type{MPIBackend}; cuda_devices=nothing, amdgpu_devices=nothing,
force_cuda::Bool=false, caller::String="", force_amdgpu::Bool=false) # Undocumented internal kwarg
!MPI.Initialized() && MPI.Init()
DistributedUtils.MPI_Initialized[] = true

local_rank = MPI.Comm_rank(MPI.COMM_WORLD)

if cuda_devices !== missing && __is_functional(LuxCUDADevice)
if cuda_devices === nothing
set_device!(LuxCUDADevice, nothing, local_rank)
set_device!(LuxCUDADevice, nothing, local_rank + 1)
else
set_device!(LuxCUDADevice, cuda_devices[local_rank])
set_device!(LuxCUDADevice, cuda_devices[local_rank + 1])
end
elseif force_cuda
error(lazy"CUDA devices are not functional (or `LuxCUDA.jl` not loaded) and `force_cuda` is set to `true`. This is caused by backend: $(caller).")
end

if amdgpu_devices !== missing && __is_functional(LuxAMDGPUDevice)
if amdgpu_devices === nothing
set_device!(LuxAMDGPUDevice, nothing, local_rank)
set_device!(LuxAMDGPUDevice, nothing, local_rank + 1)
else
set_device!(LuxAMDGPUDevice, amdgpu_devices[local_rank])
set_device!(LuxAMDGPUDevice, amdgpu_devices[local_rank + 1])
end
elseif force_amdgpu
error(lazy"AMDGPU devices are not functional (or `LuxAMDGPU.jl` not loaded) and `force_amdgpu` is set to `true`. This is caused by backend: $(caller).")
end

return
Expand All @@ -47,8 +52,9 @@ end

function DistributedUtils.__bcast!(
backend::MPIBackend, sendbuf, recvbuf, dev::AbstractLuxDevice; root=0)
MPI.Bcast!(sendbuf, recvbuf, backend.comm; root)
return recvbuf
return DistributedUtils.__bcast!(
backend, ifelse(DistributedUtils.local_rank(backend) == root, sendbuf, recvbuf),
dev; root)
end

for (aware, dType) in ((MPI_CUDA_AWARE, LuxCUDADevice), (MPI_ROCM_AWARE, LuxAMDGPUDevice))
Expand Down
3 changes: 2 additions & 1 deletion ext/LuxMPINCCLExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ using Setfield: @set!
function DistributedUtils.__initialize(
::Type{NCCLBackend}; cuda_devices=nothing, amdgpu_devices=missing)
@assert amdgpu_devices===missing "`AMDGPU` is not supported by `NCCL`."
DistributedUtils.__initialize(MPIBackend; cuda_devices, amdgpu_devices)
DistributedUtils.__initialize(
MPIBackend; cuda_devices, force_cuda=true, caller="NCCLBackend", amdgpu_devices)
DistributedUtils.NCCL_Initialized[] = true
return
end
Expand Down
2 changes: 0 additions & 2 deletions src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ const CRC = ChainRulesCore

const NAME_TYPE = Union{Nothing, String, Symbol}

@inline __is_extension_loaded(x) = Val(false)

# Utilities
include("utils.jl")

Expand Down
79 changes: 77 additions & 2 deletions test/distributed/common_distributedtest.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,91 @@
using Lux, MPI, NCCL, Test
using LuxAMDGPU, LuxCUDA

const input_args = length(ARGS) == 2 ? ARGS : ("CPU", "mpi")
const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend
const dev = input_args[1] == "CPU" ? LuxCPUDevice() :
(input_args[1] == "CUDA" ? LuxCUDADevice() : LuxAMDGPUDevice())
const aType = input_args[1] == "CPU" ? Array :
(input_args[1] == "CUDA" ? CuArray : ROCArray)

DistributedUtils.initialize(backend_type)
backend = DistributedUtils.get_distributed_backend(backend_type)

@test DistributedUtils.initialized(backend_type)

# Should always hold true
@test DistributedUtils.local_rank(backend) < DistributedUtils.total_workers(backend)
rank = DistributedUtils.local_rank(backend)
nworkers = DistributedUtils.total_workers(backend)
@test rank < nworkers

# Test the communication primitives
# Test the communication primitives
## broacast!
for arrType in (Array, aType)
sendbuf = (rank == 0) ? arrType(ones(512)) : arrType(zeros(512))
recvbuf = arrType(zeros(512))

DistributedUtils.bcast!(backend, sendbuf, recvbuf; root=0)

rank != 0 && @test all(recvbuf .== 1)

sendrecvbuf = (rank == 0) ? arrType(ones(512)) : arrType(zeros(512))
DistributedUtils.bcast!(backend, sendrecvbuf; root=0)

@test all(sendrecvbuf .== 1)
end

## reduce!
for arrType in (Array, aType)
sendbuf = arrType(fill(Float64(rank + 1), 512))
recvbuf = arrType(zeros(512))

DistributedUtils.reduce!(backend, sendbuf, recvbuf, +; root=0)

rank == 0 && @test all(recvbuf .≈ sum(1:nworkers))

sendbuf .= rank + 1

DistributedUtils.reduce!(backend, sendbuf, recvbuf, DistributedUtils.avg; root=0)

rank == 0 && @test all(recvbuf .≈ sum(1:nworkers) / nworkers)

sendrecvbuf = arrType(fill(Float64(rank + 1), 512))

DistributedUtils.reduce!(backend, sendrecvbuf, +; root=0)

rank == 0 && @test all(sendrecvbuf .≈ sum(1:nworkers))

sendrecvbuf .= rank + 1

DistributedUtils.reduce!(backend, sendrecvbuf, DistributedUtils.avg; root=0)

rank == 0 && @test all(sendrecvbuf .≈ sum(1:nworkers) / nworkers)
end

## allreduce!
for arrType in (Array, aType)
sendbuf = arrType(fill(Float64(rank + 1), 512))
recvbuf = arrType(zeros(512))

DistributedUtils.allreduce!(backend, sendbuf, recvbuf, +)

@test all(recvbuf .≈ sum(1:nworkers))

sendbuf .= rank + 1

DistributedUtils.allreduce!(backend, sendbuf, recvbuf, DistributedUtils.avg)

@test all(recvbuf .≈ sum(1:nworkers) / nworkers)

sendrecvbuf = arrType(fill(Float64(rank + 1), 512))

DistributedUtils.allreduce!(backend, sendrecvbuf, +)

@test all(sendrecvbuf .≈ sum(1:nworkers))

sendrecvbuf .= rank + 1

DistributedUtils.allreduce!(backend, sendrecvbuf, DistributedUtils.avg)

@test all(sendrecvbuf .≈ sum(1:nworkers) / nworkers)
end
1 change: 1 addition & 0 deletions test/distributed/data_distributedtest.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Lux, MLUtils, MPI, NCCL, Random, Test
using LuxAMDGPU, LuxCUDA

const input_args = length(ARGS) == 2 ? ARGS : ("CPU", "mpi")
const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend
Expand Down
1 change: 1 addition & 0 deletions test/distributed/optimizer_distributedtest.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Lux, MPI, NCCL, Optimisers, Random, Test
using LuxAMDGPU, LuxCUDA

const input_args = length(ARGS) == 2 ? ARGS : ("CPU", "mpi")
const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend
Expand Down
1 change: 1 addition & 0 deletions test/distributed/synchronize_distributedtest.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ComponentArrays, Lux, MPI, NCCL, Optimisers, Random, Test
using LuxAMDGPU, LuxCUDA

const input_args = length(ARGS) == 2 ? ARGS : ("CPU", "mpi")
const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend
Expand Down
5 changes: 3 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using ReTestItems

# ReTestItems.runtests(@__DIR__)
ReTestItems.runtests(@__DIR__)

# Distributed Tests
using MPI, Pkg, Test
Expand All @@ -27,10 +27,11 @@ include("setup_modes.jl")
@testset "MODE: $(mode)" for (mode, aType, dev, ongpu) in MODES
backends = mode == "CUDA" ? ("mpi", "nccl") : ("mpi",)
for backend_type in backends
np = backend_type == "nccl" ? min(nprocs, length(CUDA.devices())) : nprocs
@testset "Backend: $(backend_type)" begin
@testset "$(basename(file))" for file in distributedtestfiles
@info "Running $file with $backend_type backend on $mode device"
run(`$(MPI.mpiexec()) -n $(nprocs) $(Base.julia_cmd()) --color=yes \
run(`$(MPI.mpiexec()) -n $(np) $(Base.julia_cmd()) --color=yes \
--code-coverage=user --project=$(cur_proj) --startup-file=no $(file) \
$(mode) $(backend_type)`)
Test.@test true
Expand Down

0 comments on commit 93cedad

Please sign in to comment.