diff --git a/python/tests/TestTorchForce.py b/python/tests/TestTorchForce.py index c8bcd23..6eefed6 100644 --- a/python/tests/TestTorchForce.py +++ b/python/tests/TestTorchForce.py @@ -87,9 +87,7 @@ def forward(self, positions): assert self.positions.device == self.device assert positions.device == self.device assert positions.dtype == self.dtype - print(positions) - print(self.positions) - assert pt.all(positions == self.positions) + assert pt.allclose(positions, self.positions) return pt.sum(positions) with NamedTemporaryFile() as fd: