Skip to content

Commit

Permalink
[flang][AMDGPU] Convert math ops to AMD GPU library calls instead of …
Browse files Browse the repository at this point in the history
…libm calls (llvm#99517)

This patch invokes a pass when compiling for an AMDGPU target to lower
math operations to AMD GPU library calls library calls instead of libm
calls.
  • Loading branch information
jsjodin authored Sep 10, 2024
1 parent f58312e commit 4290e34
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 1 deletion.
1 change: 1 addition & 0 deletions flang/lib/Optimizer/CodeGen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ add_flang_library(FIRCodeGen
MLIRMathToFuncs
MLIRMathToLLVM
MLIRMathToLibm
MLIRMathToROCDL
MLIROpenMPToLLVM
MLIROpenACCDialect
MLIRBuiltinToLLVMIRTranslation
Expand Down
12 changes: 11 additions & 1 deletion flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down Expand Up @@ -3671,6 +3672,14 @@ class FIRToLLVMLowering
// as passes here.
mlir::OpPassManager mathConvertionPM("builtin.module");

bool isAMDGCN = fir::getTargetTriple(mod).isAMDGCN();
// If compiling for AMD target some math operations must be lowered to AMD
// GPU library calls, the rest can be converted to LLVM intrinsics, which
// is handled in the mathToLLVM conversion. The lowering to libm calls is
// not needed since all math operations are handled this way.
if (isAMDGCN)
mathConvertionPM.addPass(mlir::createConvertMathToROCDL());

// Convert math::FPowI operations to inline implementation
// only if the exponent's width is greater than 32, otherwise,
// it will be lowered to LLVM intrinsic operation by a later conversion.
Expand Down Expand Up @@ -3710,7 +3719,8 @@ class FIRToLLVMLowering
pattern);
// Math operations that have not been converted yet must be converted
// to Libm.
mlir::populateMathToLibmConversionPatterns(pattern);
if (!isAMDGCN)
mlir::populateMathToLibmConversionPatterns(pattern);
mlir::populateComplexToLLVMConversionPatterns(typeConverter, pattern);
mlir::populateVectorToLLVMConversionPatterns(typeConverter, pattern);

Expand Down
184 changes: 184 additions & 0 deletions flang/test/Lower/OpenMP/math-amdgpu.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
!REQUIRES: amdgpu-registered-target
!RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-llvm -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s

subroutine omp_pow_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call float @__ocml_pow_f32(float {{.*}}, float {{.*}})
y = x ** x
end subroutine omp_pow_f32

subroutine omp_pow_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call double @__ocml_pow_f64(double {{.*}}, double {{.*}})
y = x ** x
end subroutine omp_pow_f64

subroutine omp_sin_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call float @__ocml_sin_f32(float {{.*}})
y = sin(x)
end subroutine omp_sin_f32

subroutine omp_sin_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call double @__ocml_sin_f64(double {{.*}})
y = sin(x)
end subroutine omp_sin_f64

subroutine omp_abs_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call contract float @llvm.fabs.f32(float {{.*}})
y = abs(x)
end subroutine omp_abs_f32

subroutine omp_abs_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call contract double @llvm.fabs.f64(double {{.*}})
y = abs(x)
end subroutine omp_abs_f64

subroutine omp_atan_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call float @__ocml_atan_f32(float {{.*}})
y = atan(x)
end subroutine omp_atan_f32

subroutine omp_atan_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call double @__ocml_atan_f64(double {{.*}})
y = atan(x)
end subroutine omp_atan_f64

subroutine omp_atan2_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call float @__ocml_atan2_f32(float {{.*}}, float {{.*}})
y = atan2(x, x)
end subroutine omp_atan2_f32

subroutine omp_atan2_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call double @__ocml_atan2_f64(double {{.*}}, double {{.*}})
y = atan2(x ,x)
end subroutine omp_atan2_f64

subroutine omp_cos_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call float @__ocml_cos_f32(float {{.*}})
y = cos(x)
end subroutine omp_cos_f32

subroutine omp_cos_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call double @__ocml_cos_f64(double {{.*}})
y = cos(x)
end subroutine omp_cos_f64

subroutine omp_erf_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call float @__ocml_erf_f32(float {{.*}})
y = erf(x)
end subroutine omp_erf_f32

subroutine omp_erf_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call double @__ocml_erf_f64(double {{.*}})
y = erf(x)
end subroutine omp_erf_f64

subroutine omp_exp_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call contract float @llvm.exp.f32(float {{.*}})
y = exp(x)
end subroutine omp_exp_f32

subroutine omp_exp_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call double @__ocml_exp_f64(double {{.*}})
y = exp(x)
end subroutine omp_exp_f64

subroutine omp_log_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call contract float @llvm.log.f32(float {{.*}})
y = log(x)
end subroutine omp_log_f32

subroutine omp_log_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call double @__ocml_log_f64(double {{.*}})
y = log(x)
end subroutine omp_log_f64

subroutine omp_log10_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call float @__ocml_log10_f32(float {{.*}})
y = log10(x)
end subroutine omp_log10_f32

subroutine omp_log10_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call double @__ocml_log10_f64(double {{.*}})
y = log10(x)
end subroutine omp_log10_f64

subroutine omp_sqrt_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call contract float @llvm.sqrt.f32(float {{.*}})
y = sqrt(x)
end subroutine omp_sqrt_f32

subroutine omp_sqrt_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call contract double @llvm.sqrt.f64(double {{.*}})
y = sqrt(x)
end subroutine omp_sqrt_f64

subroutine omp_tan_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call float @__ocml_tan_f32(float {{.*}})
y = tan(x)
end subroutine omp_tan_f32

subroutine omp_tan_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call double @__ocml_tan_f64(double {{.*}})
y = tan(x)
end subroutine omp_tan_f64

subroutine omp_tanh_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call float @__ocml_tanh_f32(float {{.*}})
y = tanh(x)
end subroutine omp_tanh_f32

subroutine omp_tanh_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call double @__ocml_tanh_f64(double {{.*}})
y = tanh(x)
end subroutine omp_tanh_f64

0 comments on commit 4290e34

Please sign in to comment.