Basic types support?
#5877
-
I'm running # main.py
from jax import numpy as jnp
def _a() -> jnp.ndarray:
return jnp.array([1, 2, 3])
a = _a() Getting the following errors main.py:4: error: Name 'jnp.ndarray' is not defined
main.py:5: error: Module has no attribute "array"; maybe "ndarray" or "asarray"?
Found 2 errors in 1 file (checked 1 source file) I saw that #5275 added a py.typed file last month - should I drop the --strict flag? Thanks. |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Feb 28, 2021
Replies: 1 comment
-
Typing support is still a work in progress – most JAX API functions (including |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
bhchiang
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Typing support is still a work in progress – most JAX API functions (including
jnp.array
) do not yet specify return types. Until typing is more completely supported by jax, type-checking of code like your example will not work.