Skip to content

Commit

Permalink
Merge pull request #15 from kirkegaardlab/fix_jax_deprecation
Browse files Browse the repository at this point in the history
fix: replace deprecated jax.numpy.trapz
  • Loading branch information
alonfnt authored Jun 1, 2024
2 parents 80c1c06 + 16a6c38 commit 484bc3b
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 9 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/integration_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ jobs:
steps:
- uses: actions/checkout@v2

- name: Set up Python 3.8
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: '3.8'
python-version: '3.10'
architecture: 'x64'

- name: apt-get
Expand Down
7 changes: 4 additions & 3 deletions celegans/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import jax
import jax.numpy as jnp
from jax.scipy.integrate import trapezoid


def _theta(t, s, params):
Expand Down Expand Up @@ -123,9 +124,9 @@ def solve(t, u, X, ds, alpha):
fx = Ut * tx[jnp.newaxis] + alpha * Un * nx[jnp.newaxis]
fy = Ut * ty[jnp.newaxis] + alpha * Un * ny[jnp.newaxis]

Fx = jnp.trapz(fx, dx=ds)
Fy = jnp.trapz(fy, dx=ds)
Tau = jnp.trapz(x * fy - y * fx, dx=ds)
Fx = trapezoid(fx, dx=ds)
Fy = trapezoid(fy, dx=ds)
Tau = trapezoid(x * fy - y * fx, dx=ds)

b = -jnp.array([Fx[0], Fy[0], Tau[0]])
A = jnp.array([Fx[1:], Fy[1:], Tau[1:]])
Expand Down
6 changes: 6 additions & 0 deletions examples/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
import deeptangle as dt
import matplotlib.pyplot as plt
from skimage.exposure import equalize_adapthist
import numpy

# scikit-video uses deprecated numpy.float, numpy.int
# hacky fix: https://github.com/scikit-video/scikit-video/issues/154
numpy.float = numpy.float64
numpy.int = numpy.int_
import skvideo.io


Expand Down
6 changes: 6 additions & 0 deletions examples/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
import matplotlib.pyplot as plt
import numpy as np
from skimage.exposure import equalize_adapthist

# scikit-video uses deprecated numpy.float, numpy.int
# hacky fix: https://github.com/scikit-video/scikit-video/issues/154
import numpy
numpy.float = numpy.float64
numpy.int = numpy.int_
import skvideo.io

import deeptangle as dt
Expand Down
9 changes: 5 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ scikit-image
scikit-video
optax
chex
jax
jax>=0.4.16
jaxlib>=0.4.20
dm-pix
scikit-learn
numpy==1.21.6
numba==0.55
numpy>=1.21.6
numba>=0.55
matplotlib
https://github.com/alonfnt/dm-haiku/archive/refs/heads/avg_pool_perf.zip
dm-haiku
trackpy

0 comments on commit 484bc3b

Please sign in to comment.