Skip to content

Commit

Permalink
Improvement and consistency of stdout and csv loggers
Browse files Browse the repository at this point in the history
  • Loading branch information
m-wojnar committed Jul 15, 2023
1 parent bdc4d0c commit c80b84c
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 138 deletions.
98 changes: 47 additions & 51 deletions reinforced_lib/logs/csv_logger.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,24 @@
import json
import os.path
from typing import Any, Dict, List
from collections import defaultdict
from typing import Any, Dict

import jax.numpy as jnp
import numpy as np
from chex import Array, Scalar

from reinforced_lib.logs import BaseLogger, Source
from reinforced_lib.utils.exceptions import UnsupportedCustomLogsError
from utils import timestamp
from reinforced_lib.utils import timestamp


class CsvLogger(BaseLogger):
"""
Logger that saves values in CSV [1]_ format.
Logger that saves values in CSV format.
Parameters
----------
csv_path : str, default="~/rlib-logs-[date]-[time].csv"
Path to the output file.
References
----------
.. [1] Shafranovich, Y. (2005). Common Format and MIME Type for Comma-Separated Values (CSV) Files
(RFC No. 4180). RFC Editor. https://www.rfc-editor.org/rfc/rfc4180.txt
"""

def __init__(self, csv_path: str = None, **kwargs) -> None:
Expand All @@ -33,34 +28,32 @@ def __init__(self, csv_path: str = None, **kwargs) -> None:
csv_path = f'rlib-logs-{timestamp()}.csv'
csv_path = os.path.join(os.path.expanduser("~"), csv_path)

self._file = open(csv_path, 'w')
self._csv_path = csv_path
self._current_values = set()
self._step = 0

self._columns_values = {}
self._columns_names = []
self._values = defaultdict(list)
self._steps = defaultdict(list)

def init(self, sources: List[Source]) -> None:
def finish(self) -> None:
"""
Creates a list of all source names and writes the header to the output file.
Parameters
----------
sources : list[Source]
List containing all sources to log.
Saves the logged values to the CSV file.
"""

if None in sources:
raise UnsupportedCustomLogsError(type(self))
file = open(self._csv_path, 'w')
file.write(','.join(self._values.keys()) + '\n')

self._columns_names = list(map(self.source_to_name, sources))
header = ','.join(self._columns_names)
self._file.write(f'{header}\n')
rows, cols = self._step + 1, len(self._values)
csv_array = np.full((rows, cols), fill_value='', dtype=object)

def finish(self) -> None:
"""
Closes the output file.
"""
for j, (name, values) in enumerate(self._values.items()):
for i, v in enumerate(values):
csv_array[self._steps[name][i], j] = v

for row in csv_array:
file.write(','.join(map(str, row)) + '\n')

self._file.close()
file.close()

def log_scalar(self, source: Source, value: Scalar, *_) -> None:
"""
Expand All @@ -74,12 +67,11 @@ def log_scalar(self, source: Source, value: Scalar, *_) -> None:
Scalar to log.
"""

self._columns_values[self.source_to_name(source)] = value
self._save()
self._log(source, value)

