diff --git a/src/para_attn/para_attn_interface.py b/src/para_attn/para_attn_interface.py index 5b53952..c1d3197 100644 --- a/src/para_attn/para_attn_interface.py +++ b/src/para_attn/para_attn_interface.py @@ -119,6 +119,7 @@ def forward( mesh, ): assert _templated_ring_attention is not None, "RingAttnFunc requires a newer version of PyTorch" + assert torch_ring_attention is not None, "RingAttnFunc requires a newer version of PyTorch" with unittest.mock.patch.object( torch_ring_attention,