diff --git a/baal/active/heuristics/heuristics_gpu.py b/baal/active/heuristics/heuristics_gpu.py index a43d43d..1a5be07 100644 --- a/baal/active/heuristics/heuristics_gpu.py +++ b/baal/active/heuristics/heuristics_gpu.py @@ -106,7 +106,9 @@ def predict_on_dataset( verbose=True, ): return ( - super().predict_on_dataset(dataset, iterations, half, verbose).reshape([-1]) # type: ignore + super() + .predict_on_dataset(dataset, iterations, half, verbose) + .reshape([-1]) # type: ignore ) def predict_on_batch(self, data, iterations=1): diff --git a/baal/experiments/base.py b/baal/experiments/base.py index 430c672..b96549d 100644 --- a/baal/experiments/base.py +++ b/baal/experiments/base.py @@ -20,8 +20,8 @@ TRANSFORMERS_AVAILABLE = True except ImportError: - BaalTransformersTrainer = Any # type: ignore - TransformersAdapter = Any # type: ignore + BaalTransformersTrainer = None # type: ignore + TransformersAdapter = None # type: ignore TRANSFORMERS_AVAILABLE = False log = structlog.get_logger(__name__)