diff --git a/tests/test_test_example.py b/tests/test_test_example.py index fbcfa709..0e6ad8e2 100644 --- a/tests/test_test_example.py +++ b/tests/test_test_example.py @@ -2,7 +2,7 @@ import unittest import torch -from zeta import MultiheadAttention +from zeta.nn.attention import MultiheadAttention class TestMultiheadAttention(unittest.TestCase):