Skip to content

Commit

Permalink
Add custom logging option
Browse files Browse the repository at this point in the history
  • Loading branch information
m-wojnar committed Jul 15, 2023
1 parent 0800391 commit 2af81e5
Show file tree
Hide file tree
Showing 10 changed files with 141 additions and 70 deletions.
6 changes: 4 additions & 2 deletions examples/cart-pole/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from reinforced_lib import RLib
from reinforced_lib.agents.deep import QLearning
from reinforced_lib.exts import Gymnasium
from reinforced_lib.logs import StdoutLogger, TensorboardLogger


@hk.transform_with_state
Expand Down Expand Up @@ -40,7 +41,8 @@ def run(num_epochs: int, render_every: int, seed: int) -> None:
'epsilon_decay': 0.9975
},
ext_type=Gymnasium,
ext_params={'env_id': 'CartPole-v1'}
ext_params={'env_id': 'CartPole-v1'},
logger_types=[StdoutLogger, TensorboardLogger]
)

for epoch in range(num_epochs):
Expand All @@ -60,7 +62,7 @@ def run(num_epochs: int, render_every: int, seed: int) -> None:
terminal = env_state[2] or env_state[3]
epoch_len += 1

print(f'Epoch {epoch} finished with {epoch_len} steps')
rl.log('epoch_len', epoch_len)


if __name__ == '__main__':
Expand Down
6 changes: 4 additions & 2 deletions examples/pendulum/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from reinforced_lib import RLib
from reinforced_lib.agents.deep import DDPG
from reinforced_lib.exts import Gymnasium
from reinforced_lib.logs import StdoutLogger, TensorboardLogger


@hk.transform_with_state
Expand Down Expand Up @@ -54,7 +55,8 @@ def run(num_epochs: int, render_every: int, seed: int) -> None:
'tau': 0.005,
},
ext_type=Gymnasium,
ext_params={'env_id': 'Pendulum-v1'}
ext_params={'env_id': 'Pendulum-v1'},
logger_types=[StdoutLogger, TensorboardLogger]
)

for epoch in range(num_epochs):
Expand All @@ -74,7 +76,7 @@ def run(num_epochs: int, render_every: int, seed: int) -> None:
terminal = env_state[2] or env_state[3]
rewards_sum += env_state[1]

print(f'Epoch: {epoch}, rewards sum: {rewards_sum}')
rl.log('rewards_sum', rewards_sum)


if __name__ == '__main__':
Expand Down
18 changes: 13 additions & 5 deletions reinforced_lib/logs/base_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class SourceType(Enum):
METRIC = 2


Source = Union[Tuple[str, SourceType], str]
Source = Union[Tuple[str, SourceType], str, None]


class BaseLogger(ABC):
Expand Down Expand Up @@ -43,7 +43,7 @@ def finish(self) -> None:

pass

def log_scalar(self, source: Source, value: Scalar) -> None:
def log_scalar(self, source: Source, value: Scalar, custom: bool) -> None:
"""
Method of the logger interface used for logging scalar values.
Expand All @@ -53,11 +53,13 @@ def log_scalar(self, source: Source, value: Scalar) -> None:
Source of the logged value.
value : float
Scalar to log.
custom : bool
Whether the source is a custom source.
"""

raise UnsupportedLogTypeError(type(self), type(value))

def log_array(self, source: Source, value: Array) -> None:
def log_array(self, source: Source, value: Array, custom: bool) -> None:
"""
Method of the logger interface used for logging arrays.
Expand All @@ -67,11 +69,13 @@ def log_array(self, source: Source, value: Array) -> None:
Source of the logged value.
value : array_like
Array to log.
custom : bool
Whether the source is a custom source.
"""

raise UnsupportedLogTypeError(type(self), type(value))

def log_dict(self, source: Source, value: Dict) -> None:
def log_dict(self, source: Source, value: Dict, custom: bool) -> None:
"""
Method of the logger interface used for logging dictionaries.
Expand All @@ -81,11 +85,13 @@ def log_dict(self, source: Source, value: Dict) -> None:
Source of the logged value.
value : dict
Dictionary to log.
custom : bool
Whether the source is a custom source.
"""

raise UnsupportedLogTypeError(type(self), type(value))

def log_other(self, source: Source, value: Any) -> None:
def log_other(self, source: Source, value: Any, custom: bool) -> None:
"""
Method of the logger interface used for logging other values.
Expand All @@ -95,6 +101,8 @@ def log_other(self, source: Source, value: Any) -> None:
Source of the logged value.
value : any
Value of any type to log.
custom : bool
Whether the source is a custom source.
"""

