Skip to content

Commit

Permalink
Update to autotqdm, improve example
Browse files Browse the repository at this point in the history
  • Loading branch information
Dref360 committed Jun 6, 2024
1 parent 46d935b commit d6d6512
Show file tree
Hide file tree
Showing 9 changed files with 15 additions and 12 deletions.
2 changes: 1 addition & 1 deletion baal/active/active_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from . import heuristics
from .dataset import ActiveLearningDataset

log = structlog.get_logger(__name__)
log = structlog.get_logger('baal')
pjoin = os.path.join


Expand Down
2 changes: 1 addition & 1 deletion baal/active/heuristics/stochastics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from baal.active.heuristics import AbstractHeuristic, Sequence

log = structlog.get_logger(__name__)
log = structlog.get_logger('baal')
EPSILON = 1e-8


Expand Down
2 changes: 1 addition & 1 deletion baal/calibration/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from baal.modelwrapper import TrainingArgs
from baal.utils.metrics import ECE, ECE_PerCLs

log = structlog.get_logger("Calibrating...")
log = structlog.get_logger("baal")


class DirichletCalibrator(object):
Expand Down
9 changes: 6 additions & 3 deletions baal/experiments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import structlog
from torch.utils.data import Subset
from tqdm import tqdm
from tqdm.autonotebook import tqdm

from baal import ModelWrapper, ActiveLearningDataset
from baal.active.dataset.base import Dataset
Expand All @@ -24,7 +24,7 @@
TransformersAdapter = None # type: ignore
TRANSFORMERS_AVAILABLE = False

log = structlog.get_logger(__name__)
log = structlog.get_logger('baal')


class ActiveLearningExperiment:
Expand Down Expand Up @@ -78,7 +78,10 @@ def start(self):
"No item labelled in the training set."
" Did you run `ActiveLearningDataset.label_randomly`?"
)
for _ in tqdm(itertools.count(start=0)):
for _ in tqdm(itertools.count(start=0), # Infinite counter to rely on Criterion
desc="Active Experiment",
# Upper bound estimation.
total=np.round(self.al_dataset.n_unlabelled // self.query_size)):
self.adapter.reset_weights()
train_metrics = self.adapter.train(self.al_dataset)
eval_metrics = self.adapter.evaluate(
Expand Down
4 changes: 2 additions & 2 deletions baal/modelwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from tqdm import tqdm
from tqdm.autonotebook import tqdm

from baal.active.dataset.base import Dataset
from baal.metrics.mixin import MetricMixin
Expand All @@ -23,7 +23,7 @@
from baal.utils.metrics import Loss
from baal.utils.warnings import raise_warnings_cache_replicated

log = structlog.get_logger("ModelWrapper")
log = structlog.get_logger("baal")


def _stack_preds(out):
Expand Down
2 changes: 1 addition & 1 deletion baal/transformers_trainer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from baal.utils.array_utils import stack_in_memory
from baal.utils.iterutils import map_on_tensor

log = structlog.get_logger("BaalTransformersTrainer")
log = structlog.get_logger("baal")


def _stack_preds(out):
Expand Down
2 changes: 1 addition & 1 deletion baal/utils/pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from baal.utils.cuda_utils import to_cuda
from baal.utils.iterutils import map_on_tensor

log = structlog.get_logger("PL testing")
log = structlog.get_logger("baal")

warnings.warn(
"baal.utils.pytorch_lightning is deprecated. BaaL is now integrated into Lightning Flash!"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
BaaLDataModule,
)

log = structlog.get_logger("PL testing")
log = structlog.get_logger("baal")

warnings.warn(
"baal.utils.pytorch_lightning is deprecated. BaaL is now integrated into Lightning Flash!"
Expand Down
2 changes: 1 addition & 1 deletion experiments/pytorch_lightning/lightning_flash_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from baal.active import get_heuristic

log = structlog.get_logger()
log = structlog.get_logger('baal')

IMG_SIZE = 128

Expand Down

0 comments on commit d6d6512

Please sign in to comment.