Skip to content

Commit

Permalink
Mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
Dref360 committed Jun 6, 2024
1 parent e0fdd49 commit 466fe61
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 20 deletions.
4 changes: 2 additions & 2 deletions baal/active/heuristics/heuristics_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def predict_on_dataset(
verbose=True,
):
return (
super().predict_on_dataset(dataset, iterations, half, verbose).reshape([-1])
) # type: ignore
super().predict_on_dataset(dataset, iterations, half, verbose).reshape([-1]) # type: ignore
)

def predict_on_batch(self, data, iterations=1):
"""Rank the predictions according to their uncertainties."""
Expand Down
27 changes: 14 additions & 13 deletions baal/experiments/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import itertools
from typing import Union, Optional
from typing import Union, Optional, Any

import numpy as np
import structlog
Expand All @@ -20,8 +20,8 @@

TRANSFORMERS_AVAILABLE = True
except ImportError:
BaalTransformersTrainer = None
TransformersAdapter = None
BaalTransformersTrainer = Any # type: ignore
TransformersAdapter = Any # type: ignore
TRANSFORMERS_AVAILABLE = False

log = structlog.get_logger(__name__)
Expand All @@ -47,16 +47,17 @@ class ActiveLearningExperiment:
pool_size: Optionally limit the size of the unlabelled pool.
criterion: Stopping criterion for the experiment.
"""

def __init__(
self,
trainer: Union[ModelWrapper, "BaalTransformersTrainer"],
al_dataset: ActiveLearningDataset,
eval_dataset: Dataset,
heuristic: AbstractHeuristic,
query_size: int = 100,
iterations: int = 20,
pool_size: Optional[int] = None,
criterion: Optional[StoppingCriterion] = None,
self,
trainer: Union[ModelWrapper, "BaalTransformersTrainer"],
al_dataset: ActiveLearningDataset,
eval_dataset: Dataset,
heuristic: AbstractHeuristic,
query_size: int = 100,
iterations: int = 20,
pool_size: Optional[int] = None,
criterion: Optional[StoppingCriterion] = None,
):
self.al_dataset = al_dataset
self.eval_dataset = eval_dataset
Expand Down Expand Up @@ -89,7 +90,7 @@ def start(self):
return records

def _get_adapter(
self, trainer: Union[ModelWrapper, "BaalTransformersTrainer"]
self, trainer: Union[ModelWrapper, "BaalTransformersTrainer"]
) -> FrameworkAdapter:
if isinstance(trainer, ModelWrapper):
return ModelWrapperAdapter(trainer)
Expand Down
2 changes: 1 addition & 1 deletion baal/experiments/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def reset_weights(self):

def train(self, al_dataset: Dataset) -> Dict[str, float]:
self.wrapper.train_dataset = al_dataset
return self.wrapper.train().metrics
return self.wrapper.train().metrics # type: ignore

def predict(self, dataset: Dataset, iterations: int) -> Union[NDArray, List[NDArray]]:
return self.wrapper.predict_on_dataset(dataset, iterations=iterations)
Expand Down
6 changes: 2 additions & 4 deletions baal/utils/pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx: Optional[int] = None):
# Get the input only.
x, _ = batch
# Perform Monte-Carlo Inference fro I iterations.
out = mc_inference(
self, x, self.hparams.iterations, self.hparams.replicate_in_memory # type: ignore
)
out = mc_inference(self, x, self.hparams.iterations, self.hparams.replicate_in_memory)
return out


Expand Down Expand Up @@ -185,7 +183,7 @@ def step(self, model=None, datamodule: Optional[BaaLDataModule] = None) -> bool:
"""
# High to low
if datamodule is None:
pool_dataloader = self.lightning_module.pool_dataloader() # type: ignore
pool_dataloader = self.lightning_module.pool_dataloader()
else:
pool_dataloader = datamodule.pool_dataloader()
model = model if model is not None else self.lightning_module
Expand Down

0 comments on commit 466fe61

Please sign in to comment.