def log_array(self, source: Source, value: Array, *_) -> None:
"""
Logs an array as a JSON [2]_ string.
Logs an array as a JSON string.
Parameters
----------
Expand All @@ -92,11 +84,11 @@ def log_array(self, source: Source, value: Array, *_) -> None:
if isinstance(value, (np.ndarray, jnp.ndarray)):
value = value.tolist()

self.log_other(source, value)
self._log(source, f"\"{json.dumps(value)}\"")

def log_dict(self, source: Source, value: Dict, *_) -> None:
"""
Logs a dictionary as a JSON [2]_ string.
Logs a dictionary as a JSON string.
Parameters
----------
Expand All @@ -106,36 +98,40 @@ def log_dict(self, source: Source, value: Dict, *_) -> None:
Dictionary to log.
"""

self.log_other(source, value)
self._log(source, f"\"{json.dumps(value)}\"")

def log_other(self, source: Source, value: Any, *_) -> None:
"""
Logs an object as a JSON [2]_ string.
Logs an object as a JSON string.
Parameters
----------
source : Source
Source of the logged value.
value : any
Value of any type to log.
References
----------
.. [2] Felipe Pezoa, Juan L. Reutter, Fernando Suarez, Martin Ugarte, and Domagoj Vrgoc. 2016.
Foundations of JSON Schema. In Proceedings of the 25th International Conference on World Wide Web
(WWW '16). International World Wide Web Conferences Steering Committee, Republic and Canton of Geneva,
CHE, 263–273. https://doi.org/10.1145/2872427.2883029
"""

self._columns_values[self.source_to_name(source)] = f"\"{json.dumps(value)}\""
self._save()
self._log(source, f"\"{json.dumps(value)}\"")

def _save(self) -> None:
def _log(self, source: Source, value: Any) -> None:
"""
Writes a new row to the output file if the values for all columns has already been filled.
Saves the logged value and controls the current step.
Parameters
----------
source : Source
Source of the logged value.
value : any
Value to log.
"""

if len(self._columns_values) == len(self._columns_names):
line = ','.join(str(self._columns_values[name]) for name in self._columns_names)
self._file.write(f'{line}\n')
self._columns_values = {}
name = self.source_to_name(source)

if name in self._current_values:
self._step += 1
self._current_values.clear()

self._current_values.add(name)
self._values[name].append(value)
self._steps[name].append(self._step)
71 changes: 29 additions & 42 deletions reinforced_lib/logs/plots_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from chex import Array, Scalar

from reinforced_lib.logs import BaseLogger, Source
from utils import timestamp
from reinforced_lib.utils import timestamp


class PlotsLogger(BaseLogger):
Expand All @@ -22,41 +22,53 @@ class PlotsLogger(BaseLogger):
plots_ext : str, default="svg"
Extension of the saved plots.
plots_smoothing : float, default=0.6
Weight of the exponential moving average (EMA/EWMA) [3]_ used for smoothing. :math:`\alpha \in [0, 1)`.
Weight of the exponential moving average (EMA/EWMA) [1]_ used for smoothing. :math:`\alpha \in [0, 1)`.
plots_scatter : bool, default=False
Set to ``True`` if you want to generate a scatter plot instead of a line plot.
``plots_smoothing`` parameter does not apply to the scatter plots.
References
----------
.. [3] https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average
.. [1] https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average
"""

def __init__(
self,
plots_dir: str = None,
plots_ext: str = 'svg',
plots_ext: str = 'pdf',
plots_smoothing: Scalar = 0.6,
plots_scatter: bool = False,
**kwargs
) -> None:
assert 1 > plots_smoothing >= 0

super().__init__(**kwargs)

self._plots_dir = plots_dir if plots_dir else os.path.expanduser("~")
self._plots_ext = plots_ext
self._plots_smoothing = plots_smoothing
self._plots_scatter = plots_scatter
self._dir = plots_dir if plots_dir else os.path.expanduser("~")
self._ext = plots_ext
self._smoothing = plots_smoothing
self._scatter = plots_scatter

assert 1 > self._plots_smoothing >= 0
self._current_values = set()
self._step = 0

self._plots_values = defaultdict(list)
self._values = defaultdict(list)
self._steps = defaultdict(list)

def finish(self) -> None:
"""
Shows the generated plots and saves them to the output directory with the specified extension
(the names of the files follow the pattern ``"rlib-plot-[source]-[date]-[time].[ext]"``).
"""

def exponential_moving_average(values: List, weight: Scalar) -> List:
smoothed = [values[0]]

for value in values[1:]:
smoothed.append((1 - weight) * value + weight * smoothed[-1])

return smoothed

def lineplot(values: List, alpha: Scalar = 1.0, label: bool = False) -> None:
values = jnp.array(values)
values = jnp.squeeze(values)
Expand All @@ -80,46 +92,21 @@ def scatterplot(values: List, label: bool = False) -> None:
plt.scatter(xs, val, c=f'C{i % 10}', label=i if label else '', marker='.', s=4)
plt.legend()

for name, values in self._plots_values.items():
filename = f'rlib-plot-{name}-{timestamp()}.{self._plots_ext}'
for name, values in self._values.items():
filename = f'rlib-plot-{name}-{timestamp()}.{self._ext}'

if self._plots_scatter:
if self._scatter:
scatterplot(values, True)
else:
smoothed = self._exponential_moving_average(values, self._plots_smoothing)
smoothed = exponential_moving_average(values, self._smoothing)
lineplot(values, alpha=0.3)
lineplot(smoothed, label=True)

plt.title(name)
plt.xlabel('step')
plt.savefig(os.path.join(self._plots_dir, filename), bbox_inches='tight')
plt.savefig(os.path.join(self._dir, filename), bbox_inches='tight')
plt.show()

@staticmethod
def _exponential_moving_average(values: List, weight: Scalar) -> List:
"""
Calculates the exponential moving average (EMA/EWMA) [3]_ to smooth the plotted values.
Parameters
----------
values : array_like
Original values.
weight : float
Weight of the EMA.
Returns
-------
smoothed : array_like
Smoothed values.
"""

smoothed = [values[0]]

for value in values[1:]:
smoothed.append((1 - weight) * value + weight * smoothed[-1])

return smoothed

def log_scalar(self, source: Source, value: Scalar, *_) -> None:
"""
Adds a given scalar to the plot values.
Expand All @@ -132,7 +119,7 @@ def log_scalar(self, source: Source, value: Scalar, *_) -> None:
Scalar to log.
"""

self._plots_values[self.source_to_name(source)].append(value)
self._values[self.source_to_name(source)].append(value)

def log_array(self, source: Source, value: Array, *_) -> None:
"""
Expand All @@ -146,4 +133,4 @@ def log_array(self, source: Source, value: Array, *_) -> None:
Array to log.
"""

self._plots_values[self.source_to_name(source)].append(value)
self._values[self.source_to_name(source)].append(value)
Loading

0 comments on commit c80b84c

Please sign in to comment.