diff --git a/reinforced_lib/logs/csv_logger.py b/reinforced_lib/logs/csv_logger.py index d8f3e58..63637c6 100644 --- a/reinforced_lib/logs/csv_logger.py +++ b/reinforced_lib/logs/csv_logger.py @@ -1,29 +1,24 @@ import json import os.path -from typing import Any, Dict, List +from collections import defaultdict +from typing import Any, Dict import jax.numpy as jnp import numpy as np from chex import Array, Scalar from reinforced_lib.logs import BaseLogger, Source -from reinforced_lib.utils.exceptions import UnsupportedCustomLogsError -from utils import timestamp +from reinforced_lib.utils import timestamp class CsvLogger(BaseLogger): """ - Logger that saves values in CSV [1]_ format. + Logger that saves values in CSV format. Parameters ---------- csv_path : str, default="~/rlib-logs-[date]-[time].csv" Path to the output file. - - References - ---------- - .. [1] Shafranovich, Y. (2005). Common Format and MIME Type for Comma-Separated Values (CSV) Files - (RFC No. 4180). RFC Editor. https://www.rfc-editor.org/rfc/rfc4180.txt """ def __init__(self, csv_path: str = None, **kwargs) -> None: @@ -33,34 +28,32 @@ def __init__(self, csv_path: str = None, **kwargs) -> None: csv_path = f'rlib-logs-{timestamp()}.csv' csv_path = os.path.join(os.path.expanduser("~"), csv_path) - self._file = open(csv_path, 'w') + self._csv_path = csv_path + self._current_values = set() + self._step = 0 - self._columns_values = {} - self._columns_names = [] + self._values = defaultdict(list) + self._steps = defaultdict(list) - def init(self, sources: List[Source]) -> None: + def finish(self) -> None: """ - Creates a list of all source names and writes the header to the output file. - - Parameters - ---------- - sources : list[Source] - List containing all sources to log. + Saves the logged values to the CSV file. """ - if None in sources: - raise UnsupportedCustomLogsError(type(self)) + file = open(self._csv_path, 'w') + file.write(','.join(self._values.keys()) + '\n') - self._columns_names = list(map(self.source_to_name, sources)) - header = ','.join(self._columns_names) - self._file.write(f'{header}\n') + rows, cols = self._step + 1, len(self._values) + csv_array = np.full((rows, cols), fill_value='', dtype=object) - def finish(self) -> None: - """ - Closes the output file. - """ + for j, (name, values) in enumerate(self._values.items()): + for i, v in enumerate(values): + csv_array[self._steps[name][i], j] = v + + for row in csv_array: + file.write(','.join(map(str, row)) + '\n') - self._file.close() + file.close() def log_scalar(self, source: Source, value: Scalar, *_) -> None: """ @@ -74,12 +67,11 @@ def log_scalar(self, source: Source, value: Scalar, *_) -> None: Scalar to log. """ - self._columns_values[self.source_to_name(source)] = value - self._save() + self._log(source, value) def log_array(self, source: Source, value: Array, *_) -> None: """ - Logs an array as a JSON [2]_ string. + Logs an array as a JSON string. Parameters ---------- @@ -92,11 +84,11 @@ def log_array(self, source: Source, value: Array, *_) -> None: if isinstance(value, (np.ndarray, jnp.ndarray)): value = value.tolist() - self.log_other(source, value) + self._log(source, f"\"{json.dumps(value)}\"") def log_dict(self, source: Source, value: Dict, *_) -> None: """ - Logs a dictionary as a JSON [2]_ string. + Logs a dictionary as a JSON string. Parameters ---------- @@ -106,11 +98,11 @@ def log_dict(self, source: Source, value: Dict, *_) -> None: Dictionary to log. """ - self.log_other(source, value) + self._log(source, f"\"{json.dumps(value)}\"") def log_other(self, source: Source, value: Any, *_) -> None: """ - Logs an object as a JSON [2]_ string. + Logs an object as a JSON string. Parameters ---------- @@ -118,24 +110,28 @@ def log_other(self, source: Source, value: Any, *_) -> None: Source of the logged value. value : any Value of any type to log. - - References - ---------- - .. [2] Felipe Pezoa, Juan L. Reutter, Fernando Suarez, Martin Ugarte, and Domagoj Vrgoc. 2016. - Foundations of JSON Schema. In Proceedings of the 25th International Conference on World Wide Web - (WWW '16). International World Wide Web Conferences Steering Committee, Republic and Canton of Geneva, - CHE, 263–273. https://doi.org/10.1145/2872427.2883029 """ - self._columns_values[self.source_to_name(source)] = f"\"{json.dumps(value)}\"" - self._save() + self._log(source, f"\"{json.dumps(value)}\"") - def _save(self) -> None: + def _log(self, source: Source, value: Any) -> None: """ - Writes a new row to the output file if the values for all columns has already been filled. + Saves the logged value and controls the current step. + + Parameters + ---------- + source : Source + Source of the logged value. + value : any + Value to log. """ - if len(self._columns_values) == len(self._columns_names): - line = ','.join(str(self._columns_values[name]) for name in self._columns_names) - self._file.write(f'{line}\n') - self._columns_values = {} + name = self.source_to_name(source) + + if name in self._current_values: + self._step += 1 + self._current_values.clear() + + self._current_values.add(name) + self._values[name].append(value) + self._steps[name].append(self._step) diff --git a/reinforced_lib/logs/plots_logger.py b/reinforced_lib/logs/plots_logger.py index 76af2f0..54268df 100644 --- a/reinforced_lib/logs/plots_logger.py +++ b/reinforced_lib/logs/plots_logger.py @@ -7,7 +7,7 @@ from chex import Array, Scalar from reinforced_lib.logs import BaseLogger, Source -from utils import timestamp +from reinforced_lib.utils import timestamp class PlotsLogger(BaseLogger): @@ -22,34 +22,38 @@ class PlotsLogger(BaseLogger): plots_ext : str, default="svg" Extension of the saved plots. plots_smoothing : float, default=0.6 - Weight of the exponential moving average (EMA/EWMA) [3]_ used for smoothing. :math:`\alpha \in [0, 1)`. + Weight of the exponential moving average (EMA/EWMA) [1]_ used for smoothing. :math:`\alpha \in [0, 1)`. plots_scatter : bool, default=False Set to ``True`` if you want to generate a scatter plot instead of a line plot. ``plots_smoothing`` parameter does not apply to the scatter plots. References ---------- - .. [3] https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average + .. [1] https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average """ def __init__( self, plots_dir: str = None, - plots_ext: str = 'svg', + plots_ext: str = 'pdf', plots_smoothing: Scalar = 0.6, plots_scatter: bool = False, **kwargs ) -> None: + assert 1 > plots_smoothing >= 0 + super().__init__(**kwargs) - self._plots_dir = plots_dir if plots_dir else os.path.expanduser("~") - self._plots_ext = plots_ext - self._plots_smoothing = plots_smoothing - self._plots_scatter = plots_scatter + self._dir = plots_dir if plots_dir else os.path.expanduser("~") + self._ext = plots_ext + self._smoothing = plots_smoothing + self._scatter = plots_scatter - assert 1 > self._plots_smoothing >= 0 + self._current_values = set() + self._step = 0 - self._plots_values = defaultdict(list) + self._values = defaultdict(list) + self._steps = defaultdict(list) def finish(self) -> None: """ @@ -57,6 +61,14 @@ def finish(self) -> None: (the names of the files follow the pattern ``"rlib-plot-[source]-[date]-[time].[ext]"``). """ + def exponential_moving_average(values: List, weight: Scalar) -> List: + smoothed = [values[0]] + + for value in values[1:]: + smoothed.append((1 - weight) * value + weight * smoothed[-1]) + + return smoothed + def lineplot(values: List, alpha: Scalar = 1.0, label: bool = False) -> None: values = jnp.array(values) values = jnp.squeeze(values) @@ -80,46 +92,21 @@ def scatterplot(values: List, label: bool = False) -> None: plt.scatter(xs, val, c=f'C{i % 10}', label=i if label else '', marker='.', s=4) plt.legend() - for name, values in self._plots_values.items(): - filename = f'rlib-plot-{name}-{timestamp()}.{self._plots_ext}' + for name, values in self._values.items(): + filename = f'rlib-plot-{name}-{timestamp()}.{self._ext}' - if self._plots_scatter: + if self._scatter: scatterplot(values, True) else: - smoothed = self._exponential_moving_average(values, self._plots_smoothing) + smoothed = exponential_moving_average(values, self._smoothing) lineplot(values, alpha=0.3) lineplot(smoothed, label=True) plt.title(name) plt.xlabel('step') - plt.savefig(os.path.join(self._plots_dir, filename), bbox_inches='tight') + plt.savefig(os.path.join(self._dir, filename), bbox_inches='tight') plt.show() - @staticmethod - def _exponential_moving_average(values: List, weight: Scalar) -> List: - """ - Calculates the exponential moving average (EMA/EWMA) [3]_ to smooth the plotted values. - - Parameters - ---------- - values : array_like - Original values. - weight : float - Weight of the EMA. - - Returns - ------- - smoothed : array_like - Smoothed values. - """ - - smoothed = [values[0]] - - for value in values[1:]: - smoothed.append((1 - weight) * value + weight * smoothed[-1]) - - return smoothed - def log_scalar(self, source: Source, value: Scalar, *_) -> None: """ Adds a given scalar to the plot values. @@ -132,7 +119,7 @@ def log_scalar(self, source: Source, value: Scalar, *_) -> None: Scalar to log. """ - self._plots_values[self.source_to_name(source)].append(value) + self._values[self.source_to_name(source)].append(value) def log_array(self, source: Source, value: Array, *_) -> None: """ @@ -146,4 +133,4 @@ def log_array(self, source: Source, value: Array, *_) -> None: Array to log. """ - self._plots_values[self.source_to_name(source)].append(value) + self._values[self.source_to_name(source)].append(value) diff --git a/reinforced_lib/logs/stdout_logger.py b/reinforced_lib/logs/stdout_logger.py index 0334875..5af9a08 100644 --- a/reinforced_lib/logs/stdout_logger.py +++ b/reinforced_lib/logs/stdout_logger.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, List +from typing import Any, Dict import jax.numpy as jnp import numpy as np @@ -17,19 +17,6 @@ def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self._values = {} - self._names = [] - - def init(self, sources: List[Source]) -> None: - """ - Creates a list of all source names. - - Parameters - ---------- - sources : list[Source] - List containing all sources to log. - """ - - self._names = list(map(self.source_to_name, filter(lambda s: s is not None, sources))) def finish(self) -> None: """ @@ -37,7 +24,7 @@ def finish(self) -> None: """ if len(self._values) > 0: - print('\t'.join(f'{name}: {self._values[name]}' for name in self._names if name in self._values)) + print('\t'.join(f'{n}: {v}' for n, v in self._values.items())) def log_scalar(self, source: Source, value: Scalar, custom: bool) -> None: """ @@ -53,11 +40,11 @@ def log_scalar(self, source: Source, value: Scalar, custom: bool) -> None: Whether the source is a custom source. """ - self._print(source, value, custom) + self._log(source, value, custom) def log_array(self, source: Source, value: Array, custom: bool) -> None: """ - Logs an array as a JSON [2]_ string. + Logs an array as a JSON string. Parameters ---------- @@ -72,11 +59,11 @@ def log_array(self, source: Source, value: Array, custom: bool) -> None: if isinstance(value, (np.ndarray, jnp.ndarray)): value = value.tolist() - self.log_other(source, value, custom) + self._log(source, json.dumps(value), custom) def log_dict(self, source: Source, value: Dict, custom: bool) -> None: """ - Logs a dictionary as a JSON [2]_ string. + Logs a dictionary as a JSON string. Parameters ---------- @@ -88,11 +75,11 @@ def log_dict(self, source: Source, value: Dict, custom: bool) -> None: Whether the source is a custom source. """ - self.log_other(source, value, custom) + self._log(source, json.dumps(value), custom) def log_other(self, source: Source, value: Any, custom: bool) -> None: """ - Logs an object as a JSON [2]_ string. + Logs an object as a JSON string. Parameters ---------- @@ -104,11 +91,12 @@ def log_other(self, source: Source, value: Any, custom: bool) -> None: Whether the source is a custom source. """ - self._print(source, json.dumps(value), custom) + self._log(source, json.dumps(value), custom) - def _print(self, source: Source, value: Any, custom: bool) -> None: + def _log(self, source: Source, value: Any, custom: bool) -> None: """ - Prints a new row to the standard output if all values has already been filled. + Prints a new row to the standard output if there is a new value for a + standard source or the source is custom. Parameters ---------- @@ -120,11 +108,13 @@ def _print(self, source: Source, value: Any, custom: bool) -> None: Whether the source is a custom source. """ - if not custom: - self._values[self.source_to_name(source)] = value + name = self.source_to_name(source) - if len(self._values) == len(self._names): - print('\t'.join(f'{name}: {self._values[name]}' for name in self._names)) + if not custom: + if name in self._values: + print('\t'.join(f'{n}: {v}' for n, v in self._values.items())) self._values = {} + + self._values[name] = value else: - print(f'{self.source_to_name(source)}: {value}') + print(f'{name}: {value}') diff --git a/reinforced_lib/logs/tb_logger.py b/reinforced_lib/logs/tb_logger.py index b175daf..5601cf6 100644 --- a/reinforced_lib/logs/tb_logger.py +++ b/reinforced_lib/logs/tb_logger.py @@ -10,7 +10,7 @@ class TensorboardLogger(BaseLogger): """ - Logger that saves values in TensorBoard [4]_ format. Offers a possibility to log to Comet [5]_. + Logger that saves values in TensorBoard [2]_ format. Offers a possibility to log to Comet [3]_. Parameters ---------- @@ -21,8 +21,8 @@ class TensorboardLogger(BaseLogger): References ---------- - .. [4] TensorBoard. https://www.tensorflow.org/tensorboard - .. [5] Comet. https://www.comet.ml + .. [2] TensorBoard. https://www.tensorflow.org/tensorboard + .. [3] Comet. https://www.comet.ml """ def __init__( @@ -36,15 +36,15 @@ def __init__( if tb_comet_config is None: tb_comet_config = {'disabled': True} - self._summary_writer = SummaryWriter(log_dir=tb_log_dir, comet_config=tb_comet_config) - self._step = defaultdict(int) + self._writer = SummaryWriter(log_dir=tb_log_dir, comet_config=tb_comet_config) + self._steps = defaultdict(int) def finish(self) -> None: """ Closes the summary writer. """ - self._summary_writer.close() + self._writer.close() def log_scalar(self, source: Source, value: Scalar, *_) -> None: """ @@ -58,8 +58,8 @@ def log_scalar(self, source: Source, value: Scalar, *_) -> None: Scalar to log. """ - self._summary_writer.add_scalar(self.source_to_name(source), value, self._step[source]) - self._step[source] += 1 + self._writer.add_scalar(self.source_to_name(source), value, self._steps[source]) + self._steps[source] += 1 def log_array(self, source: Source, value: Array, *_) -> None: """ @@ -73,12 +73,12 @@ def log_array(self, source: Source, value: Array, *_) -> None: Array to log. """ - self._summary_writer.add_histogram(self.source_to_name(source), value, self._step[source]) - self._step[source] += 1 + self._writer.add_histogram(self.source_to_name(source), value, self._steps[source]) + self._steps[source] += 1 def log_dict(self, source: Source, value: Dict, *_) -> None: """ - Logs a dictionary as a JSON [2]_ string. + Logs a dictionary as a JSON string. Parameters ---------- @@ -92,7 +92,7 @@ def log_dict(self, source: Source, value: Dict, *_) -> None: def log_other(self, source: Source, value: Any, *_) -> None: """ - Logs an object as a JSON [2]_ string. + Logs an object as a JSON string. Parameters ---------- @@ -102,5 +102,5 @@ def log_other(self, source: Source, value: Any, *_) -> None: Dictionary to log. """ - self._summary_writer.add_text(self.source_to_name(source), json.dumps(value), self._step[source]) - self._step[source] += 1 + self._writer.add_text(self.source_to_name(source), json.dumps(value), self._steps[source]) + self._steps[source] += 1 diff --git a/reinforced_lib/rlib.py b/reinforced_lib/rlib.py index f77ce66..83ee3be 100644 --- a/reinforced_lib/rlib.py +++ b/reinforced_lib/rlib.py @@ -15,8 +15,7 @@ from reinforced_lib.logs import Source from reinforced_lib.logs.logs_observer import LogsObserver from reinforced_lib.utils.exceptions import * -from reinforced_lib.utils import is_scalar -from utils import timestamp +from reinforced_lib.utils import is_scalar, timestamp @dataclass