diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 52b08cd..f7c4c37 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: - id: requirements-txt-fixer - id: trailing-whitespace - repo: https://github.com/PyCQA/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort args: [--profile, black] @@ -23,16 +23,16 @@ repos: - id: pyupgrade args: [--py38-plus] - repo: https://github.com/PyCQA/flake8 - rev: 6.1.0 + rev: 7.0.0 hooks: - id: flake8 args: ['--ignore=E501,E203,E731,W503'] - repo: https://github.com/psf/black - rev: 23.11.0 + rev: 24.1.1 hooks: - id: black - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.6.1 + rev: v1.8.0 hooks: - id: mypy additional_dependencies: [types-dataclasses, numpy] diff --git a/jax_scaled_arithmetics/core/datatype.py b/jax_scaled_arithmetics/core/datatype.py index 8d82693..23bcb5c 100644 --- a/jax_scaled_arithmetics/core/datatype.py +++ b/jax_scaled_arithmetics/core/datatype.py @@ -13,7 +13,7 @@ from .pow2 import Pow2RoundMode, pow2_decompose from .typing import Array, ArrayTypes -GenericArray = Union[Array, np.ndarray] +GenericArray = Union[Array, np.ndarray[Any, Any]] @register_pytree_node_class