Skip to content

Commit

Permalink
improve oo-to-function transformation speed (#208)
Browse files Browse the repository at this point in the history
improve oo-to-function transformation speed
  • Loading branch information
chaoming0625 authored May 15, 2022
2 parents 683e5cb + 44eac4a commit 231b566
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 572 deletions.
10 changes: 6 additions & 4 deletions brainpy/dyn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ class DynamicalSystem(Base):

"""Global delay variables. Useful when the same target
variable is used in multiple mappings."""
global_delay_vars: Dict[str, bm.LengthDelay] = dict()
global_delay_vars: Dict[str, bm.LengthDelay] = Collector()

def __init__(self, name=None):
super(DynamicalSystem, self).__init__(name=name)

# local delay variables
self.local_delay_vars: Dict[str, bm.LengthDelay] = dict()
self.local_delay_vars: Dict[str, bm.LengthDelay] = Collector()

def __repr__(self):
return f'{self.__class__.__name__}(name={self.name})'
Expand Down Expand Up @@ -334,15 +334,17 @@ def reset(self):

@classmethod
def has(cls, **children_cls):
"""
"""The aggressive operation to gather master and children classes.
Parameters
----------
children_cls
The children classes.
Returns
-------
wrapper: ContainerWrapper
A wrapper which has master and its children classes.
"""
return ContainerWrapper(master=cls, **children_cls)

Expand Down
16 changes: 8 additions & 8 deletions brainpy/math/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ def call(xs=None, length=None):
turn_off_global_jit()
except UnexpectedTracerError as e:
turn_off_global_jit()
for v, d in zip(dyn_vars, init_values): v.value = d
for v, d in zip(dyn_vars, init_values): v._value = d
raise errors.JaxTracerError(variables=dyn_vars) from e
for v, d in zip(dyn_vars, dyn_values): v.value = d
for v, d in zip(dyn_vars, dyn_values): v._value = d
return tree_unflatten(tree, out_values), results

else:
Expand All @@ -189,7 +189,7 @@ def call(xs):
turn_off_global_jit()
for v, d in zip(dyn_vars, init_values): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_values): v.value = d
for v, d in zip(dyn_vars, dyn_values): v._value = d
return tree_unflatten(tree, out_values)

return call
Expand Down Expand Up @@ -271,7 +271,7 @@ def call(x=None):
turn_off_global_jit()
for v, d in zip(dyn_vars, dyn_init): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_values): v.value = d
for v, d in zip(dyn_vars, dyn_values): v._value = d

return call

Expand Down Expand Up @@ -359,7 +359,7 @@ def call(pred, x=None):
turn_off_global_jit()
for v, d in zip(dyn_vars, old_values): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_values): v.value = d
for v, d in zip(dyn_vars, dyn_values): v._value = d
return res

else:
Expand Down Expand Up @@ -477,7 +477,7 @@ def _false_fun(op):
turn_off_global_jit()
for v, d in zip(dyn_vars, old_values): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_values): v.value = d
for v, d in zip(dyn_vars, dyn_values): v._value = d
else:
turn_on_global_jit()
res = lax.cond(pred, true_fun, false_fun, operands)
Expand Down Expand Up @@ -663,7 +663,7 @@ def fun2scan(dyn_vals, x):
turn_off_global_jit()
for v, d in zip(dyn_vars, init_vals): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_vals): v.value = d
for v, d in zip(dyn_vars, dyn_vals): v._value = d
return out_vals


Expand Down Expand Up @@ -729,4 +729,4 @@ def _cond_fun(op):
turn_off_global_jit()
for v, d in zip(dyn_vars, dyn_init): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_values): v.value = d
for v, d in zip(dyn_vars, dyn_values): v._value = d
2 changes: 1 addition & 1 deletion brainpy/math/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def call(*args, **kwargs):
turn_off_global_jit()
for key, v in vars.items(): v._value = variable_data[key]
raise e
vars.assign(changes)
for key, v in vars.items(): v._value = changes[key]
return out

return change_func_name(name=f_name, f=call) if f_name else call
Expand Down
1 change: 1 addition & 0 deletions docs/apis/dyn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
auto/dyn/neurons
auto/dyn/synapses
auto/dyn/rates
auto/dyn/others
auto/dyn/runners
12 changes: 10 additions & 2 deletions docs/auto_generater.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ def generate_dyn_docs(path='apis/auto/dyn/'):
module_and_name = [
('biological_models', 'Biological Models'),
('fractional_models', 'Fractional-order Models'),
('input_models', 'Input Models'),
('reduced_models', 'Reduced Models'),
]
write_submodules(module_name='brainpy.dyn.neurons',
Expand All @@ -278,14 +277,23 @@ def generate_dyn_docs(path='apis/auto/dyn/'):
module_and_name = [
('populations', 'Population Models'),
('couplings', 'Coupling Models'),
('noises', 'Noise Models'),
]
write_submodules(module_name='brainpy.dyn.rates',
filename=os.path.join(path, 'rates.rst'),
header='Rate Models',
submodule_names=[a[0] for a in module_and_name],
section_names=[a[1] for a in module_and_name])

module_and_name = [
('noises', 'Noise Models'),
('inputs', 'Input Models'),
]
write_submodules(module_name='brainpy.dyn.others',
filename=os.path.join(path, 'others.rst'),
header='Helper Models',
submodule_names=[a[0] for a in module_and_name],
section_names=[a[1] for a in module_and_name])

write_module(module_name='brainpy.dyn.runners',
filename=os.path.join(path, 'runners.rst'),
header='Runners')
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ The code of BrainPy is open-sourced at GitHub:

tutorial_toolbox/ode_numerical_solvers
tutorial_toolbox/sde_numerical_solvers
tutorial_toolbox/dde_numerical_solvers
tutorial_toolbox/fde_numerical_solvers
tutorial_toolbox/dde_numerical_solvers
tutorial_toolbox/joint_equations
tutorial_toolbox/synaptic_connections
tutorial_toolbox/synaptic_weights
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorial_toolbox/fde_numerical_solvers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
"source": [
"Factional differential equations have several definitions. It can be defined in a variety of different ways that do often do not all lead to the same result even for smooth functions. In neuroscience, we usually use the following two definitions:\n",
"\n",
"- Riemann–Liouville fractional derivative\n",
"- Grünwald-Letnikov derivative\n",
"- Caputo fractional derivative\n",
"\n",
"See [Fractional calculus - Wikipedia](https://en.wikipedia.org/wiki/Fractional_calculus) for more details."
Expand Down Expand Up @@ -421,7 +421,7 @@
{
"cell_type": "markdown",
"source": [
"## Methods for Riemann–Liouville FDEs"
"## Methods for Grünwald-Letnikov FDEs"
],
"metadata": {
"collapsed": false,
Expand Down
Loading

0 comments on commit 231b566

Please sign in to comment.