raise UnsupportedLogTypeError(type(self), type(value))
Expand Down
12 changes: 8 additions & 4 deletions reinforced_lib/logs/csv_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from chex import Array, Scalar

from reinforced_lib.logs import BaseLogger, Source
from reinforced_lib.utils.exceptions import UnsupportedCustomLogsError


class CsvLogger(BaseLogger):
Expand Down Expand Up @@ -48,6 +49,9 @@ def init(self, sources: List[Source]) -> None:
List containing all sources to log.
"""

if None in sources:
raise UnsupportedCustomLogsError(type(self))

self._columns_names = list(map(self.source_to_name, sources))
header = ','.join(self._columns_names)
self._file.write(f'{header}\n')
Expand All @@ -59,7 +63,7 @@ def finish(self) -> None:

self._file.close()

def log_scalar(self, source: Source, value: Scalar) -> None:
def log_scalar(self, source: Source, value: Scalar, *_) -> None:
"""
Logs a scalar as a standard value in a column.
Expand All @@ -74,7 +78,7 @@ def log_scalar(self, source: Source, value: Scalar) -> None:
self._columns_values[self.source_to_name(source)] = value
self._save()

def log_array(self, source: Source, value: Array) -> None:
def log_array(self, source: Source, value: Array, *_) -> None:
"""
Logs an array as a JSON [2]_ string.
Expand All @@ -91,7 +95,7 @@ def log_array(self, source: Source, value: Array) -> None:

self.log_other(source, value)

def log_dict(self, source: Source, value: Dict) -> None:
def log_dict(self, source: Source, value: Dict, *_) -> None:
"""
Logs a dictionary as a JSON [2]_ string.
Expand All @@ -105,7 +109,7 @@ def log_dict(self, source: Source, value: Dict) -> None:

self.log_other(source, value)

