From 8c6392016e955a4dfa2bff45a899f395d92ecd9e Mon Sep 17 00:00:00 2001 From: Parsiad Azimzadeh Date: Fri, 23 Aug 2024 15:00:20 -0400 Subject: [PATCH] Skip test if import missing --- tests/test_func.py | 5 ++++- tests/test_nn.py | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_func.py b/tests/test_func.py index 0e3cdff..e420670 100644 --- a/tests/test_func.py +++ b/tests/test_func.py @@ -1,10 +1,13 @@ import numpy as np -import scipy.special +import pytest import micrograd_pp as mpp +@pytest.mark.skipif(not pytest.importorskip("scipy.special"), reason="Unable to import scipy.special") def test_softmax() -> None: + import scipy.special + a = np.random.randn(5, 4, 3) actual = mpp.softmax(mpp.Constant(a), dim=1).value desired = scipy.special.softmax(a, axis=1) diff --git a/tests/test_nn.py b/tests/test_nn.py index adb5ef9..8aed0b3 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -2,7 +2,6 @@ import numpy as np import pytest -import torch import micrograd_pp as mpp @@ -88,7 +87,10 @@ def test_layer_norm() -> None: @pytest.mark.parametrize("is_causal", (False, True)) +@pytest.mark.skipif(not pytest.importorskip("torch"), reason="Unable to import torch") def test_multihead_attention(is_causal: bool) -> None: # Test against PyTorch implementation + import torch + torch_attn_mask = None mpp_attn_mask = None if is_causal: