-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve typing of jax.jit
#23720
base: main
Are you sure you want to change the base?
Improve typing of jax.jit
#23720
Conversation
Thanks for the contribution! See #14688 for a past attempt at this, the things that went wrong, and some discussion of the considerations around a change like this. In particular, PyType was a blocker in the past, and we'll have to check whether that's still the case. |
a79de59
to
a58adb8
Compare
Thanks @jakevdp for the prompt response. |
Yeah, newly discovered issues need to be fixed or silenced, because otherwise merging the PR will break our CI. Could you given an example of the issue you discovered, please? |
I just rebased the PR to only improve the typing of Here are the typing issues that the MyPy pre-commit hook now sees within the codebase, now that the signature of the callable isn't dropped by
What is the typical way that such typing errors are silenced in this project? Do you prefer |
a58adb8
to
49bd69c
Compare
I am pretty sure this will break a lot of other targets internally too making this change very difficult to land. |
Yeah, #14688 was eventually blocked by the fact that pytype doesn't properly suppot |
Some of the errors reflect that the
It basically comes from something that looks like this: @jit
def _lu_solve(x: Array):
...
def lu_solve(x: ArrayLike):
return _lu_solve(x) # <- type error, because ArrayLike is not Array However, when you wrap a function with What do you think? |
re: @superbobry, @yashk2810 - I silenced the new typing errors that mypy raised in the pre-commit hook. re: @jakevdp
I agree. In my view, this only encourages internal jax source to be more explicit, by not depending on this implicit conversion from ArrayLike to Array. |
re: @yashk2810
Are you saying that this tiny little PR would also improve other downstream projects at Google? 🤩 😛 |
Haha, I wouldn't say improve but it will break a lot of stuff and I don't think we have the bandwidth to fix all those projects. Hence landing this is very hard IRL. |
I think, setting aside caveats about how hard this might be to land, this is generally a change we want, and one we've been hoping to add for a long time. Initially we were blocked by the lack of The relevant issue is google/pytype#1471, which is still open. |
re: @yashk2810 re: @yashk2810 @jakevdp |
Just to be clear, the |
With respect to google/pytype#1471, would this change be easier to merge if the output TypeVar were not marked as covariant? 🤔 Edit: I'll double-check, but I think that the output var being invariant might cause other issues (which would then be due to the the annotation not being 100% correct) |
Yes, if this didn't use covariant typevars, it would be easier to merge. But my understanding from #14688 was that covariant typevars are required in order to correctly annotate |
@jakevdp pytype treats all type variables as covariant IIRC, so maybe we can just suppress the warning for that particular type var? |
Re @superbobry @jakevdp : |
I'll start with Flax, since that seems like the most obvious downstream project from my perspective. |
e3faefb
to
be5fef9
Compare
Pulling in to run internal pytype tests |
jax/_src/api.py
Outdated
@@ -151,7 +153,7 @@ def jit( | |||
backend: str | None = None, | |||
inline: bool = False, | |||
abstracted_axes: Any | None = None, | |||
) -> pjit.JitWrapped: | |||
) -> pjit.JitWrapped[_P, _OutT]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm seeing an error here in pytype that's blocking any further testing:
/jax/_src/api.py:156: error: in <module>: class JitWrapped is not indexable [not-indexable]
('JitWrapped' does not subclass Generic)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
JitWrapped
does in fact inherit from Generic
, so I suspect there's something deeper going on that's preventing pytype
from seeing that within the package structure of bazel-based builds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One potential workaround (although ugly) could be to make Wrapped
inherit directly from Generic
, in addition to Protocol
(which itself inherits from Generic
).
Should we test this out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@superbobry do you have thoughts here?
One thing we could do is just ignore this error in PyType, and then within pytype jit
would essentially be treated as Any
until we fix the bazel build depth that prevents pytype from resolving this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think inheriting from Generic
and listing type parameters makes sense. This is always fine to do for a generic class.
Note sure about inheriting from Protocol
, because we don't want JitWrapped
to be a protocol itself (it does implement the stages.Wrapped
protocol, though).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I'll give this a shot. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I pushed a new version of this with JitWrapped
inheriting from Generic directly (in addition to the existing JitWrapped
-> Wrapped
-> Protocol
-> Generic
chain): 5d9f762
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh and btw @superbobry, I meant to say that JitWrapped already subclasses Generic through Protocol, not that we should make it subclass both Generic and Protocol, if that makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think @superbobry was saying that JitWrapped
should not inherit from Wrapped
, because Wrapped
is a protocol and JitWrapped
is a class that implements that protocol.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't say that, but I agree with @jakevdp in principle :)
5c11738
to
ee069b3
Compare
jax/_src/api.py
Outdated
@@ -137,9 +137,11 @@ def _update_debug_special_thread_local(_): | |||
|
|||
float0 = dtypes.float0 | |||
|
|||
_P = ParamSpec("_P") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reuse this from stages instead of redefining in 3 files.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Addressed in 32de804 (rebased to keep everything in a single commit)
jax/_src/pjit.py
Outdated
@@ -807,8 +807,10 @@ def ax_leaf(l): | |||
isinstance(l, tuple) and all_leaves(l, lambda x: x is None)) | |||
return broadcast_prefix(abstracted_axes, args, ax_leaf) | |||
|
|||
_P = ParamSpec("_P") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reuse this from stages instead of redefining in 3 files.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Addressed in 32de804 (rebased to keep everything in a single commit)
32de804
to
1e706c2
Compare
I'm curious: If this ends up getting merged, will the Google-ML-Automation bot include my github username in the final commit? Or would someone on the inside need to add a |
If this is merged, your actual unmodified commit would be added to the JAX source tree. |
- Fix for jax-ml#23719 Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
1e706c2
to
5d9f762
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving to test on the internal CI.
HI @superbobry , would it help if I rebased this off of the most recent version of master? |
Apologies for the silence. We can try, but it looks like pytype might need some fixing to accept the declarations in this PR. |
jax.jit
#23719