Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Low-level-prototype #232

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,6 @@ cython_debug/
/docs/saved-models/
/docs/getting-started/saved-models/
/TODO
.git
.git/
.python-version
.devcontainer
.devcontainer
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
repos:
- repo: https://github.com/psf/black
rev: 20.8b1
rev: 22.3.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.9.3
rev: 5.10.1
hooks:
- id: isort
args: ["--profile", "black"]
36 changes: 23 additions & 13 deletions elegy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,38 @@

__version__ = "0.8.6"

from treex import *

import elegy.types as types
import elegy.utils as utils

from treeo import Hashable, compact
from treex import Optimizer


from . import (
callbacks,
data,
model,
# nets,
modules,
strategies,
)

from .model.model import Model
from .model.model_base import ModelBase, load
from .model.model_core import (
GradStepOutput,
PredStepOutput,
TestStepOutput,
TrainStepOutput,
LossStepOutput,
ModelCore,
)
from .model import Model
from .strategies import Strategy

# from .model.model_base import ModelBase, load
# from .model.model_core import (
# GradStepOutput,
# PredStepOutput,
# TestStepOutput,
# TrainStepOutput,
# LossStepOutput,
# ModelCore,
# )
from .types import KeySeq
from .utils import inject_dependencies
from .modules.high_level.high_level_module import HighLevelModule
from .modules.managed.managed_module import ManagedModule
from .modules.module import Module
from .modules.high_level.flax_module import FlaxModule
from .modules.managed.managed_flax_module import ManagedFlaxModule
from .pytree import PytreeObject, field, static_field
3 changes: 3 additions & 0 deletions elegy/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .history import History
from .lambda_callback import LambdaCallback
from .model_checkpoint import ModelCheckpoint
from .module_callback import ModuleCallback
from .remote_monitor import RemoteMonitor
from .sigint import SigInt
from .tensorboard import TensorBoard
Expand All @@ -23,4 +24,6 @@
"CSVLogger",
"TensorBoard",
"WandbCallback",
"SigInt",
"ModuleCallback",
]
67 changes: 25 additions & 42 deletions elegy/callbacks/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import typing as tp

import numpy as np
import jax.numpy as jnp

import elegy


def default(method):
Expand Down Expand Up @@ -45,6 +47,8 @@ class Callback(object):
model (elegy.model.Model): Reference of the model being trained.
"""

model: tp.Optional["elegy.model_full.Model"]

__all__ = [
"on_epoch_begin",
"on_epoch_end",
Expand Down Expand Up @@ -78,41 +82,20 @@ def set_model(self, model):

# @doc_controls.for_subclass_implementers
def on_epoch_begin(
self, epoch: int, logs: tp.Optional[tp.Dict[str, np.ndarray]] = None
self, epoch: int, logs: tp.Optional[tp.Dict[str, jnp.ndarray]] = None
):
"""Called at the def on_epoch_end(self, epoch, logs=None):
current = self.get_monitor_value(logs)
if current is None:
return
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
if self.restore_best_weights:
# This will also save optimizer state
self.best_state = self.model.full_state
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
if self.restore_best_weights:
if self.verbose > 0:
print("Restoring model weights from the end of the best epoch.")
self.model.full_state = self.best_statestart of an epoch.

Subclasses should override for any actions to run. This function should only
be called during TRAIN mode.

Arguments:
epoch: integer, index of epoch.
logs: dict. Currently no data is passed to this argument for this method but
that may change in the future.
"""Calls the `on_epoch_begin` methods of its callbacks.
This function should only be called during TRAIN mode.
Args:
epoch: Integer, index of epoch.
logs: Dict. Currently no data is passed to this argument for this method
but that may change in the future.
"""
pass

# @doc_controls.for_subclass_implementers
def on_epoch_end(
self, epoch: int, logs: tp.Optional[tp.Dict[str, np.ndarray]] = None
self, epoch: int, logs: tp.Optional[tp.Dict[str, jnp.ndarray]] = None
):
"""Called at the end of an epoch.

Expand All @@ -130,7 +113,7 @@ def on_epoch_end(
# @doc_controls.for_subclass_implementers
@default
def on_train_batch_begin(
self, batch: int, logs: tp.Optional[tp.Dict[str, np.ndarray]] = None
self, batch: int, logs: tp.Optional[tp.Dict[str, jnp.ndarray]] = None
):
"""Called at the beginning of a training batch in `fit` methods.

