-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
basic interpeter for scaled ops and scaled arrays #8
Conversation
…metics into scaledarray-basic-interpreter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @lyprince Just a couple of comments on my side, to get even closer to the pattern used in JAX for registry + interpreter.
But other than that, nice stuff!
@@ -0,0 +1,52 @@ | |||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure how many interpreters we will have! I would except a main one, with config. When we know more, we can maybe consolidate in a single interpreter.py
file in core
from jax._src.util import safe_map | ||
|
||
from ..core import ScaledArray | ||
from ..lax import scaled_ops_registry |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To follow JAX way: I would actually declare the global dict
scaled_ops_registry
here, and import it in the lax
module. Feels weird to do it the other way around.
And I would add a function register_scaled_op
to mutate it, so we avoid importing a global variable everywhere (and at some point we can add additional checks)
def autoscale(fun): | ||
@wraps(fun) | ||
def wrapped(*args, **kwargs): | ||
unscaled_args = safe_map(lambda x: x.to_array() if hasattr(x, "to_array") else x, args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not blocking: it is probably sub-optimal to call to_array
here. Just getting .aval
abstract array should be enough for tracing. Then you avoid the potential cost of the allocation of arrays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Meaning we need to add .aval
property to ScaledArray
|
||
def read(var): | ||
if type(var) is core.Literal: | ||
return var.val |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not blocking: we may need to think how we want to handle literals and constants? I believe we should also convert them to ScaledArray
|
||
from .scaled_ops import scaled_mul | ||
|
||
scaled_ops_registry = {lax.mul_p: scaled_mul} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see above, I would move that to autoscale.py
return ScaledArray(A.data * B.data, A.scale * B.scale) | ||
|
||
|
||
__all__ = ["scaled_mul"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add the call to register_scaled_op(mul_p, scaled_mul)
here. It's the pattern used everywhere in the JAX codebase
expected = jnp.array([1.0, 2.0]) | ||
|
||
scaled_inputs = ScaledArray(inputs, scale) | ||
scaled_outputs = asfunc(scaled_inputs)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to do [0]
? Feels we are changing the semantic of the function here?
|
||
jaxpr = jax.make_jaxpr(asfunc)(scaled_inputs).jaxpr | ||
|
||
assert jaxpr.invars[0].aval.shape == inputs.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe here a comment to explain why the jaxpr
has doubled amount of inputs/outputs, due to scale tensor. Not that obvious if you're not a JAX expert!
|
||
|
||
class AutoScaleInterpreterTests(chex.TestCase): | ||
def test__identity(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a blocker: we can at some point use @chex.variants(with_jit=True, without_jit=True)
to run the same unit test with jax.jit
added or not. Helps having consistent checks that our logic works in eager/interpreter mode or jitting mode.
Thanks @lyprince, all looking good! |
Basic interpreter inspired by the jax tutorial: https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html
Traces jaxpr of function without scaled arrays, then swaps traced primitives with ops in scaled_ops_registry.