Skip to content

Commit

Permalink
Improve MyPy type code annotation. (#102)
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap authored Feb 9, 2024
1 parent 94c44c8 commit 5e5ccc0
Show file tree
Hide file tree
Showing 9 changed files with 21 additions and 18 deletions.
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def aval(self) -> ShapedArray:
"""Abstract value of the scaled array, i.e. shape and dtype."""
return ShapedArray(self.data.shape, self.data.dtype)

def astype(self, dtype) -> "ScaledArray":
def astype(self, dtype: DTypeLike) -> "ScaledArray":
"""Convert the ScaledArray to a dtype.
NOTE: only impacting `data` field, not the `scale` tensor.
"""
Expand Down
8 changes: 4 additions & 4 deletions jax_scaled_arithmetics/core/debug.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from typing import Any, Callable
from typing import Any, Callable, Dict

from jax import tree_util
from jax._src.debugging import debug_callback as debug_callback_orig
from jax._src.debugging import debug_callback_p

from .interpreters import register_scaled_op
from .interpreters import ScaledArray, register_scaled_op


def get_debug_callback_effect(ordered: bool) -> Any:
Expand Down Expand Up @@ -42,15 +42,15 @@ def _flat_callback(*flat_args):
debug_callback.__doc__ = debug_callback_orig.__doc__


def scaled_debug_callback(*args, **params) -> Any:
def scaled_debug_callback(*args: ScaledArray, **params: Dict[str, Any]) -> Any:
"""Scaled `debug_callback`: properly forwarding ScaledArrays
to host callback.
"""
flat_callback_fn = params["callback"]
if not hasattr(flat_callback_fn, "__callback_fn"):
raise NotImplementedError("Please use `jsa.debug_callback` function instead of original JAX function.")
callback_fn = flat_callback_fn.__callback_fn
in_pytree = flat_callback_fn.__callback_in_tree
in_pytree = flat_callback_fn.__callback_in_tree # type:ignore
# Re-build original input, with scaled arrays.
scaled_args, scaled_kwargs = tree_util.tree_unflatten(in_pytree, args)
# Re-build ordered boolean, in a backward compatible way.
Expand Down
8 changes: 5 additions & 3 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,20 +345,22 @@ def jaxpr_eqn_bind(eqn: core.JaxprEqn, invals: Sequence[core.ShapedArray]) -> Se
return outvals


def autoscale_jaxpr(jaxpr: core.Jaxpr, consts: Sequence[ScalifyTracerArray], *args: ScalifyTracerArray):
def autoscale_jaxpr(
jaxpr: core.Jaxpr, consts: Sequence[ScalifyTracerArray], *args: ScalifyTracerArray
) -> Sequence[ScalifyTracerArray]:
env: Dict[core.Var, ScalifyTracerArray] = {}
# Check dtype consistency between normal and scaled modes.
safe_check_dtypes: bool = False
# AutoScale config to use.
autoscale_cfg = get_autoscale_config()

def read(var) -> ScalifyTracerArray:
def read(var: core.Var) -> ScalifyTracerArray:
if type(var) is core.Literal:
# Wrap the constant in tracer array.
return ScalifyTracerArray(var.val)
return env[var]

def write(var, val: ScalifyTracerArray):
def write(var: core.Var, val: ScalifyTracerArray) -> None:
env[var] = val

# A few initial checks to make sure there is consistency.
Expand Down
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/core/pow2.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def pow2_decompose_abstract_eval(


def pow2_decompose_mlir_lowering(
ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]], **params
ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]], **params: Dict[str, Any]
) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]:
scale_dtype = params["scale_dtype"]
mode = params["mode"]
Expand Down
4 changes: 2 additions & 2 deletions jax_scaled_arithmetics/lax/base_scaling_primitives.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import logging
from typing import Optional, Sequence, Union
from typing import Any, Dict, Optional, Sequence, Union

import numpy as np
from jax import core
Expand Down Expand Up @@ -127,7 +127,7 @@ def stop_scaling_abstract_eval(values: core.ShapedArray, dtype: Optional[DTypeLi


def stop_scaling_mlir_lowering(
ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]], **params
ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]], **params: Dict[str, Any]
) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]:
dtype = params.get("dtype", None)
if dtype is not None:
Expand Down
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/lax/scaled_ops_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _get_data(val: Any) -> Array:
return val


def check_scalar_scales(*args: ScaledArray):
def check_scalar_scales(*args: ScaledArray) -> Array:
"""Check all ScaledArrays have scalar scaling."""
for val in args:
assert np.ndim(val.scale) == 0
Expand Down
4 changes: 2 additions & 2 deletions jax_scaled_arithmetics/lax/scaled_ops_l2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from typing import Any, Optional, Sequence, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple

import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -96,7 +96,7 @@ def scaled_dot_general(


@core.register_scaled_lax_op
def scaled_conv_general_dilated(lhs: ScaledArray, rhs: ScaledArray, **params) -> ScaledArray:
def scaled_conv_general_dilated(lhs: ScaledArray, rhs: ScaledArray, **params: Dict[str, Any]) -> ScaledArray:
assert isinstance(lhs, ScaledArray)
assert isinstance(rhs, ScaledArray)
data = lax.conv_general_dilated_p.bind(lhs.data, rhs.data, **params)
Expand Down
7 changes: 4 additions & 3 deletions jax_scaled_arithmetics/ops/debug.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from functools import partial
from typing import Sequence

import jax

from jax_scaled_arithmetics.core import debug_callback
from jax_scaled_arithmetics.core import Array, debug_callback


@partial(jax.custom_vjp, nondiff_argnums=(0,))
Expand All @@ -24,12 +25,12 @@ def debug_callback_grad_bwd(f, _, args_grad):
debug_callback_grad.defvjp(debug_callback_grad_fwd, debug_callback_grad_bwd)


def debug_print(fmt: str, *args):
def debug_print(fmt: str, *args: Array) -> Sequence[Array]:
"""Debug print of a collection of tensors."""
debug_callback(lambda *args: print(fmt.format(*args)), *args)
return args


def debug_print_grad(fmt: str, *args):
def debug_print_grad(fmt: str, *args: Array) -> Sequence[Array]:
"""Debug print of gradients of a collection of tensors."""
return debug_callback_grad(lambda *args: print(fmt.format(*args)), *args)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ warn_unused_configs = true
check_untyped_defs = true
disallow_any_generics = true
no_implicit_optional = false
# disallow_incomplete_defs = true
disallow_incomplete_defs = true
# disallow_untyped_decorators = true
# disallow_untyped_calls = true
# # disallow_subclassing_any = true
Expand Down

0 comments on commit 5e5ccc0

Please sign in to comment.