Skip to content

Commit

Permalink
Add step synchronization in PlotsLogger and TensorboardLogger
Browse files Browse the repository at this point in the history
  • Loading branch information
m-wojnar committed Jul 18, 2023
1 parent f6e4ac1 commit 6a38384
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 27 deletions.
69 changes: 48 additions & 21 deletions reinforced_lib/logs/plots_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import jax.numpy as jnp
import matplotlib.pyplot as plt
from chex import Array, Scalar
from chex import Array, Scalar, Numeric

from reinforced_lib.logs import BaseLogger, Source
from reinforced_lib.utils import timestamp
Expand Down Expand Up @@ -32,12 +32,13 @@ class PlotsLogger(BaseLogger):
"""

def __init__(
self,
plots_dir: str = None,
plots_ext: str = 'pdf',
plots_smoothing: Scalar = 0.6,
plots_scatter: bool = False,
**kwargs
self,
plots_dir: str = None,
plots_ext: str = 'pdf',
plots_smoothing: Scalar = 0.6,
plots_scatter: bool = False,
plots_sync_steps: bool = False,
**kwargs
) -> None:
assert 1 > plots_smoothing >= 0

Expand All @@ -47,6 +48,7 @@ def __init__(
self._ext = plots_ext
self._smoothing = plots_smoothing
self._scatter = plots_scatter
self._sync_steps = plots_sync_steps

self._current_values = set()
self._step = 0
Expand All @@ -68,38 +70,36 @@ def exponential_moving_average(values: list, weight: Scalar) -> list:

return smoothed

def lineplot(values: list, alpha: Scalar = 1.0, label: bool = False) -> None:
def lineplot(values: list, steps: list, alpha: Scalar = 1.0, label: bool = False) -> None:
values = jnp.array(values)
values = jnp.squeeze(values)

if values.ndim == 1:
plt.plot(values, alpha=alpha, c='C0')
plt.plot(steps, values, alpha=alpha, c='C0')
elif values.ndim == 2:
for i, val in enumerate(jnp.array(values).T):
plt.plot(val, alpha=alpha, c=f'C{i % 10}', label=i if label else '')
plt.plot(steps, val, alpha=alpha, c=f'C{i % 10}', label=i if label else '')
plt.legend()

def scatterplot(values: list, label: bool = False) -> None:
values = jnp.array(values)
values = jnp.squeeze(values)
xs = range(1, len(values) + 1)
def scatterplot(values: list, steps: list, label: bool = False) -> None:
values = jnp.array(values).squeeze()

if values.ndim == 1:
plt.scatter(xs, values, c='C0', marker='.', s=4)
plt.scatter(steps, values, c='C0', marker='.', s=4)
elif values.ndim == 2:
for i, val in enumerate(jnp.array(values).T):
plt.scatter(xs, val, c=f'C{i % 10}', label=i if label else '', marker='.', s=4)
plt.scatter(steps, val, c=f'C{i % 10}', label=i if label else '', marker='.', s=4)
plt.legend()

for name, values in self._values.items():
filename = f'rlib-plot-{name}-{timestamp()}.{self._ext}'

if self._scatter:
scatterplot(values, True)
scatterplot(values, self._steps[name], True)
else:
smoothed = exponential_moving_average(values, self._smoothing)
lineplot(values, alpha=0.3)
lineplot(smoothed, label=True)
lineplot(values, self._steps[name], alpha=0.3)
lineplot(smoothed, self._steps[name], label=True)

plt.title(name)
plt.xlabel('step')
Expand All @@ -118,7 +118,7 @@ def log_scalar(self, source: Source, value: Scalar, *_) -> None:
Scalar to log.
"""

self._values[self.source_to_name(source)].append(value)
self._log(source, value)

def log_array(self, source: Source, value: Array, *_) -> None:
"""
Expand All @@ -132,4 +132,31 @@ def log_array(self, source: Source, value: Array, *_) -> None:
Array to log.
"""

self._values[self.source_to_name(source)].append(value)
self._log(source, value)

def _log(self, source: Source, value: Numeric) -> None:
"""
Adds a given scalar to the plot values.
Parameters
----------
source : Source
Source of the logged value.
value : Numeric
Value to log.
"""

name = self.source_to_name(source)

if self._sync_steps:
if name in self._current_values:
self._step += 1
self._current_values.clear()

self._current_values.add(name)
step = self._step
else:
step = self._steps[name][-1] + 1 if self._steps[name] else 0

self._values[name].append(value)
self._steps[name].append(step)
48 changes: 42 additions & 6 deletions reinforced_lib/logs/tb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,18 @@ def __init__(
self,
tb_log_dir: str = None,
tb_comet_config: dict[str, any] = None,
tb_sync_steps: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)

if tb_comet_config is None:
tb_comet_config = {'disabled': True}

self._sync_steps = tb_sync_steps
self._current_values = set()
self._step = 0

self._writer = SummaryWriter(log_dir=tb_log_dir, comet_config=tb_comet_config)
self._steps = defaultdict(int)

Expand All @@ -57,8 +62,9 @@ def log_scalar(self, source: Source, value: Scalar, *_) -> None:
Scalar to log.
"""

self._writer.add_scalar(self.source_to_name(source), value, self._steps[source])
self._steps[source] += 1
name = self.source_to_name(source)
step = self._get_step(name)
self._writer.add_scalar(name, value, step)

def log_array(self, source: Source, value: Array, *_) -> None:
"""
Expand All @@ -72,8 +78,9 @@ def log_array(self, source: Source, value: Array, *_) -> None:
Array to log.
"""

self._writer.add_histogram(self.source_to_name(source), value, self._steps[source])
self._steps[source] += 1
name = self.source_to_name(source)
step = self._get_step(name)
self._writer.add_histogram(name, value, step)

def log_dict(self, source: Source, value: dict, *_) -> None:
"""
Expand Down Expand Up @@ -101,5 +108,34 @@ def log_other(self, source: Source, value: any, *_) -> None:
Dictionary to log.
"""

self._writer.add_text(self.source_to_name(source), json.dumps(value), self._steps[source])
self._steps[source] += 1
name = self.source_to_name(source)
step = self._get_step(name)
self._writer.add_text(name, json.dumps(value), step)

def _get_step(self, name: str) -> int:
"""
Returns the current step for a given source.
Parameters
----------
name : str
Name of the source.
Returns
-------
int
Current step for the given source.
"""

if self._sync_steps:
if name in self._current_values:
self._step += 1
self._current_values.clear()

self._current_values.add(name)
step = self._step
else:
step = self._steps[name] + 1

self._steps[name] = step
return step

0 comments on commit 6a38384

Please sign in to comment.