Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

basic interpeter for scaled ops and scaled arrays #8

Merged
merged 12 commits into from
Nov 8, 2023

Conversation

lyprince
Copy link
Contributor

@lyprince lyprince commented Nov 8, 2023

Basic interpreter inspired by the jax tutorial: https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html

Traces jaxpr of function without scaled arrays, then swaps traced primitives with ops in scaled_ops_registry.

Copy link
Contributor

@balancap balancap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @lyprince Just a couple of comments on my side, to get even closer to the pattern used in JAX for registry + interpreter.
But other than that, nice stuff!

@@ -0,0 +1,52 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how many interpreters we will have! I would except a main one, with config. When we know more, we can maybe consolidate in a single interpreter.py file in core

from jax._src.util import safe_map

from ..core import ScaledArray
from ..lax import scaled_ops_registry
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To follow JAX way: I would actually declare the global dict scaled_ops_registry here, and import it in the lax module. Feels weird to do it the other way around.

And I would add a function register_scaled_op to mutate it, so we avoid importing a global variable everywhere (and at some point we can add additional checks)

def autoscale(fun):
@wraps(fun)
def wrapped(*args, **kwargs):
unscaled_args = safe_map(lambda x: x.to_array() if hasattr(x, "to_array") else x, args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not blocking: it is probably sub-optimal to call to_array here. Just getting .aval abstract array should be enough for tracing. Then you avoid the potential cost of the allocation of arrays.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meaning we need to add .aval property to ScaledArray


def read(var):
if type(var) is core.Literal:
return var.val
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not blocking: we may need to think how we want to handle literals and constants? I believe we should also convert them to ScaledArray


from .scaled_ops import scaled_mul

scaled_ops_registry = {lax.mul_p: scaled_mul}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above, I would move that to autoscale.py

return ScaledArray(A.data * B.data, A.scale * B.scale)


__all__ = ["scaled_mul"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the call to register_scaled_op(mul_p, scaled_mul) here. It's the pattern used everywhere in the JAX codebase

expected = jnp.array([1.0, 2.0])

scaled_inputs = ScaledArray(inputs, scale)
scaled_outputs = asfunc(scaled_inputs)[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to do [0]? Feels we are changing the semantic of the function here?


jaxpr = jax.make_jaxpr(asfunc)(scaled_inputs).jaxpr

assert jaxpr.invars[0].aval.shape == inputs.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe here a comment to explain why the jaxpr has doubled amount of inputs/outputs, due to scale tensor. Not that obvious if you're not a JAX expert!



class AutoScaleInterpreterTests(chex.TestCase):
def test__identity(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a blocker: we can at some point use @chex.variants(with_jit=True, without_jit=True) to run the same unit test with jax.jit added or not. Helps having consistent checks that our logic works in eager/interpreter mode or jitting mode.

@balancap
Copy link
Contributor

balancap commented Nov 8, 2023

Thanks @lyprince, all looking good!

@lyprince lyprince merged commit 8352e80 into main Nov 8, 2023
2 checks passed
@lyprince lyprince deleted the scaledarray-basic-interpreter branch November 8, 2023 19:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants