Skip to content

Commit

Permalink
set jax to <0.4.20
Browse files Browse the repository at this point in the history
  • Loading branch information
LegrandNico committed Nov 6, 2023
1 parent 5829ad4 commit d80538f
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 61 deletions.
100 changes: 46 additions & 54 deletions docs/source/notebooks/1.1-Binary_HGF.ipynb

Large diffs are not rendered by default.

8 changes: 3 additions & 5 deletions docs/source/notebooks/1.1-Binary_HGF.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,7 @@ In the previous section, we assumed we knew the parameters of the HGF models tha

Because the HGF classes are built on the top of [JAX](https://github.com/google/jax), they are natively differentiable and compatible with optimisation libraries or can be embedded as regular distributions in the context of a Bayesian network. Here, we are using this approach, and we are going to use [PyMC](https://www.pymc.io/welcome.html) to perform this step. PyMC can use any log probability function (here the negative surprise of the model) as a building block for a new distribution by wrapping it in its underlying tensor library [Aesara](https://aesara.readthedocs.io/en/latest/), now forked as [PyTensor](https://pytensor.readthedocs.io/en/latest/). This PyMC-compatible distribution can be found in the {py:obj}`pyhgf.distribution` sub-module.

```{code-cell} ipython3
```
+++

### Two-level model
#### Creating the model
Expand Down Expand Up @@ -230,7 +228,7 @@ pm.model_to_graphviz(two_levels_binary_hgf)

```{code-cell} ipython3
with two_levels_binary_hgf:
two_level_hgf_idata = pm.sample(chains=2, cores=1)
two_level_hgf_idata = pm.sample(chains=2)
```

```{code-cell} ipython3
Expand Down Expand Up @@ -320,7 +318,7 @@ slideshow:
slide_type: ''
---
with three_levels_binary_hgf:
three_level_hgf_idata = pm.sample(chains=2, cores=1)
three_level_hgf_idata = pm.sample(chains=2)
```

```{code-cell} ipython3
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ matplotlib>=3.0.2
seaborn>=0.9.0
arviz>=0.12.0
pymc>=5.0.0
jax>=0.4.1
jaxlib>=0.4.1
jax>=0.4.1, <0.4.20
jaxlib>=0.4.1, <0.4.20
setuptools>=38.4
packaging

0 comments on commit d80538f

Please sign in to comment.