Expand All @@ -146,7 +129,7 @@ def on_train_batch_begin(
# @doc_controls.for_subclass_implementers
@default
def on_train_batch_end(
self, batch: int, logs: tp.Optional[tp.Dict[str, np.ndarray]] = None
self, batch: int, logs: tp.Optional[tp.Dict[str, jnp.ndarray]] = None
):
"""Called at the end of a training batch in `fit` methods.

Expand All @@ -161,7 +144,7 @@ def on_train_batch_end(
# @doc_controls.for_subclass_implementers
@default
def on_test_batch_begin(
self, batch: int, logs: tp.Optional[tp.Dict[str, np.ndarray]] = None
self, batch: int, logs: tp.Optional[tp.Dict[str, jnp.ndarray]] = None
):
"""Called at the beginning of a batch in `evaluate` methods.

Expand All @@ -180,7 +163,7 @@ def on_test_batch_begin(
# @doc_controls.for_subclass_implementers
@default
def on_test_batch_end(
self, batch: int, logs: tp.Optional[tp.Dict[str, np.ndarray]] = None
self, batch: int, logs: tp.Optional[tp.Dict[str, jnp.ndarray]] = None
):
"""Called at the end of a batch in `evaluate` methods.

Expand All @@ -198,7 +181,7 @@ def on_test_batch_end(
# @doc_controls.for_subclass_implementers
@default
def on_predict_batch_begin(
self, batch: int, logs: tp.Optional[tp.Dict[str, np.ndarray]] = None
self, batch: int, logs: tp.Optional[tp.Dict[str, jnp.ndarray]] = None
):
"""Called at the beginning of a batch in `predict` methods.

Expand All @@ -214,7 +197,7 @@ def on_predict_batch_begin(
# @doc_controls.for_subclass_implementers
@default
def on_predict_batch_end(
self, batch: int, logs: tp.Optional[tp.Dict[str, np.ndarray]] = None
self, batch: int, logs: tp.Optional[tp.Dict[str, jnp.ndarray]] = None
):
"""Called at the end of a batch in `predict` methods.

Expand All @@ -227,7 +210,7 @@ def on_predict_batch_end(
pass

# @doc_controls.for_subclass_implementers
def on_train_begin(self, logs: tp.Optional[tp.Dict[str, np.ndarray]] = None):
def on_train_begin(self, logs: tp.Optional[tp.Dict[str, jnp.ndarray]] = None):
"""Called at the beginning of training.

Subclasses should override for any actions to run.
Expand All @@ -239,7 +222,7 @@ def on_train_begin(self, logs: tp.Optional[tp.Dict[str, np.ndarray]] = None):
pass

# @doc_controls.for_subclass_implementers
def on_train_end(self, logs: tp.Optional[tp.Dict[str, np.ndarray]] = None):
def on_train_end(self, logs: tp.Optional[tp.Dict[str, jnp.ndarray]] = None):
"""Called at the end of training.

Subclasses should override for any actions to run.
Expand All @@ -251,7 +234,7 @@ def on_train_end(self, logs: tp.Optional[tp.Dict[str, np.ndarray]] = None):
pass

# @doc_controls.for_subclass_implementers
def on_test_begin(self, logs: tp.Optional[tp.Dict[str, np.ndarray]] = None):
def on_test_begin(self, logs: tp.Optional[tp.Dict[str, jnp.ndarray]] = None):
"""Called at the beginning of evaluation or validation.

Subclasses should override for any actions to run.
Expand All @@ -263,7 +246,7 @@ def on_test_begin(self, logs: tp.Optional[tp.Dict[str, np.ndarray]] = None):
pass

# @doc_controls.for_subclass_implementers
def on_test_end(self, logs: tp.Optional[tp.Dict[str, np.ndarray]] = None):
def on_test_end(self, logs: tp.Optional[tp.Dict[str, jnp.ndarray]] = None):
"""Called at the end of evaluation or validation.

Subclasses should override for any actions to run.
Expand All @@ -275,7 +258,7 @@ def on_test_end(self, logs: tp.Optional[tp.Dict[str, np.ndarray]] = None):
pass

# @doc_controls.for_subclass_implementers
def on_predict_begin(self, logs: tp.Optional[tp.Dict[str, np.ndarray]] = None):
def on_predict_begin(self, logs: tp.Optional[tp.Dict[str, jnp.ndarray]] = None):
"""Called at the beginning of prediction.

Subclasses should override for any actions to run.
Expand All @@ -287,7 +270,7 @@ def on_predict_begin(self, logs: tp.Optional[tp.Dict[str, np.ndarray]] = None):
pass

# @doc_controls.for_subclass_implementers
def on_predict_end(self, logs: tp.Optional[tp.Dict[str, np.ndarray]] = None):
def on_predict_end(self, logs: tp.Optional[tp.Dict[str, jnp.ndarray]] = None):
"""Called at the end of prediction.

Subclasses should override for any actions to run.
Expand Down
13 changes: 11 additions & 2 deletions elegy/callbacks/callback_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from .callback import Callback
from .history import History
from .module_callback import ModuleCallback
from .progbar_logger import ProgbarLogger


Expand Down Expand Up @@ -39,6 +40,7 @@ def __init__(
add_history: bool = False,
add_progbar: bool = False,
sigint_mode: tp.Optional[SigIntMode] = None,
add_module: bool = False,
model: tp.Optional[tp.Any] = None,
**params
):
Expand All @@ -55,7 +57,7 @@ def __init__(
`Callback.set_params`.
"""
self.callbacks = callbacks if callbacks else []
self._add_default_callbacks(add_history, add_progbar, sigint_mode)
self._add_default_callbacks(add_history, add_progbar, sigint_mode, add_module)

if model:
self.set_model(model)
Expand All @@ -81,7 +83,11 @@ def __init__(
# pylint: enable=protected-access

def _add_default_callbacks(
self, add_history, add_progbar, sigint_mode: tp.Optional[SigIntMode]
self,
add_history,
add_progbar,
sigint_mode: tp.Optional[SigIntMode],
add_module: bool,
):
"""Adds `Callback`s that are always present."""
self._progbar = None
Expand All @@ -104,6 +110,9 @@ def _add_default_callbacks(
if sigint_mode is not None:
self.callbacks.append(SigInt(sigint_mode))

if add_module:
self.callbacks.append(ModuleCallback())

def _reset_batch_timing(self):
self._delta_t_batch = 0.0
self._delta_ts = collections.defaultdict(
Expand Down
Loading