Skip to content

Commit

Permalink
More informative error messages for failing to load devices
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 7, 2024
1 parent e0089c4 commit 3a4f1d9
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 7 deletions.
1 change: 0 additions & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ steps:
cuda: "*"
env:
GROUP: "CUDA"
JULIA_MPI_TEST_NPROCS: 2 # Needs to be same as number of GPUs for NCCL
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 240
matrix:
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ docs/src/tutorials/advanced
*.log

bench/benchmark_results.json
*.cov
15 changes: 10 additions & 5 deletions ext/LuxMPIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,31 @@ using LuxDeviceUtils: AbstractLuxDevice, LuxCUDADevice, LuxAMDGPUDevice, cpu_dev
using MPI: MPI

function DistributedUtils.__initialize(
::Type{MPIBackend}; cuda_devices=nothing, amdgpu_devices=nothing)
::Type{MPIBackend}; cuda_devices=nothing, amdgpu_devices=nothing,
force_cuda::Bool=false, caller::String="", force_amdgpu::Bool=false) # Undocumented internal kwarg
!MPI.Initialized() && MPI.Init()
DistributedUtils.MPI_Initialized[] = true

local_rank = MPI.Comm_rank(MPI.COMM_WORLD)

if cuda_devices !== missing && __is_functional(LuxCUDADevice)
if cuda_devices === nothing
set_device!(LuxCUDADevice, nothing, local_rank)
set_device!(LuxCUDADevice, nothing, local_rank + 1)
else
set_device!(LuxCUDADevice, cuda_devices[local_rank])
set_device!(LuxCUDADevice, cuda_devices[local_rank + 1])
end
elseif force_cuda
error(lazy"CUDA devices are not functional (or `LuxCUDA.jl` not loaded) and `force_cuda` is set to `true`. This is caused by backend: $(caller).")
end

if amdgpu_devices !== missing && __is_functional(LuxAMDGPUDevice)
if amdgpu_devices === nothing
set_device!(LuxAMDGPUDevice, nothing, local_rank)
set_device!(LuxAMDGPUDevice, nothing, local_rank + 1)
else
set_device!(LuxAMDGPUDevice, amdgpu_devices[local_rank])
set_device!(LuxAMDGPUDevice, amdgpu_devices[local_rank + 1])
end
elseif force_amdgpu
error(lazy"AMDGPU devices are not functional (or `LuxAMDGPU.jl` not loaded) and `force_amdgpu` is set to `true`. This is caused by backend: $(caller).")
end

return
Expand Down
3 changes: 2 additions & 1 deletion ext/LuxMPINCCLExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ using Setfield: @set!
function DistributedUtils.__initialize(
::Type{NCCLBackend}; cuda_devices=nothing, amdgpu_devices=missing)
@assert amdgpu_devices===missing "`AMDGPU` is not supported by `NCCL`."
DistributedUtils.__initialize(MPIBackend; cuda_devices, amdgpu_devices)
DistributedUtils.__initialize(
MPIBackend; cuda_devices, force_cuda=true, caller="NCCLBackend", amdgpu_devices)
DistributedUtils.NCCL_Initialized[] = true
return
end
Expand Down
1 change: 1 addition & 0 deletions test/distributed/common_distributedtest.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Lux, MPI, NCCL, Test
using LuxAMDGPU, LuxCUDA

const input_args = length(ARGS) == 2 ? ARGS : ("CPU", "mpi")
const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend
Expand Down
1 change: 1 addition & 0 deletions test/distributed/data_distributedtest.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Lux, MLUtils, MPI, NCCL, Random, Test
using LuxAMDGPU, LuxCUDA

const input_args = length(ARGS) == 2 ? ARGS : ("CPU", "mpi")
const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend
Expand Down
1 change: 1 addition & 0 deletions test/distributed/optimizer_distributedtest.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Lux, MPI, NCCL, Optimisers, Random, Test
using LuxAMDGPU, LuxCUDA

const input_args = length(ARGS) == 2 ? ARGS : ("CPU", "mpi")
const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend
Expand Down
1 change: 1 addition & 0 deletions test/distributed/synchronize_distributedtest.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ComponentArrays, Lux, MPI, NCCL, Optimisers, Random, Test
using LuxAMDGPU, LuxCUDA

const input_args = length(ARGS) == 2 ? ARGS : ("CPU", "mpi")
const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend
Expand Down

0 comments on commit 3a4f1d9

Please sign in to comment.