Skip to content

Commit

Permalink
improve oo-to-function transformation speed
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed May 14, 2022
1 parent 8ef3a56 commit b23fb4c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
16 changes: 8 additions & 8 deletions brainpy/math/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ def call(xs=None, length=None):
turn_off_global_jit()
except UnexpectedTracerError as e:
turn_off_global_jit()
for v, d in zip(dyn_vars, init_values): v.value = d
for v, d in zip(dyn_vars, init_values): v._value = d
raise errors.JaxTracerError(variables=dyn_vars) from e
for v, d in zip(dyn_vars, dyn_values): v.value = d
for v, d in zip(dyn_vars, dyn_values): v._value = d
return tree_unflatten(tree, out_values), results

else:
Expand All @@ -189,7 +189,7 @@ def call(xs):
turn_off_global_jit()
for v, d in zip(dyn_vars, init_values): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_values): v.value = d
for v, d in zip(dyn_vars, dyn_values): v._value = d
return tree_unflatten(tree, out_values)

return call
Expand Down Expand Up @@ -271,7 +271,7 @@ def call(x=None):
turn_off_global_jit()
for v, d in zip(dyn_vars, dyn_init): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_values): v.value = d
for v, d in zip(dyn_vars, dyn_values): v._value = d

return call

Expand Down Expand Up @@ -359,7 +359,7 @@ def call(pred, x=None):
turn_off_global_jit()
for v, d in zip(dyn_vars, old_values): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_values): v.value = d
for v, d in zip(dyn_vars, dyn_values): v._value = d
return res

else:
Expand Down Expand Up @@ -477,7 +477,7 @@ def _false_fun(op):
turn_off_global_jit()
for v, d in zip(dyn_vars, old_values): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_values): v.value = d
for v, d in zip(dyn_vars, dyn_values): v._value = d
else:
turn_on_global_jit()
res = lax.cond(pred, true_fun, false_fun, operands)
Expand Down Expand Up @@ -663,7 +663,7 @@ def fun2scan(dyn_vals, x):
turn_off_global_jit()
for v, d in zip(dyn_vars, init_vals): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_vals): v.value = d
for v, d in zip(dyn_vars, dyn_vals): v._value = d
return out_vals


Expand Down Expand Up @@ -729,4 +729,4 @@ def _cond_fun(op):
turn_off_global_jit()
for v, d in zip(dyn_vars, dyn_init): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_values): v.value = d
for v, d in zip(dyn_vars, dyn_values): v._value = d
2 changes: 1 addition & 1 deletion brainpy/math/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def call(*args, **kwargs):
turn_off_global_jit()
for key, v in vars.items(): v._value = variable_data[key]
raise e
vars.assign(changes)
for key, v in vars.items(): v._value = changes[key]
return out

return change_func_name(name=f_name, f=call) if f_name else call
Expand Down

0 comments on commit b23fb4c

Please sign in to comment.