diff --git a/setup.py b/setup.py index 58cfd6c1a..02636cf1f 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,4 @@ import torch -from torch.utils import cpp_extension from setuptools import setup, find_packages import subprocess @@ -7,51 +6,23 @@ import warnings import os +from torch.utils.hipify import hipify_python + # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - release = output[release_idx].split(".") - bare_metal_major = release[0] - bare_metal_minor = release[1][0] - - return raw_output, bare_metal_major, bare_metal_minor - -def check_if_rocm_pytorch(): - is_rocm_pytorch = False - if torch.__version__ >= '1.5': - from torch.utils.cpp_extension import ROCM_HOME - is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False - - return is_rocm_pytorch - -IS_ROCM_PYTORCH = check_if_rocm_pytorch() - -if not torch.cuda.is_available() and not IS_ROCM_PYTORCH: +if not torch.cuda.is_available(): # https://github.com/NVIDIA/apex/issues/486 # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). print('\nWarning: Torch did not find available GPUs on this system.\n', 'If your intention is to cross-compile, this is not an error.\n' 'By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n' - 'Volta (compute capability 7.0), Turing (compute capability 7.5),\n' - 'and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n' + 'Volta (compute capability 7.0), and Turing (compute capability 7.5).\n' 'If you wish to cross-compile for a single specific architecture,\n' 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n') if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: - _, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) - if int(bare_metal_major) == 11: - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" - else: - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" -elif not torch.cuda.is_available() and IS_ROCM_PYTORCH: - print('\nWarning: Torch did not find available GPUs on this system.\n', - 'If your intention is to cross-compile, this is not an error.\n' - 'By default, Apex will cross-compile for the same gfx targets\n' - 'used by default in ROCm PyTorch\n') + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) TORCH_MAJOR = int(torch.__version__.split('.')[0]) @@ -95,18 +66,13 @@ def check_if_rocm_pytorch(): CppExtension('apex_C', ['csrc/flatten_unflatten.cpp',])) -def get_cuda_bare_metal_version(cuda_dir): +def check_cuda_torch_binary_vs_bare_metal(cuda_dir): raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) output = raw_output.split() release_idx = output.index("release") + 1 release = output[release_idx].split(".") bare_metal_major = release[0] bare_metal_minor = release[1][0] - - return raw_output, bare_metal_major, bare_metal_minor - -def check_cuda_torch_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) torch_binary_major = torch.version.cuda.split(".")[0] torch_binary_minor = torch.version.cuda.split(".")[1] @@ -121,6 +87,14 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " "You can try commenting out this check (at your own risk).") +def check_if_rocm_pytorch(): + is_rocm_pytorch = False + if torch.__version__ >= '1.5': + from torch.utils.cpp_extension import ROCM_HOME + is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False + + return is_rocm_pytorch + # Set up macros for forward/backward compatibility hack around # https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e # and @@ -137,25 +111,6 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): version_ge_1_5 = ['-DVERSION_GE_1_5'] version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5 -if "--distributed_adam" in sys.argv: - from torch.utils.cpp_extension import CUDAExtension - sys.argv.remove("--distributed_adam") - - from torch.utils.cpp_extension import BuildExtension - cmdclass['build_ext'] = BuildExtension - - if torch.utils.cpp_extension.CUDA_HOME is None: - raise RuntimeError("--distributed_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") - else: - ext_modules.append( - CUDAExtension(name='distributed_adam_cuda', - sources=['apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp', - 'apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu'], - include_dirs=[os.path.join(this_dir, 'csrc')], - extra_compile_args={'cxx': ['-O3',] + version_dependent_macros, - 'nvcc':['-O3', - '--use_fast_math'] + version_dependent_macros})) - if "--distributed_lamb" in sys.argv: from torch.utils.cpp_extension import CUDAExtension sys.argv.remove("--distributed_lamb") @@ -179,56 +134,110 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): from torch.utils.cpp_extension import CUDAExtension sys.argv.remove("--cuda_ext") - if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH: + is_rocm_pytorch = False + if torch.__version__ >= '1.5': + from torch.utils.cpp_extension import ROCM_HOME + is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False + + if torch.utils.cpp_extension.CUDA_HOME is None and (not is_rocm_pytorch): raise RuntimeError("--cuda_ext was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") else: - if not IS_ROCM_PYTORCH: + if not is_rocm_pytorch: check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME) - print ("INFO: Building the multi-tensor apply extension.") - nvcc_args_multi_tensor = ['-lineinfo', '-O3', '--use_fast_math'] + version_dependent_macros - hipcc_args_multi_tensor = ['-O3'] + version_dependent_macros - ext_modules.append( - CUDAExtension(name='amp_C', - sources=['csrc/amp_C_frontend.cpp', - 'csrc/multi_tensor_sgd_kernel.cu', - 'csrc/multi_tensor_scale_kernel.cu', - 'csrc/multi_tensor_axpby_kernel.cu', - 'csrc/multi_tensor_l2norm_kernel.cu', - 'csrc/multi_tensor_lamb_stage_1.cu', - 'csrc/multi_tensor_lamb_stage_2.cu', - 'csrc/multi_tensor_adam.cu', - 'csrc/multi_tensor_adagrad.cu', - 'csrc/multi_tensor_novograd.cu', - 'csrc/multi_tensor_lamb.cu'], - extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, - 'nvcc': nvcc_args_multi_tensor if not IS_ROCM_PYTORCH else hipcc_args_multi_tensor})) - - print ("INFO: Building syncbn extension.") - ext_modules.append( - CUDAExtension(name='syncbn', - sources=['csrc/syncbn.cpp', - 'csrc/welford.cu'], - extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, - 'nvcc':['-O3'] + version_dependent_macros})) - - nvcc_args_layer_norm = ['-maxrregcount=50', '-O3', '--use_fast_math'] + version_dependent_macros - hipcc_args_layer_norm = ['-O3'] + version_dependent_macros - print ("INFO: Building fused layernorm extension.") - ext_modules.append( - CUDAExtension(name='fused_layer_norm_cuda', - sources=['csrc/layer_norm_cuda.cpp', - 'csrc/layer_norm_cuda_kernel.cu'], - extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, - 'nvcc': nvcc_args_layer_norm if not IS_ROCM_PYTORCH else hipcc_args_layer_norm})) - - print ("INFO: Building the MLP Extension.") - ext_modules.append( - CUDAExtension(name='mlp_cuda', - sources=['csrc/mlp.cpp', - 'csrc/mlp_cuda.cu'], - extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, - 'nvcc':['-O3'] + version_dependent_macros})) + if is_rocm_pytorch: + import shutil + with hipify_python.GeneratedFileCleaner(keep_intermediates=True) as clean_ctx: + hipify_python.hipify(project_directory=this_dir, output_directory=this_dir, includes="csrc/*", + show_detailed=True, is_pytorch_extension=True, clean_ctx=clean_ctx) + shutil.copy("csrc/compat.h", "csrc/hip/compat.h") + shutil.copy("csrc/type_shim.h", "csrc/hip/type_shim.h") + + if not is_rocm_pytorch: + ext_modules.append( + CUDAExtension(name='amp_C', + sources=['csrc/amp_C_frontend.cpp', + 'csrc/multi_tensor_sgd_kernel.cu', + 'csrc/multi_tensor_scale_kernel.cu', + 'csrc/multi_tensor_axpby_kernel.cu', + 'csrc/multi_tensor_l2norm_kernel.cu', + 'csrc/multi_tensor_lamb_stage_1.cu', + 'csrc/multi_tensor_lamb_stage_2.cu', + 'csrc/multi_tensor_adam.cu', + 'csrc/multi_tensor_adagrad.cu', + 'csrc/multi_tensor_novograd.cu', + 'csrc/multi_tensor_lamb.cu'], + extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, + 'nvcc':['-lineinfo', + '-O3', + # '--resource-usage', + '--use_fast_math'] + version_dependent_macros})) + else: + print ("INFO: Building Multitensor apply extension") + ext_modules.append( + CUDAExtension(name='amp_C', + sources=['csrc/amp_C_frontend.cpp', + 'csrc/hip/multi_tensor_sgd_kernel.hip', + 'csrc/hip/multi_tensor_scale_kernel.hip', + 'csrc/hip/multi_tensor_axpby_kernel.hip', + 'csrc/hip/multi_tensor_l2norm_kernel.hip', + 'csrc/hip/multi_tensor_lamb_stage_1.hip', + 'csrc/hip/multi_tensor_lamb_stage_2.hip', + 'csrc/hip/multi_tensor_adam.hip', + 'csrc/hip/multi_tensor_adagrad.hip', + 'csrc/hip/multi_tensor_novograd.hip', + 'csrc/hip/multi_tensor_lamb.hip'], + extra_compile_args=['-O3'] + version_dependent_macros)) + + if not is_rocm_pytorch: + ext_modules.append( + CUDAExtension(name='syncbn', + sources=['csrc/syncbn.cpp', + 'csrc/welford.cu'], + extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, + 'nvcc':['-O3'] + version_dependent_macros})) + else: + print ("INFO: Building syncbn extension.") + ext_modules.append( + CUDAExtension(name='syncbn', + sources=['csrc/syncbn.cpp', + 'csrc/hip/welford.hip'], + extra_compile_args=['-O3'] + version_dependent_macros)) + + + if not is_rocm_pytorch: + ext_modules.append( + CUDAExtension(name='fused_layer_norm_cuda', + sources=['csrc/layer_norm_cuda.cpp', + 'csrc/layer_norm_cuda_kernel.cu'], + extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, + 'nvcc':['-maxrregcount=50', + '-O3', + '--use_fast_math'] + version_dependent_macros})) + else: + print ("INFO: Building FusedLayerNorm extension.") + ext_modules.append( + CUDAExtension(name='fused_layer_norm_cuda', + sources=['csrc/layer_norm_cuda.cpp', + 'csrc/hip/layer_norm_hip_kernel.hip'], + extra_compile_args={'cxx' : ['-O3'] + version_dependent_macros, + 'nvcc' : []})) + + if not is_rocm_pytorch: + ext_modules.append( + CUDAExtension(name='mlp_cuda', + sources=['csrc/mlp.cpp', + 'csrc/mlp_cuda.cu'], + extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, + 'nvcc':['-O3'] + version_dependent_macros})) + else: + print ("INFO: Building MLP extension") + ext_modules.append( + CUDAExtension(name='mlp_cuda', + sources=['csrc/mlp.cpp', + 'csrc/hip/mlp_hip.hip'], + extra_compile_args={'cxx' : ['-O3'] + version_dependent_macros, + 'nvcc' : []})) if "--bnp" in sys.argv: from torch.utils.cpp_extension import CUDAExtension @@ -260,18 +269,27 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): from torch.utils.cpp_extension import BuildExtension cmdclass['build_ext'] = BuildExtension - if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH: + is_rocm_pytorch = check_if_rocm_pytorch() + + if torch.utils.cpp_extension.CUDA_HOME is None and (not is_rocm_pytorch): raise RuntimeError("--xentropy was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") else: - print ("INFO: Building the xentropy extension.") - ext_modules.append( - CUDAExtension(name='xentropy_cuda', - sources=['apex/contrib/csrc/xentropy/interface.cpp', - 'apex/contrib/csrc/xentropy/xentropy_kernel.cu'], - include_dirs=[os.path.join(this_dir, 'csrc')], - extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, - 'nvcc':['-O3'] + version_dependent_macros})) - + if not is_rocm_pytorch: + ext_modules.append( + CUDAExtension(name='xentropy_cuda', + sources=['apex/contrib/csrc/xentropy/interface.cpp', + 'apex/contrib/csrc/xentropy/xentropy_kernel.cu'], + include_dirs=[os.path.join(this_dir, 'csrc')], + extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, + 'nvcc':['-O3'] + version_dependent_macros})) + else: + ext_modules.append( + CUDAExtension(name='xentropy_cuda', + sources=['apex/contrib/csrc/xentropy/interface.cpp', + 'apex/contrib/csrc/xentropy/hip/xentropy_kernel.hip'], + include_dirs=[os.path.join(this_dir, 'csrc/hip')], + extra_compile_args=['-O3'] + version_dependent_macros)) + if "--deprecated_fused_adam" in sys.argv: from torch.utils.cpp_extension import CUDAExtension @@ -280,19 +298,28 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): from torch.utils.cpp_extension import BuildExtension cmdclass['build_ext'] = BuildExtension - if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH: + is_rocm_pytorch = check_if_rocm_pytorch() + + if torch.utils.cpp_extension.CUDA_HOME is None and (not is_rocm_pytorch): raise RuntimeError("--deprecated_fused_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") else: - print ("INFO: Building deprecated fused adam extension.") - nvcc_args_fused_adam = ['-O3', '--use_fast_math'] + version_dependent_macros - hipcc_args_fused_adam = ['-O3'] + version_dependent_macros - ext_modules.append( - CUDAExtension(name='fused_adam_cuda', - sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp', - 'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'], - include_dirs=[os.path.join(this_dir, 'csrc')], - extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, - 'nvcc' : nvcc_args_fused_adam if not IS_ROCM_PYTORCH else hipcc_args_fused_adam})) + if not is_rocm_pytorch: + ext_modules.append( + CUDAExtension(name='fused_adam_cuda', + sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp', + 'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'], + include_dirs=[os.path.join(this_dir, 'csrc')], + extra_compile_args={'cxx': ['-O3',] + version_dependent_macros, + 'nvcc':['-O3', + '--use_fast_math'] + version_dependent_macros})) + else: + print ("INFO: Building deprecated fused adam.") + ext_modules.append( + CUDAExtension(name='fused_adam_cuda', + sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp', + 'apex/contrib/csrc/optimizers/hip/fused_adam_hip_kernel.hip'], + include_dirs=[os.path.join(this_dir, 'csrc/hip')], + extra_compile_args=['-O3'] + version_dependent_macros)) if "--deprecated_fused_lamb" in sys.argv: from torch.utils.cpp_extension import CUDAExtension @@ -301,19 +328,30 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): from torch.utils.cpp_extension import BuildExtension cmdclass['build_ext'] = BuildExtension - if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH: + is_rocm_pytorch = check_if_rocm_pytorch() + + if torch.utils.cpp_extension.CUDA_HOME is None and (not is_rocm_pytorch): raise RuntimeError("--deprecated_fused_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") else: - print ("INFO: Building deprecated fused lamb extension.") - nvcc_args_fused_lamb = ['-O3', '--use_fast_math'] + version_dependent_macros - hipcc_args_fused_lamb = ['-O3'] + version_dependent_macros - ext_modules.append( - CUDAExtension(name='fused_lamb_cuda', - sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp', - 'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu', - 'csrc/multi_tensor_l2norm_kernel.cu'], - include_dirs=[os.path.join(this_dir, 'csrc')], - extra_compile_args = nvcc_args_fused_lamb if not IS_ROCM_PYTORCH else hipcc_args_fused_lamb)) + if not is_rocm_pytorch: + ext_modules.append( + CUDAExtension(name='fused_lamb_cuda', + sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp', + 'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu', + 'csrc/multi_tensor_l2norm_kernel.cu'], + include_dirs=[os.path.join(this_dir, 'csrc')], + extra_compile_args={'cxx': ['-O3',] + version_dependent_macros, + 'nvcc':['-O3', + '--use_fast_math'] + version_dependent_macros})) + else: + print ("INFO: Building deprecated fused lamb.") + ext_modules.append( + CUDAExtension(name='fused_lamb_cuda', + sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp', + 'apex/contrib/csrc/optimizers/hip/fused_lamb_hip_kernel.hip', + 'csrc/hip/multi_tensor_l2norm_kernel.hip'], + include_dirs=[os.path.join(this_dir, 'csrc/hip')], + extra_compile_args=['-O3'] + version_dependent_macros)) # Check, if ATen/CUDAGenerator.h is found, otherwise use the new ATen/CUDAGeneratorImpl.h, due to breaking change in https://github.com/pytorch/pytorch/pull/36026 generator_flag = [] @@ -321,7 +359,6 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')): generator_flag = ['-DOLD_GENERATOR'] - if "--fast_multihead_attn" in sys.argv: from torch.utils.cpp_extension import CUDAExtension sys.argv.remove("--fast_multihead_attn") @@ -332,13 +369,6 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): if torch.utils.cpp_extension.CUDA_HOME is None: raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") else: - # Check, if CUDA11 is installed for compute capability 8.0 - cc_flag = [] - _, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) - if int(bare_metal_major) >= 11: - cc_flag.append('-gencode') - cc_flag.append('arch=compute_80,code=sm_80') - subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"]) ext_modules.append( CUDAExtension(name='fast_additive_mask_softmax_dropout', @@ -352,7 +382,7 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', '--expt-extended-lambda', - '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) + '--use_fast_math'] + version_dependent_macros + generator_flag})) ext_modules.append( CUDAExtension(name='fast_mask_softmax_dropout', sources=['apex/contrib/csrc/multihead_attn/masked_softmax_dropout.cpp', @@ -365,7 +395,7 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', '--expt-extended-lambda', - '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) + '--use_fast_math'] + version_dependent_macros + generator_flag})) ext_modules.append( CUDAExtension(name='fast_self_multihead_attn_bias_additive_mask', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask.cpp', @@ -378,7 +408,7 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', '--expt-extended-lambda', - '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) + '--use_fast_math'] + version_dependent_macros + generator_flag})) ext_modules.append( CUDAExtension(name='fast_self_multihead_attn_bias', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias.cpp', @@ -391,7 +421,7 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', '--expt-extended-lambda', - '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) + '--use_fast_math'] + version_dependent_macros + generator_flag})) ext_modules.append( CUDAExtension(name='fast_self_multihead_attn', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp', @@ -404,7 +434,7 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', '--expt-extended-lambda', - '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) + '--use_fast_math'] + version_dependent_macros + generator_flag})) ext_modules.append( CUDAExtension(name='fast_self_multihead_attn_norm_add', sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp', @@ -417,7 +447,7 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', '--expt-extended-lambda', - '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) + '--use_fast_math'] + version_dependent_macros + generator_flag})) ext_modules.append( CUDAExtension(name='fast_encdec_multihead_attn', sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp', @@ -430,7 +460,7 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', '--expt-extended-lambda', - '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) + '--use_fast_math'] + version_dependent_macros + generator_flag})) ext_modules.append( CUDAExtension(name='fast_encdec_multihead_attn_norm_add', sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp', @@ -443,7 +473,7 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', '--expt-extended-lambda', - '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag})) + '--use_fast_math'] + version_dependent_macros + generator_flag})) setup( name='apex',