From 164f3f8d3feb213db700c72a8b8c9ee5ff2dc6e7 Mon Sep 17 00:00:00 2001 From: YoshitakaMo Date: Tue, 23 Jan 2024 12:00:42 +0900 Subject: [PATCH] suppress warnings for jax initialize backend --- colabfold/batch.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/colabfold/batch.py b/colabfold/batch.py index b279d616..a1fb6593 100644 --- a/colabfold/batch.py +++ b/colabfold/batch.py @@ -75,7 +75,11 @@ logger = logging.getLogger(__name__) import jax import jax.numpy as jnp -logging.getLogger('jax._src.lib.xla_bridge').addFilter(lambda _: False) + +# from jax 0.4.6, jax._src.lib.xla_bridge moved to jax._src.xla_bridge +# suppress warnings: Unable to initialize backend 'rocm' or 'tpu' +logging.getLogger('jax._src.xla_bridge').addFilter(lambda _: False) # before jax 0.4.5 +logging.getLogger('jax._src.lib.xla_bridge').addFilter(lambda _: False) # from jax 0.4.6 def mk_mock_template( query_sequence: Union[List[str], str], num_temp: int = 1