From 99af5c081d4810932195a69b2a9a55c338987f16 Mon Sep 17 00:00:00 2001 From: "Auriane R." <48684432+aurianer@users.noreply.github.com> Date: Tue, 24 Sep 2024 20:01:36 +0200 Subject: [PATCH] Allow to pass architectures like 90a, without being overriden (#1178) * Allow to pass architectures like 90a, without being overriden Signed-off-by: aurianer * Review suggestion from @timmoon10 Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by: aurianer Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- build_tools/pytorch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 3725e58c87..7aff2db324 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -80,10 +80,10 @@ def setup_pytorch_extension( ) ) - if "80" in cuda_architectures: - nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"]) - if "90" in cuda_architectures: - nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"]) + for arch in cuda_architectures.split(";"): + if arch == "70": + continue # Already handled + nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"]) # Libraries library_dirs = []