Skip to content

Commit

Permalink
Environments: Work well with external libraries that set their own GP…
Browse files Browse the repository at this point in the history
…U device
  • Loading branch information
tbennun committed Aug 30, 2023
1 parent 30fdcf7 commit 8a8744e
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 43 deletions.
2 changes: 1 addition & 1 deletion dace/libraries/blas/environments/cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class cuBLAS:
def handle_setup_code(node):
location = node.location
if not location or "gpu" not in node.location:
location = 0
location = -1 # -1 means current device
else:
try:
location = int(location["gpu"])
Expand Down
2 changes: 1 addition & 1 deletion dace/libraries/blas/environments/rocblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class rocBLAS:
def handle_setup_code(node):
location = node.location
if not location or "gpu" not in node.location:
location = 0
location = -1 # -1 means current device
else:
try:
location = int(location["gpu"])
Expand Down
12 changes: 8 additions & 4 deletions dace/libraries/blas/include/dace_cublas.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ static void CheckCublasError(cublasStatus_t const& status) {
}

static cublasHandle_t CreateCublasHandle(int device) {
if (cudaSetDevice(device) != cudaSuccess) {
throw std::runtime_error("Failed to set CUDA device.");
if (device >= 0) {
if (cudaSetDevice(device) != cudaSuccess) {
throw std::runtime_error("Failed to set CUDA device.");
}
}
cublasHandle_t handle;
CheckCublasError(cublasCreate(&handle));
Expand Down Expand Up @@ -65,8 +67,10 @@ class _CublasConstants {
}

_CublasConstants(int device) {
if (cudaSetDevice(device) != cudaSuccess) {
throw std::runtime_error("Failed to set CUDA device.");
if (device >= 0) {
if (cudaSetDevice(device) != cudaSuccess) {
throw std::runtime_error("Failed to set CUDA device.");
}
}
// Allocate constant zero with the largest used size
cudaMalloc(&zero_, sizeof(cuDoubleComplex) * 1);
Expand Down
60 changes: 32 additions & 28 deletions dace/libraries/blas/include/dace_rocblas.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ static void CheckRocblasError(rocblas_status const& status) {
}

static rocblas_handle CreateRocblasHandle(int device) {
if (hipSetDevice(device) != hipSuccess) {
throw std::runtime_error("Failed to set HIP device.");
if (device >= 0) {
if (hipSetDevice(device) != hipSuccess) {
throw std::runtime_error("Failed to set HIP device.");
}
}
rocblas_handle handle;
CheckRocblasError(rocblas_create_handle(&handle));
Expand Down Expand Up @@ -68,53 +70,55 @@ class _RocblasConstants {
}

_RocblasConstants(int device) {
if (hipSetDevice(device) != hipSuccess) {
throw std::runtime_error("Failed to set HIP device.");
if (device >= 0) {
if (hipSetDevice(device) != hipSuccess) {
throw std::runtime_error("Failed to set HIP device.");
}
}
// Allocate constant zero with the largest used size
hipMalloc(&zero_, sizeof(hipDoubleComplex) * 1);
hipMemset(zero_, 0, sizeof(hipDoubleComplex) * 1);
(void)hipMalloc(&zero_, sizeof(hipDoubleComplex) * 1);
(void)hipMemset(zero_, 0, sizeof(hipDoubleComplex) * 1);

// Allocate constant one
hipMalloc(&half_pone_, sizeof(__half) * 1);
(void)hipMalloc(&half_pone_, sizeof(__half) * 1);
__half half_pone = __float2half(1.0f);
hipMemcpy(half_pone_, &half_pone, sizeof(__half) * 1,
(void)hipMemcpy(half_pone_, &half_pone, sizeof(__half) * 1,
hipMemcpyHostToDevice);
hipMalloc(&float_pone_, sizeof(float) * 1);
(void)hipMalloc(&float_pone_, sizeof(float) * 1);
float float_pone = 1.0f;
hipMemcpy(float_pone_, &float_pone, sizeof(float) * 1,
(void)hipMemcpy(float_pone_, &float_pone, sizeof(float) * 1,
hipMemcpyHostToDevice);
hipMalloc(&double_pone_, sizeof(double) * 1);
(void)hipMalloc(&double_pone_, sizeof(double) * 1);
double double_pone = 1.0;
hipMemcpy(double_pone_, &double_pone, sizeof(double) * 1,
(void)hipMemcpy(double_pone_, &double_pone, sizeof(double) * 1,
hipMemcpyHostToDevice);
hipMalloc(&complex64_pone_, sizeof(hipComplex) * 1);
(void)hipMalloc(&complex64_pone_, sizeof(hipComplex) * 1);
hipComplex complex64_pone = make_hipFloatComplex(1.0f, 0.0f);
hipMemcpy(complex64_pone_, &complex64_pone, sizeof(hipComplex) * 1,
(void)hipMemcpy(complex64_pone_, &complex64_pone, sizeof(hipComplex) * 1,
hipMemcpyHostToDevice);
hipMalloc(&complex128_pone_, sizeof(hipDoubleComplex) * 1);
(void)hipMalloc(&complex128_pone_, sizeof(hipDoubleComplex) * 1);
hipDoubleComplex complex128_pone = make_hipDoubleComplex(1.0, 0.0);
hipMemcpy(complex128_pone_, &complex128_pone, sizeof(hipDoubleComplex) * 1,
(void)hipMemcpy(complex128_pone_, &complex128_pone, sizeof(hipDoubleComplex) * 1,
hipMemcpyHostToDevice);

// Allocate custom factors and default to zero
hipMalloc(&custom_alpha_, sizeof(hipDoubleComplex) * 1);
hipMemset(custom_alpha_, 0, sizeof(hipDoubleComplex) * 1);
hipMalloc(&custom_beta_, sizeof(hipDoubleComplex) * 1);
hipMemset(custom_beta_, 0, sizeof(hipDoubleComplex) * 1);
(void)hipMalloc(&custom_alpha_, sizeof(hipDoubleComplex) * 1);
(void)hipMemset(custom_alpha_, 0, sizeof(hipDoubleComplex) * 1);
(void)hipMalloc(&custom_beta_, sizeof(hipDoubleComplex) * 1);
(void)hipMemset(custom_beta_, 0, sizeof(hipDoubleComplex) * 1);
}

_RocblasConstants(_RocblasConstants const&) = delete;

~_RocblasConstants() {
hipFree(zero_);
hipFree(half_pone_);
hipFree(float_pone_);
hipFree(double_pone_);
hipFree(complex64_pone_);
hipFree(complex128_pone_);
hipFree(custom_alpha_);
hipFree(custom_beta_);
(void)hipFree(zero_);
(void)hipFree(half_pone_);
(void)hipFree(float_pone_);
(void)hipFree(double_pone_);
(void)hipFree(complex64_pone_);
(void)hipFree(complex128_pone_);
(void)hipFree(custom_alpha_);
(void)hipFree(custom_beta_);
}

_RocblasConstants& operator=(_RocblasConstants const&) = delete;
Expand Down
2 changes: 1 addition & 1 deletion dace/libraries/lapack/environments/cusolverdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class cuSolverDn:
def handle_setup_code(node):
location = node.location
if not location or "gpu" not in node.location:
location = 0
location = -1 # -1 means current device
else:
try:
location = int(location["gpu"])
Expand Down
6 changes: 4 additions & 2 deletions dace/libraries/lapack/include/dace_cusolverdn.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ static void CheckCusolverDnError(cusolverStatus_t const& status) {
}

static cusolverDnHandle_t CreateCusolverDnHandle(int device) {
if (cudaSetDevice(device) != cudaSuccess) {
throw std::runtime_error("Failed to set CUDA device.");
if (device >= 0) {
if (cudaSetDevice(device) != cudaSuccess) {
throw std::runtime_error("Failed to set CUDA device.");
}
}
cusolverDnHandle_t handle;
CheckCusolverDnError(cusolverDnCreate(&handle));
Expand Down
2 changes: 1 addition & 1 deletion dace/libraries/linalg/environments/cutensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class cuTensor:
def handle_setup_code(node):
location = node.location
if not location or "gpu" not in node.location:
location = 0
location = -1 # -1 means current device
else:
try:
location = int(location["gpu"])
Expand Down
6 changes: 4 additions & 2 deletions dace/libraries/linalg/include/dace_cutensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ static void CheckCuTensorError(cutensorStatus_t const& status) {
}

static cutensorHandle_t CreateCuTensorHandle(int device) {
if (cudaSetDevice(device) != cudaSuccess) {
throw std::runtime_error("Failed to set CUDA device.");
if (device >= 0) {
if (cudaSetDevice(device) != cudaSuccess) {
throw std::runtime_error("Failed to set CUDA device.");
}
}
cutensorHandle_t handle;
CheckCuTensorError(cutensorInit(&handle));
Expand Down
2 changes: 1 addition & 1 deletion dace/libraries/sparse/environments/cusparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class cuSPARSE:
def handle_setup_code(node):
location = node.location
if not location or "gpu" not in node.location:
location = 0
location = -1 # -1 means current device
else:
try:
location = int(location["gpu"])
Expand Down
6 changes: 4 additions & 2 deletions dace/libraries/sparse/include/dace_cusparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ static void CheckCusparseError(cusparseStatus_t const& status) {
}

static cusparseHandle_t CreateCusparseHandle(int device) {
if (cudaSetDevice(device) != cudaSuccess) {
throw std::runtime_error("Failed to set CUDA device.");
if (device >= 0) {
if (cudaSetDevice(device) != cudaSuccess) {
throw std::runtime_error("Failed to set CUDA device.");
}
}
cusparseHandle_t handle;
CheckCusparseError(cusparseCreate(&handle));
Expand Down

0 comments on commit 8a8744e

Please sign in to comment.