Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas GPU] Disable implicit type conversion during type matching #23193

Merged
merged 1 commit into from
Aug 29, 2024

Conversation

ayaka14732
Copy link
Collaborator

This fixes #23191

@ayaka14732 ayaka14732 self-assigned this Aug 22, 2024
@ayaka14732 ayaka14732 added the pull ready Ready for copybara import and testing label Aug 22, 2024
@ayaka14732 ayaka14732 marked this pull request as draft August 27, 2024 21:54
@ayaka14732 ayaka14732 removed the pull ready Ready for copybara import and testing label Aug 27, 2024
@ayaka14732 ayaka14732 force-pushed the pallas-gpu-type-conversion branch 3 times, most recently from b3615a9 to 252fc07 Compare August 28, 2024 22:51
@ayaka14732 ayaka14732 added the pull ready Ready for copybara import and testing label Aug 28, 2024
@ayaka14732 ayaka14732 marked this pull request as ready for review August 28, 2024 23:39
@ayaka14732 ayaka14732 added pull ready Ready for copybara import and testing and removed pull ready Ready for copybara import and testing labels Aug 28, 2024
@ayaka14732
Copy link
Collaborator Author

The added test case fails in internal CI as expected:

Traceback (most recent call last):
  File "/build/work/c743021dcb5cf3119cd6c9148fe8f544061a/google3/runfiles/google3/third_party/py/jax/tests/pallas/ops_test.py", [line 751](https://cs.corp.google.com/piper///depot/google3/third_party/py/jax/tests/pallas/ops_test.py?l=751&ws=tap-presubmit-server/238884448&snapshot=2), in test_abs_weak_type
    np.testing.assert_allclose(kernel(x), jnp.abs(x), rtol=1e-6)
  File "/build/work/c743021dcb5cf3119cd6c9148fe8f544061a/google3/runfiles/google3/third_party/py/numpy/testing/_private/utils.py", [line 1504](https://cs.corp.google.com/piper///depot/google3/third_party/py/numpy/testing/_private/utils.py?l=1504&ws=tap-presubmit-server/238884448&snapshot=2), in assert_allclose
    assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
  File "/<embedded stdlib>/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/build/work/c743021dcb5cf3119cd6c9148fe8f544061a/google3/runfiles/google3/third_party/py/numpy/testing/_private/utils.py", [line 797](https://cs.corp.google.com/piper///depot/google3/third_party/py/numpy/testing/_private/utils.py?l=797&ws=tap-presubmit-server/238884448&snapshot=2), in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Not equal to tolerance rtol=1e-06, atol=0

Mismatched elements: 16 / 16 (100%)
Max absolute difference: 3.2
Max relative difference: 1.
 x: array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]], dtype=float32)
 y: array([[3.2, 3.2, 3.2, 3.2],
       [3.2, 3.2, 3.2, 3.2],
       [3.2, 3.2, 3.2, 3.2],
       [3.2, 3.2, 3.2, 3.2]], dtype=float32)

@copybara-service copybara-service bot merged commit 48a9159 into jax-ml:main Aug 29, 2024
13 of 15 checks passed
@ayaka14732 ayaka14732 deleted the pallas-gpu-type-conversion branch August 29, 2024 00:23
copybara-service bot pushed a commit that referenced this pull request Aug 29, 2024
Previously,

* weakly typed avals matched the wrong externs;
* this was addressed by #23193, which disallowed weakly typed avals entirely.

Here we check if a weakly typed aval can be casted to the extern input dtype
when selecting an extern.

PiperOrigin-RevId: 669067725
copybara-service bot pushed a commit that referenced this pull request Aug 30, 2024
Previously,

* weakly typed avals matched the wrong externs;
* this was addressed by #23193, which disallowed weakly typed avals entirely.

Here we check if a weakly typed aval can be casted to the extern input dtype
when selecting an extern.

PiperOrigin-RevId: 669067725
copybara-service bot pushed a commit that referenced this pull request Aug 30, 2024
Previously,

* weakly typed avals matched the wrong externs;
* this was addressed by #23193, which disallowed weakly typed avals entirely.

Here we check if a weakly typed aval can be casted to the extern input dtype
when selecting an extern.

PiperOrigin-RevId: 669067725
copybara-service bot pushed a commit that referenced this pull request Aug 30, 2024
Previously,

* weakly typed avals matched the wrong externs;
* this was addressed by #23193, which disallowed weakly typed avals entirely.

Here we check if a weakly typed aval can be casted to the extern input dtype
when selecting an extern.

PiperOrigin-RevId: 669067725
copybara-service bot pushed a commit that referenced this pull request Aug 30, 2024
Previously,

* weakly typed avals matched the wrong externs;
* this was addressed by #23193, which disallowed weakly typed avals entirely.

Here we check if a weakly typed aval can be casted to the extern input dtype
when selecting an extern.

PiperOrigin-RevId: 669378582
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Pallas GPU] jnp.abs() gives the wrong result
3 participants