Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rewritten jnp.squeeze(self._value_layer(inputs), axis=-1) #296

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
11 changes: 2 additions & 9 deletions acme/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,14 @@
# Internal import.

# Expose specs and types modules.
from acme import specs
from acme import types
from acme import specs, types

# Make __version__ accessible.
from acme._metadata import __version__

# Expose core interfaces.
from acme.core import Actor
from acme.core import Learner
from acme.core import Saveable
from acme.core import VariableSource
from acme.core import Worker
from acme.core import Actor, Learner, Saveable, VariableSource, Worker

# Expose the environment loop.
from acme.environment_loop import EnvironmentLoop

from acme.specs import make_environment_spec

8 changes: 4 additions & 4 deletions acme/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
"""

# We follow Semantic Versioning (https://semver.org/)
_MAJOR_VERSION = '0'
_MINOR_VERSION = '4'
_PATCH_VERSION = '1'
_MAJOR_VERSION = "0"
_MINOR_VERSION = "4"
_PATCH_VERSION = "1"

# Example: '0.4.2'
__version__ = '.'.join([_MAJOR_VERSION, _MINOR_VERSION, _PATCH_VERSION])
__version__ = ".".join([_MAJOR_VERSION, _MINOR_VERSION, _PATCH_VERSION])
3 changes: 1 addition & 2 deletions acme/adders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,4 @@
# pylint: disable=unused-import

from acme.adders.base import Adder
from acme.adders.wrappers import ForkingAdder
from acme.adders.wrappers import IgnoreExtrasAdder
from acme.adders.wrappers import ForkingAdder, IgnoreExtrasAdder
34 changes: 17 additions & 17 deletions acme/adders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@

import abc

from acme import types
import dm_env

from acme import types


class Adder(abc.ABC):
"""The Adder interface.
"""The Adder interface.

An adder packs together data to send to the replay buffer, and potentially
performs some reduction/transformation to this data in the process.
Expand Down Expand Up @@ -49,9 +50,9 @@ class Adder(abc.ABC):
timestep is named `next_timestep` precisely to emphasize this point.
"""

@abc.abstractmethod
def add_first(self, timestep: dm_env.TimeStep):
"""Defines the interface for an adder's `add_first` method.
@abc.abstractmethod
def add_first(self, timestep: dm_env.TimeStep):
"""Defines the interface for an adder's `add_first` method.

We expect this to be called at the beginning of each episode and it will
start a trajectory to be added to replay with an initial observation.
Expand All @@ -60,14 +61,14 @@ def add_first(self, timestep: dm_env.TimeStep):
timestep: a dm_env TimeStep corresponding to the first step.
"""

@abc.abstractmethod
def add(
self,
action: types.NestedArray,
next_timestep: dm_env.TimeStep,
extras: types.NestedArray = (),
):
"""Defines the adder `add` interface.
@abc.abstractmethod
def add(
self,
action: types.NestedArray,
next_timestep: dm_env.TimeStep,
extras: types.NestedArray = (),
):
"""Defines the adder `add` interface.

Args:
action: A possibly nested structure corresponding to a_t.
Expand All @@ -76,7 +77,6 @@ def add(
extras: A possibly nested structure of extra data to add to replay.
"""

@abc.abstractmethod
def reset(self):
"""Resets the adder's buffer."""

@abc.abstractmethod
def reset(self):
"""Resets the adder's buffer."""
26 changes: 14 additions & 12 deletions acme/adders/reverb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,19 @@

# pylint: disable=unused-import

from acme.adders.reverb.base import DEFAULT_PRIORITY_TABLE
from acme.adders.reverb.base import PriorityFn
from acme.adders.reverb.base import PriorityFnInput
from acme.adders.reverb.base import ReverbAdder
from acme.adders.reverb.base import Step
from acme.adders.reverb.base import Trajectory

from acme.adders.reverb.base import (
DEFAULT_PRIORITY_TABLE,
PriorityFn,
PriorityFnInput,
ReverbAdder,
Step,
Trajectory,
)
from acme.adders.reverb.episode import EpisodeAdder
from acme.adders.reverb.sequence import EndBehavior
from acme.adders.reverb.sequence import SequenceAdder
from acme.adders.reverb.structured import create_n_step_transition_config
from acme.adders.reverb.structured import create_step_spec
from acme.adders.reverb.structured import StructuredAdder
from acme.adders.reverb.sequence import EndBehavior, SequenceAdder
from acme.adders.reverb.structured import (
StructuredAdder,
create_n_step_transition_config,
create_step_spec,
)
from acme.adders.reverb.transition import NStepTransitionAdder
Loading