diff --git a/.github/scripts/utils_pip.bash b/.github/scripts/utils_pip.bash index 59c9a878f2..8522fa7476 100644 --- a/.github/scripts/utils_pip.bash +++ b/.github/scripts/utils_pip.bash @@ -166,11 +166,39 @@ __prepare_pip_arguments () { __export_pip_arguments "$([ "$package_variant_type_version" != "" ] && echo "true" || echo "false")" } +__check_package_variant () { + # shellcheck disable=SC2155 + local env_prefix=$(env_name_or_prefix "${env_name}") + + # Check applies to installation of packages with variants, and only to non-CPU variants + if [ "$package_variant_type_version" != "" ] && [ "$package_variant_type" != "cpu" ]; then + # Ensure that the package build is of the correct variant + # This test usually applies to the nightly builds + # shellcheck disable=SC2086 + if conda run ${env_prefix} pip list | grep "${package_name_raw}" | grep "${package_variant}"; then + local check_passed=1 + elif conda run ${env_prefix} pip list | grep "${package_name}" | grep "${package_variant}"; then + local check_passed=1 + else + local check_passed=0 + fi + + if [ $check_passed -eq 1 ]; then + echo "[CHECK] The installed package [${package_name}, ${package_channel}/${package_version:-LATEST}] is the correct variant (${package_variant})" + return 0 + else + echo "[CHECK] The installed package [${package_name}, ${package_channel}/${package_version:-LATEST}] appears to be an incorrect variant as it is missing references to ${package_variant}!" + echo "[CHECK] This can happen if the variant of the package (e.g. GPU, nightly) for the MAJOR.MINOR version of CUDA or ROCm presently installed on the system is not available." + return 1 + fi + fi +} + install_from_pytorch_pip () { - local env_name="$1" - local package_name_raw="$2" - local package_channel_version="$3" - local package_variant_type_version="$4" + env_name="$1" + package_name_raw="$2" + package_channel_version="$3" + package_variant_type_version="$4" if [ "$package_channel_version" == "" ]; then echo "Usage: ${FUNCNAME[0]} ENV_NAME PACKAGE_NAME PACKAGE_CHANNEL[/VERSION] [PACKAGE_VARIANT_TYPE[/VARIANT_VERSION]]" echo "Example(s):" @@ -203,22 +231,10 @@ install_from_pytorch_pip () { # shellcheck disable=SC2086 (exec_with_retries 3 conda run ${env_prefix} pip install ${pip_package} --index-url ${pip_channel}) || return 1 - # Check applies to installation of packages with variants, and only to non-CPU variants - if [ "$package_variant_type_version" != "" ] && [ "$package_variant_type" != "cpu" ]; then - # Ensure that the package build is of the correct variant - # This test usually applies to the nightly builds - # shellcheck disable=SC2086 - if conda run ${env_prefix} pip list | grep "${package_name}" | grep "${package_variant}"; then - echo "[CHECK] The installed package [${package_name}, ${package_channel}/${package_version:-LATEST}] is the correct variant (${package_variant})" - else - echo "[CHECK] The installed package [${package_name}, ${package_channel}/${package_version:-LATEST}] appears to be an incorrect variant as it is missing references to ${package_variant}!" - echo "[CHECK] This can happen if the variant of the package (e.g. GPU, nightly) for the MAJOR.MINOR version of CUDA or ROCm presently installed on the system is not available." - return 1 - fi - fi + # Ensure that the correct package variant has been installed + __check_package_variant || return 1 } - ################################################################################ # PyTorch PIP Download Functions ################################################################################ diff --git a/fbgemm_gpu/experimental/gen_ai/test/attention/gqa_test.py b/fbgemm_gpu/experimental/gen_ai/test/attention/gqa_test.py index 32d96fad23..98659460ba 100755 --- a/fbgemm_gpu/experimental/gen_ai/test/attention/gqa_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/attention/gqa_test.py @@ -120,8 +120,8 @@ def gqa_reference( class Int4GQATest(unittest.TestCase): @unittest.skipIf( - not torch.version.cuda, - "Skip when CUDA is not available", + not torch.version.cuda or torch.cuda.get_device_capability()[0] < 8, + "Skip when CUDA is not available or CUDA compute capability is less than 8", ) @settings(verbosity=VERBOSITY, max_examples=40, deadline=None) # pyre-ignore