From 4444d885d37544e3eeb7c6a29c583fd708fc4857 Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 5 May 2022 17:29:21 +0800 Subject: [PATCH] fix bug --- brainpy/__init__.py | 2 +- brainpy/dyn/base.py | 16 ++++++++-------- brainpy/nn/base.py | 1 - 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 4e3f7b6d4..b754e75e8 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -__version__ = "2.1.9" +__version__ = "2.1.10" try: diff --git a/brainpy/dyn/base.py b/brainpy/dyn/base.py index 7fdb080d4..d48c773fb 100644 --- a/brainpy/dyn/base.py +++ b/brainpy/dyn/base.py @@ -172,7 +172,7 @@ def get_delay_data( self, name: str, delay_step: Union[int, bm.JaxArray, jnp.DeviceArray], - indices: Union[int, bm.JaxArray, jnp.DeviceArray] = None, + *indices: Union[int, bm.JaxArray, jnp.DeviceArray], ): """Get delay data according to the provided delay steps. @@ -192,18 +192,18 @@ def get_delay_data( """ if name in self.global_delay_vars: if isinstance(delay_step, int): - return self.global_delay_vars[name](delay_step, indices) + return self.global_delay_vars[name](delay_step, *indices) else: - if indices is None: - indices = jnp.arange(delay_step.size) - return self.global_delay_vars[name](delay_step, indices) + if len(indices) == 0: + indices = (jnp.arange(delay_step.size), ) + return self.global_delay_vars[name](delay_step, *indices) elif name in self.local_delay_vars: if isinstance(delay_step, int): return self.local_delay_vars[name](delay_step) else: - if indices is None: - indices = jnp.arange(delay_step.size) - return self.local_delay_vars[name](delay_step, indices) + if len(indices) == 0: + indices = (jnp.arange(delay_step.size), ) + return self.local_delay_vars[name](delay_step, *indices) else: raise ValueError(f'{name} is not defined in delay variables.') diff --git a/brainpy/nn/base.py b/brainpy/nn/base.py index 38cb6b564..f6e0529f6 100644 --- a/brainpy/nn/base.py +++ b/brainpy/nn/base.py @@ -1478,7 +1478,6 @@ def plot_node_graph(self, legends = dict() if legends is None else legends ax.legend(proxie, labels, scatterpoints=1, markerscale=2, loc='best', **legends) if show: - plt.tight_layout() plt.show()