From 090c3a2488297f6bd5f78aa36a7bd61e4fa54fcd Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 25 Sep 2024 14:17:26 +0200 Subject: [PATCH] Updated the `make_array()` function. Before the `order` argument was a `Literal` but this caused more truble now it is a string. --- .../primitive_translators/test_primitive_reshape.py | 3 +-- tests/util.py | 6 +++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/integration_tests/primitive_translators/test_primitive_reshape.py b/tests/integration_tests/primitive_translators/test_primitive_reshape.py index bd1d6ab..4a504c5 100644 --- a/tests/integration_tests/primitive_translators/test_primitive_reshape.py +++ b/tests/integration_tests/primitive_translators/test_primitive_reshape.py @@ -25,8 +25,7 @@ def _test_impl_reshaping( src_shape: Sequence[int], dst_shape: Sequence[int], order: str = "C" ) -> None: """Performs a reshaping from `src_shape` to `dst_shape`.""" - a = testutil.make_array(src_shape) - a = np.array(a, order=order) # type: ignore[call-overload] # MyPy wants a literal as order. + a = testutil.make_array(src_shape, order=order) def testee(a: np.ndarray) -> jax.Array: return jnp.reshape(a, dst_shape) diff --git a/tests/util.py b/tests/util.py index 1fa2e07..3b38bfe 100644 --- a/tests/util.py +++ b/tests/util.py @@ -10,7 +10,7 @@ from __future__ import annotations from collections.abc import Mapping, Sequence -from typing import Any, Literal +from typing import Any import numpy as np @@ -23,7 +23,7 @@ def make_array( shape: Sequence[int] | int, dtype: type = np.float64, - order: Literal[None, "K", "A", "C", "F"] = "C", + order: str = "C", low: Any = None, high: Any = None, ) -> np.ndarray: @@ -69,7 +69,7 @@ def make_array( res = low + (high - low) * res assert (low is None) == (high is None) - return np.array(res, order=order, dtype=dtype) + return np.array(res, order=order, dtype=dtype) # type: ignore[call-overload] # Because we use `str` as `order`. def set_active_primitive_translators_to(