diff --git a/core/metric_mixin.py b/core/metric_mixin.py index a716ca7..b9a133d 100644 --- a/core/metric_mixin.py +++ b/core/metric_mixin.py @@ -34,17 +34,17 @@ def transform(self, outputs): class MetricMixin: - @abstractmethod - def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict: - ... - - def update(self, outputs: Dict[str, torch.Tensor]): - results = self.transform(outputs) - # Do not try to update if any tensor is empty as a result of stratification. - for value in results.values(): - if torch.is_tensor(value) and not value.nelement(): - return - super().update(**results) + @abstractmethod + def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + ... + + def update(self, outputs: Dict[str, torch.Tensor]): + results = self.transform(outputs) + # Do not try to update if any tensor is empty as a result of stratification. + if any((torch.is_tensor(value) and not value.nelement()) for value in results.values()): + return + super().update(**results) + class TaskMixin: diff --git a/ml_logging/torch_logging.py b/ml_logging/torch_logging.py index e791c46..fedabcb 100644 --- a/ml_logging/torch_logging.py +++ b/ml_logging/torch_logging.py @@ -17,45 +17,46 @@ import torch.distributed as dist +import functools +from typing import Optional + +import torch.distributed as dist + +from absl import logging as absl_logging +from tml.ml_logging.absl_logging import logging as logging + + def rank_specific(logger): - """Ensures that we only override a given logger once.""" - if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC"): - return logger - - def _if_rank(logger_method, limit: Optional[int] = None): - if limit: - # If we are limiting redundant logs, wrap logging call with a cache - # to not execute if already cached. - def _wrap(_call): - @functools.lru_cache(limit) - def _logger_method(*args, **kwargs): - _call(*args, **kwargs) - - return _logger_method - - logger_method = _wrap(logger_method) - - def _inner(msg, *args, rank: int = 0, **kwargs): - if not dist.is_initialized(): - logger_method(msg, *args, **kwargs) - elif dist.get_rank() == rank: - logger_method(msg, *args, **kwargs) - elif rank < 0: - logger_method(f"Rank{dist.get_rank()}: {msg}", *args, **kwargs) - - # Register this stack frame with absl logging so that it doesn't trample logging lines. - absl_logging.ABSLLogger.register_frame_to_skip(__file__, _inner.__name__) - - return _inner - - logger.fatal = _if_rank(logger.fatal) - logger.error = _if_rank(logger.error) - logger.warning = _if_rank(logger.warning, limit=1) - logger.info = _if_rank(logger.info) - logger.debug = _if_rank(logger.debug) - logger.exception = _if_rank(logger.exception) - - logger._ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC = True + """Ensures that we only override a given logger once.""" + + def _if_rank(logger_method, limit: Optional[int] = None): + """Decorator to wrap logger_method and execute only if rank matches.""" + if limit: + @functools.lru_cache(limit) + def _logger_method(*args, **kwargs): + logger_method(*args, **kwargs) + return _logger_method + + def _inner(msg, *args, rank: int = 0, **kwargs): + """Inner function to execute logger_method only if rank matches.""" + if not dist.is_initialized() or dist.get_rank() == rank or rank < 0: + logger_method(msg, *args, **kwargs) + + # Register this stack frame with absl logging so that it doesn't trample logging lines. + absl_logging.ABSLLogger.register_frame_to_skip(__file__, _inner.__name__) + + return _inner + + logger.fatal = _if_rank(logger.fatal) + logger.error = _if_rank(logger.error) + logger.warning = _if_rank(logger.warning, limit=1) + logger.info = _if_rank(logger.info) + logger.debug = _if_rank(logger.debug) + logger.exception = _if_rank(logger.exception) + + logger._ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC = True + + return logger if hasattr(logger, "_ALREADY_OVERWRITTEN_TO_BE_RANK_SPECIFIC") else None rank_specific(logging) diff --git a/optimizers/config.py b/optimizers/config.py index f5011f0..a1df58e 100644 --- a/optimizers/config.py +++ b/optimizers/config.py @@ -72,11 +72,7 @@ class OptimizerConfig(base_config.BaseConfig): def get_optimizer_algorithm_config(optimizer_config: OptimizerConfig): - if optimizer_config.adam is not None: - return optimizer_config.adam - elif optimizer_config.sgd is not None: - return optimizer_config.sgd - elif optimizer_config.adagrad is not None: - return optimizer_config.adagrad - else: - raise ValueError(f"No optimizer selected in optimizer_config, passed {optimizer_config}") + for optz in (optimizer_config.adam, optimizer_config.sgd, optimizer_config.adagrad): + if optz is not None: + return optz + raise ValueError(f"No optimizer selected in optimizer_config, passed {optimizer_config}")