def log_other(self, source: Source, value: Any) -> None:
def log_other(self, source: Source, value: Any, *_) -> None:
"""
Logs an object as a JSON [2]_ string.
Expand Down
34 changes: 28 additions & 6 deletions reinforced_lib/logs/logs_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self) -> None:
self._observations_loggers = defaultdict(list)
self._agent_state_loggers = defaultdict(list)
self._metrics_loggers = defaultdict(list)
self._custom_loggers = defaultdict(list)

def add_logger(self, source: Source, logger_type: type, logger_params: Dict[str, Any]) -> None:
"""
Expand All @@ -42,7 +43,7 @@ def add_logger(self, source: Source, logger_type: type, logger_params: Dict[str,
if isinstance(source, tuple):
if len(source) != 2 or not isinstance(source[0], str) or not hasattr(source[1], 'name'):
raise IncorrectSourceTypeError(type(source))
elif not isinstance(source, str):
elif source is not None and not isinstance(source, str):
raise IncorrectSourceTypeError(type(source))

logger = self._logger_instances.get(logger_type, logger_type(**logger_params))
Expand All @@ -58,6 +59,8 @@ def add_logger(self, source: Source, logger_type: type, logger_params: Dict[str,
self._observations_loggers[logger].append((source, source))
self._agent_state_loggers[logger].append((source, source))
self._metrics_loggers[logger].append((source, source))
elif source is None:
self._custom_loggers[logger] = [(None, None)]

self._logger_sources[logger].append(source)
self._logger_instances[logger_type] = logger
Expand Down Expand Up @@ -119,8 +122,22 @@ def update_metrics(self, metric: Any, metric_name: str) -> None:

self._update(self._metrics_loggers, lambda name: metric if name == metric_name else None)

def update_custom(self, value: Any, name: str) -> None:
"""
Passes custom values to loggers.
Parameters
----------
value : any
Value to log.
name : str
Name of the value.
"""

self._update(self._custom_loggers, lambda _: (name, value))

@staticmethod
def _update(loggers: Dict[BaseLogger, List[str]], get_value: Callable) -> None:
def _update(loggers: Dict[BaseLogger, List[Source]], get_value: Callable) -> None:
"""
Passes values to the appropriate loggers and method based on the type and the source of the value.
Expand All @@ -135,11 +152,16 @@ def _update(loggers: Dict[BaseLogger, List[str]], get_value: Callable) -> None:
for logger, sources in loggers.items():
for source, name in sources:
if (value := get_value(name)) is not None:
if name is None:
source, value = value

custom = name is None

if jnp.isscalar(value) or (hasattr(value, 'ndim') and value.ndim == 0):
logger.log_scalar(source, value)
logger.log_scalar(source, value, custom)
elif isinstance(value, dict):
logger.log_dict(source, value)
logger.log_dict(source, value, custom)
elif isinstance(value, (list, tuple)) or (hasattr(value, 'ndim') and value.ndim == 1):
logger.log_array(source, value)
logger.log_array(source, value, custom)
else:
logger.log_other(source, value)
logger.log_other(source, value, custom)
17 changes: 2 additions & 15 deletions reinforced_lib/logs/plots_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,6 @@ def __init__(
assert 1 > self._plots_smoothing >= 0

self._plots_values = defaultdict(list)
self._plots_names = []

def init(self, sources: List[Source]) -> None:
"""
Creates a list of all source names.
Parameters
----------
sources : list[Source]
List containing all sources to log.
"""

self._plots_names = list(map(self.source_to_name, sources))

def finish(self) -> None:
"""
Expand Down Expand Up @@ -135,7 +122,7 @@ def _exponential_moving_average(values: List, weight: Scalar) -> List:

return smoothed

def log_scalar(self, source: Source, value: Scalar) -> None:
def log_scalar(self, source: Source, value: Scalar, *_) -> None:
"""
Adds a given scalar to the plot values.
Expand All @@ -149,7 +136,7 @@ def log_scalar(self, source: Source, value: Scalar) -> None:

self._plots_values[self.source_to_name(source)].append(value)

def log_array(self, source: Source, value: Array) -> None:
def log_array(self, source: Source, value: Array, *_) -> None:
"""
Adds a given array to the plot values.
Expand Down
52 changes: 36 additions & 16 deletions reinforced_lib/logs/stdout_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,17 @@ def init(self, sources: List[Source]) -> None:
List containing all sources to log.
"""

self._names = list(map(self.source_to_name, sources))
self._names = list(map(self.source_to_name, filter(lambda s: s is not None, sources)))

def finish(self) -> None:
"""
Prints the last row if there are any unprinted values left.
"""

if len(self._values) > 0:
print('\t'.join(f'{name}: {value}' for name, value in self._values.items()))
print('\t'.join(f'{name}: {self._values[name]}' for name in self._names if name in self._values))

def log_scalar(self, source: Source, value: Scalar) -> None:
def log_scalar(self, source: Source, value: Scalar, custom: bool) -> None:
"""
Logs a scalar as the standard value.
Expand All @@ -49,12 +49,13 @@ def log_scalar(self, source: Source, value: Scalar) -> None:
Source of the logged value.
value : float
Scalar to log.
custom : bool
Whether the source is a custom source.
"""

self._values[self.source_to_name(source)] = value
self._print()
self._print(source, value, custom)

def log_array(self, source: Source, value: Array) -> None:
def log_array(self, source: Source, value: Array, custom: bool) -> None:
"""
Logs an array as a JSON [2]_ string.
Expand All @@ -64,14 +65,16 @@ def log_array(self, source: Source, value: Array) -> None:
Source of the logged value.
value : array_like
Array to log.
custom : bool
Whether the source is a custom source.
"""

if isinstance(value, (np.ndarray, jnp.ndarray)):
value = value.tolist()

self.log_other(source, value)
self.log_other(source, value, custom)

def log_dict(self, source: Source, value: Dict) -> None:
def log_dict(self, source: Source, value: Dict, custom: bool) -> None:
"""
Logs a dictionary as a JSON [2]_ string.
Expand All @@ -81,11 +84,13 @@ def log_dict(self, source: Source, value: Dict) -> None:
Source of the logged value.
value : dict
Dictionary to log.
custom : bool
Whether the source is a custom source.
"""

self.log_other(source, value)
self.log_other(source, value, custom)

def log_other(self, source: Source, value: Any) -> None:
def log_other(self, source: Source, value: Any, custom: bool) -> None:
"""
Logs an object as a JSON [2]_ string.
Expand All @@ -95,16 +100,31 @@ def log_other(self, source: Source, value: Any) -> None:
Source of the logged value.
value : any
Value of any type to log.
custom : bool
Whether the source is a custom source.
"""

self._values[self.source_to_name(source)] = json.dumps(value)
self._print()
self._print(source, json.dumps(value), custom)

def _print(self) -> None:
def _print(self, source: Source, value: Any, custom: bool) -> None:
"""
Prints a new row to the standard output if all values has already been filled.
Parameters
----------
source : Source
Source of the logged value.
value : any
Value of any type to log.
custom : bool
Whether the source is a custom source.
"""

if len(self._values) == len(self._names):
print('\t'.join(f'{name}: {self._values[name]}' for name in self._names))
self._values = {}
if not custom:
self._values[self.source_to_name(source)] = value

if len(self._values) == len(self._names):
print('\t'.join(f'{name}: {self._values[name]}' for name in self._names))
self._values = {}
else:
print(f'{self.source_to_name(source)}: {value}')
Loading

0 comments on commit 2af81e5

Please sign in to comment.