Skip to content

Commit

Permalink
Adapt documentation to inform that the specification of return_shapes…
Browse files Browse the repository at this point in the history
… isn't necerry anymore and adapt the getting started example.ipynb
  • Loading branch information
jdehning committed Sep 25, 2024
1 parent 28f9036 commit c5caed0
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 61 deletions.
165 changes: 122 additions & 43 deletions docs/example.ipynb

Large diffs are not rendered by default.

15 changes: 4 additions & 11 deletions icomo/comp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,17 +525,11 @@ def get_op(
more details. The return value of the function has to be a
pytree/list/dict with the same structure as `y`.
return_shapes : tuple of tuples, default is ((),)
The shapes (except the time dimension) of the variables of the system of
differential equations that are returned by the integrator. If
`list_keys_to_return` is `None`, the shapes have to be given in the same
order as the variables are returned by the integrator. If
`list_keys_to_return` is not `None`, the shapes have to be given in the
same order as the keys in `list_keys_to_return`. The default `((), )` means
a single variable with only a time dimension is returned.
Depreceated, the return shape had to be specified before, now it is inferred
automatically. This argument isn't used anymore.
list_keys_to_return : list of str or None, default is None
The keys of the variables of the system of differential equations that will
be chosen to be returned by the integrator. Necessary if the ODE returns a
`dict`, as :mod:`pytensor` only accepts single outputs or a list of outputs.
be chosen to be returned by the integrator.
If `None`, the output is returned as is.
name :
The name under which the operator is registered in pymc.
Expand All @@ -558,8 +552,7 @@ def get_op(
)

pytensor_op = jax2pytensor(
integrator,
output_shape_def=output_shape_def,
integrator, output_shape_def=output_shape_def, name=name
)

return pytensor_op
Expand Down
2 changes: 0 additions & 2 deletions icomo/jax2pytensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,7 @@ def new_func(inputs_list_flat):
input_types_dic = {
arg: shapes for arg, shapes in zip(inputnames_list, input_shapes_list)
}
print(input_types_dic)
output_shape = output_shape_def(**input_types_dic)
print(output_shape)

# For flattening the output shapes, we need to redefine what is a leaf, so
# that the shape tuples don't get also flattened.
Expand Down
8 changes: 3 additions & 5 deletions tests/test_example_notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,16 +277,14 @@ def test_bayes():

SEIR_integrator_op = integrator_object_bayes.get_op(
Erlang_SEIR,
return_shapes=[() for _ in range(2)],
list_keys_to_return=["S", "I"],
)

S, I = SEIR_integrator_op(
output = SEIR_integrator_op(
y0=y0_var, arg_t=beta_t_var, constant_args=const_args_var
)

pm.Deterministic("I", I)
new_cases = -pt.diff(S)
pm.Deterministic("I", output["I"])
new_cases = -pt.diff(output["S"])
pm.Deterministic("new_cases", new_cases)

sigma_error = pm.HalfCauchy("sigma_error", beta=1)
Expand Down

0 comments on commit c5caed0

Please sign in to comment.