From bd16f9b993e412a359e649eee28c6add64ac7ffe Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Fri, 9 Feb 2024 16:54:09 +0000 Subject: [PATCH] Improve MyPy type code annotation. --- jax_scaled_arithmetics/core/datatype.py | 2 +- jax_scaled_arithmetics/core/debug.py | 8 ++++---- jax_scaled_arithmetics/core/interpreters.py | 8 +++++--- jax_scaled_arithmetics/core/pow2.py | 2 +- jax_scaled_arithmetics/lax/base_scaling_primitives.py | 4 ++-- jax_scaled_arithmetics/lax/scaled_ops_common.py | 2 +- jax_scaled_arithmetics/lax/scaled_ops_l2.py | 4 ++-- jax_scaled_arithmetics/ops/debug.py | 7 ++++--- pyproject.toml | 2 +- 9 files changed, 21 insertions(+), 18 deletions(-) diff --git a/jax_scaled_arithmetics/core/datatype.py b/jax_scaled_arithmetics/core/datatype.py index 23bcb5c..bd2e056 100644 --- a/jax_scaled_arithmetics/core/datatype.py +++ b/jax_scaled_arithmetics/core/datatype.py @@ -97,7 +97,7 @@ def aval(self) -> ShapedArray: """Abstract value of the scaled array, i.e. shape and dtype.""" return ShapedArray(self.data.shape, self.data.dtype) - def astype(self, dtype) -> "ScaledArray": + def astype(self, dtype: DTypeLike) -> "ScaledArray": """Convert the ScaledArray to a dtype. NOTE: only impacting `data` field, not the `scale` tensor. """ diff --git a/jax_scaled_arithmetics/core/debug.py b/jax_scaled_arithmetics/core/debug.py index c06f7a0..6807557 100644 --- a/jax_scaled_arithmetics/core/debug.py +++ b/jax_scaled_arithmetics/core/debug.py @@ -1,11 +1,11 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. -from typing import Any, Callable +from typing import Any, Callable, Dict from jax import tree_util from jax._src.debugging import debug_callback as debug_callback_orig from jax._src.debugging import debug_callback_p -from .interpreters import register_scaled_op +from .interpreters import ScaledArray, register_scaled_op def get_debug_callback_effect(ordered: bool) -> Any: @@ -42,7 +42,7 @@ def _flat_callback(*flat_args): debug_callback.__doc__ = debug_callback_orig.__doc__ -def scaled_debug_callback(*args, **params) -> Any: +def scaled_debug_callback(*args: ScaledArray, **params: Dict[str, Any]) -> Any: """Scaled `debug_callback`: properly forwarding ScaledArrays to host callback. """ @@ -50,7 +50,7 @@ def scaled_debug_callback(*args, **params) -> Any: if not hasattr(flat_callback_fn, "__callback_fn"): raise NotImplementedError("Please use `jsa.debug_callback` function instead of original JAX function.") callback_fn = flat_callback_fn.__callback_fn - in_pytree = flat_callback_fn.__callback_in_tree + in_pytree = flat_callback_fn.__callback_in_tree # type:ignore # Re-build original input, with scaled arrays. scaled_args, scaled_kwargs = tree_util.tree_unflatten(in_pytree, args) # Re-build ordered boolean, in a backward compatible way. diff --git a/jax_scaled_arithmetics/core/interpreters.py b/jax_scaled_arithmetics/core/interpreters.py index 84e414a..acf15ce 100644 --- a/jax_scaled_arithmetics/core/interpreters.py +++ b/jax_scaled_arithmetics/core/interpreters.py @@ -345,20 +345,22 @@ def jaxpr_eqn_bind(eqn: core.JaxprEqn, invals: Sequence[core.ShapedArray]) -> Se return outvals -def autoscale_jaxpr(jaxpr: core.Jaxpr, consts: Sequence[ScalifyTracerArray], *args: ScalifyTracerArray): +def autoscale_jaxpr( + jaxpr: core.Jaxpr, consts: Sequence[ScalifyTracerArray], *args: ScalifyTracerArray +) -> Sequence[ScalifyTracerArray]: env: Dict[core.Var, ScalifyTracerArray] = {} # Check dtype consistency between normal and scaled modes. safe_check_dtypes: bool = False # AutoScale config to use. autoscale_cfg = get_autoscale_config() - def read(var) -> ScalifyTracerArray: + def read(var: core.Var) -> ScalifyTracerArray: if type(var) is core.Literal: # Wrap the constant in tracer array. return ScalifyTracerArray(var.val) return env[var] - def write(var, val: ScalifyTracerArray): + def write(var: core.Var, val: ScalifyTracerArray) -> None: env[var] = val # A few initial checks to make sure there is consistency. diff --git a/jax_scaled_arithmetics/core/pow2.py b/jax_scaled_arithmetics/core/pow2.py index 0000450..a9f1edc 100644 --- a/jax_scaled_arithmetics/core/pow2.py +++ b/jax_scaled_arithmetics/core/pow2.py @@ -108,7 +108,7 @@ def pow2_decompose_abstract_eval( def pow2_decompose_mlir_lowering( - ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]], **params + ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]], **params: Dict[str, Any] ) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]: scale_dtype = params["scale_dtype"] mode = params["mode"] diff --git a/jax_scaled_arithmetics/lax/base_scaling_primitives.py b/jax_scaled_arithmetics/lax/base_scaling_primitives.py index defa590..e564640 100644 --- a/jax_scaled_arithmetics/lax/base_scaling_primitives.py +++ b/jax_scaled_arithmetics/lax/base_scaling_primitives.py @@ -1,6 +1,6 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. import logging -from typing import Optional, Sequence, Union +from typing import Any, Dict, Optional, Sequence, Union import numpy as np from jax import core @@ -127,7 +127,7 @@ def stop_scaling_abstract_eval(values: core.ShapedArray, dtype: Optional[DTypeLi def stop_scaling_mlir_lowering( - ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]], **params + ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]], **params: Dict[str, Any] ) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]: dtype = params.get("dtype", None) if dtype is not None: diff --git a/jax_scaled_arithmetics/lax/scaled_ops_common.py b/jax_scaled_arithmetics/lax/scaled_ops_common.py index bd5ad32..49799fb 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops_common.py +++ b/jax_scaled_arithmetics/lax/scaled_ops_common.py @@ -28,7 +28,7 @@ def _get_data(val: Any) -> Array: return val -def check_scalar_scales(*args: ScaledArray): +def check_scalar_scales(*args: ScaledArray) -> Array: """Check all ScaledArrays have scalar scaling.""" for val in args: assert np.ndim(val.scale) == 0 diff --git a/jax_scaled_arithmetics/lax/scaled_ops_l2.py b/jax_scaled_arithmetics/lax/scaled_ops_l2.py index 5d1bda1..ca92247 100644 --- a/jax_scaled_arithmetics/lax/scaled_ops_l2.py +++ b/jax_scaled_arithmetics/lax/scaled_ops_l2.py @@ -1,5 +1,5 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. -from typing import Any, Optional, Sequence, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple import jax.numpy as jnp import numpy as np @@ -96,7 +96,7 @@ def scaled_dot_general( @core.register_scaled_lax_op -def scaled_conv_general_dilated(lhs: ScaledArray, rhs: ScaledArray, **params) -> ScaledArray: +def scaled_conv_general_dilated(lhs: ScaledArray, rhs: ScaledArray, **params: Dict[str, Any]) -> ScaledArray: assert isinstance(lhs, ScaledArray) assert isinstance(rhs, ScaledArray) data = lax.conv_general_dilated_p.bind(lhs.data, rhs.data, **params) diff --git a/jax_scaled_arithmetics/ops/debug.py b/jax_scaled_arithmetics/ops/debug.py index ba6dbd4..1db09ab 100644 --- a/jax_scaled_arithmetics/ops/debug.py +++ b/jax_scaled_arithmetics/ops/debug.py @@ -1,9 +1,10 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. from functools import partial +from typing import Sequence import jax -from jax_scaled_arithmetics.core import debug_callback +from jax_scaled_arithmetics.core import Array, debug_callback @partial(jax.custom_vjp, nondiff_argnums=(0,)) @@ -24,12 +25,12 @@ def debug_callback_grad_bwd(f, _, args_grad): debug_callback_grad.defvjp(debug_callback_grad_fwd, debug_callback_grad_bwd) -def debug_print(fmt: str, *args): +def debug_print(fmt: str, *args: Array) -> Sequence[Array]: """Debug print of a collection of tensors.""" debug_callback(lambda *args: print(fmt.format(*args)), *args) return args -def debug_print_grad(fmt: str, *args): +def debug_print_grad(fmt: str, *args: Array) -> Sequence[Array]: """Debug print of gradients of a collection of tensors.""" return debug_callback_grad(lambda *args: print(fmt.format(*args)), *args) diff --git a/pyproject.toml b/pyproject.toml index 2e859eb..9b171d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,7 @@ warn_unused_configs = true check_untyped_defs = true disallow_any_generics = true no_implicit_optional = false -# disallow_incomplete_defs = true +disallow_incomplete_defs = true # disallow_untyped_decorators = true # disallow_untyped_calls = true # # disallow_subclassing_any = true