diff --git a/reinforced_lib/logs/plots_logger.py b/reinforced_lib/logs/plots_logger.py index 31d4f3b..f15817d 100644 --- a/reinforced_lib/logs/plots_logger.py +++ b/reinforced_lib/logs/plots_logger.py @@ -3,7 +3,7 @@ import jax.numpy as jnp import matplotlib.pyplot as plt -from chex import Array, Scalar +from chex import Array, Scalar, Numeric from reinforced_lib.logs import BaseLogger, Source from reinforced_lib.utils import timestamp @@ -32,12 +32,13 @@ class PlotsLogger(BaseLogger): """ def __init__( - self, - plots_dir: str = None, - plots_ext: str = 'pdf', - plots_smoothing: Scalar = 0.6, - plots_scatter: bool = False, - **kwargs + self, + plots_dir: str = None, + plots_ext: str = 'pdf', + plots_smoothing: Scalar = 0.6, + plots_scatter: bool = False, + plots_sync_steps: bool = False, + **kwargs ) -> None: assert 1 > plots_smoothing >= 0 @@ -47,6 +48,7 @@ def __init__( self._ext = plots_ext self._smoothing = plots_smoothing self._scatter = plots_scatter + self._sync_steps = plots_sync_steps self._current_values = set() self._step = 0 @@ -68,38 +70,36 @@ def exponential_moving_average(values: list, weight: Scalar) -> list: return smoothed - def lineplot(values: list, alpha: Scalar = 1.0, label: bool = False) -> None: + def lineplot(values: list, steps: list, alpha: Scalar = 1.0, label: bool = False) -> None: values = jnp.array(values) values = jnp.squeeze(values) if values.ndim == 1: - plt.plot(values, alpha=alpha, c='C0') + plt.plot(steps, values, alpha=alpha, c='C0') elif values.ndim == 2: for i, val in enumerate(jnp.array(values).T): - plt.plot(val, alpha=alpha, c=f'C{i % 10}', label=i if label else '') + plt.plot(steps, val, alpha=alpha, c=f'C{i % 10}', label=i if label else '') plt.legend() - def scatterplot(values: list, label: bool = False) -> None: - values = jnp.array(values) - values = jnp.squeeze(values) - xs = range(1, len(values) + 1) + def scatterplot(values: list, steps: list, label: bool = False) -> None: + values = jnp.array(values).squeeze() if values.ndim == 1: - plt.scatter(xs, values, c='C0', marker='.', s=4) + plt.scatter(steps, values, c='C0', marker='.', s=4) elif values.ndim == 2: for i, val in enumerate(jnp.array(values).T): - plt.scatter(xs, val, c=f'C{i % 10}', label=i if label else '', marker='.', s=4) + plt.scatter(steps, val, c=f'C{i % 10}', label=i if label else '', marker='.', s=4) plt.legend() for name, values in self._values.items(): filename = f'rlib-plot-{name}-{timestamp()}.{self._ext}' if self._scatter: - scatterplot(values, True) + scatterplot(values, self._steps[name], True) else: smoothed = exponential_moving_average(values, self._smoothing) - lineplot(values, alpha=0.3) - lineplot(smoothed, label=True) + lineplot(values, self._steps[name], alpha=0.3) + lineplot(smoothed, self._steps[name], label=True) plt.title(name) plt.xlabel('step') @@ -118,7 +118,7 @@ def log_scalar(self, source: Source, value: Scalar, *_) -> None: Scalar to log. """ - self._values[self.source_to_name(source)].append(value) + self._log(source, value) def log_array(self, source: Source, value: Array, *_) -> None: """ @@ -132,4 +132,31 @@ def log_array(self, source: Source, value: Array, *_) -> None: Array to log. """ - self._values[self.source_to_name(source)].append(value) + self._log(source, value) + + def _log(self, source: Source, value: Numeric) -> None: + """ + Adds a given scalar to the plot values. + + Parameters + ---------- + source : Source + Source of the logged value. + value : Numeric + Value to log. + """ + + name = self.source_to_name(source) + + if self._sync_steps: + if name in self._current_values: + self._step += 1 + self._current_values.clear() + + self._current_values.add(name) + step = self._step + else: + step = self._steps[name][-1] + 1 if self._steps[name] else 0 + + self._values[name].append(value) + self._steps[name].append(step) diff --git a/reinforced_lib/logs/tb_logger.py b/reinforced_lib/logs/tb_logger.py index ce29c28..bc3fc86 100644 --- a/reinforced_lib/logs/tb_logger.py +++ b/reinforced_lib/logs/tb_logger.py @@ -28,6 +28,7 @@ def __init__( self, tb_log_dir: str = None, tb_comet_config: dict[str, any] = None, + tb_sync_steps: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) @@ -35,6 +36,10 @@ def __init__( if tb_comet_config is None: tb_comet_config = {'disabled': True} + self._sync_steps = tb_sync_steps + self._current_values = set() + self._step = 0 + self._writer = SummaryWriter(log_dir=tb_log_dir, comet_config=tb_comet_config) self._steps = defaultdict(int) @@ -57,8 +62,9 @@ def log_scalar(self, source: Source, value: Scalar, *_) -> None: Scalar to log. """ - self._writer.add_scalar(self.source_to_name(source), value, self._steps[source]) - self._steps[source] += 1 + name = self.source_to_name(source) + step = self._get_step(name) + self._writer.add_scalar(name, value, step) def log_array(self, source: Source, value: Array, *_) -> None: """ @@ -72,8 +78,9 @@ def log_array(self, source: Source, value: Array, *_) -> None: Array to log. """ - self._writer.add_histogram(self.source_to_name(source), value, self._steps[source]) - self._steps[source] += 1 + name = self.source_to_name(source) + step = self._get_step(name) + self._writer.add_histogram(name, value, step) def log_dict(self, source: Source, value: dict, *_) -> None: """ @@ -101,5 +108,34 @@ def log_other(self, source: Source, value: any, *_) -> None: Dictionary to log. """ - self._writer.add_text(self.source_to_name(source), json.dumps(value), self._steps[source]) - self._steps[source] += 1 + name = self.source_to_name(source) + step = self._get_step(name) + self._writer.add_text(name, json.dumps(value), step) + + def _get_step(self, name: str) -> int: + """ + Returns the current step for a given source. + + Parameters + ---------- + name : str + Name of the source. + + Returns + ------- + int + Current step for the given source. + """ + + if self._sync_steps: + if name in self._current_values: + self._step += 1 + self._current_values.clear() + + self._current_values.add(name) + step = self._step + else: + step = self._steps[name] + 1 + + self._steps[name] = step + return step