From 9a9b30432aaceac7b35077a5e19bfa26c8c12ef4 Mon Sep 17 00:00:00 2001 From: Alex Rogozhnikov Date: Wed, 10 Jan 2024 10:23:35 -0800 Subject: [PATCH 1/2] delete experimental data-api_packing as production version is available now --- einops/experimental/data_api_packing.py | 137 ------------------------ 1 file changed, 137 deletions(-) delete mode 100644 einops/experimental/data_api_packing.py diff --git a/einops/experimental/data_api_packing.py b/einops/experimental/data_api_packing.py deleted file mode 100644 index 5e3e04c5..00000000 --- a/einops/experimental/data_api_packing.py +++ /dev/null @@ -1,137 +0,0 @@ -from typing import List, TypeVar, Tuple, Sequence - -from einops import EinopsError - -T = TypeVar('T') - -Shape = Tuple[int, ...] - - -def pack(pattern: str, tensors: Sequence[T]) -> Tuple[T, List[Shape]]: - axes = pattern.split() - if len(axes) != len(set(axes)): - raise EinopsError(f'Duplicates in axes names in pack("{pattern}", ...)') - if '*' not in axes: - raise EinopsError(f'No *-axis in pack("{pattern}", ...)') - - # need some validation of identifiers - - n_axes_before = axes.index('*') - n_axes_after = len(axes) - n_axes_before - 1 - min_axes = n_axes_before + n_axes_after - - xp = tensors[0].__array_namespace__() - - reshaped_tensors: List[T] = [] - packed_shapes: List[Shape] = [] - for i, tensor in enumerate(tensors): - shape = tensor.shape - if len(shape) < min_axes: - raise EinopsError(f'packed tensor #{i} (enumeration starts with 0) has shape {shape}, ' - f'while pattern {pattern} assumes at least {min_axes} axes') - axis_after_packed_axes = len(shape) - n_axes_after - packed_shapes.append(shape[n_axes_before:]) - reshaped_tensors.append( - xp.reshape(tensor, (*shape[:n_axes_before], -1, *shape[axis_after_packed_axes:])) - ) - - return xp.concat(reshaped_tensors, axis=n_axes_before), packed_shapes - - -def prod(x: Shape) -> int: - result = 1 - for i in x: - result *= i - return result - - -def unpack(pattern: str, tensor: T, packed_shapes: List[Shape]) -> List[T]: - axes = pattern.split() - if len(axes) != len(set(axes)): - raise EinopsError(f'Duplicates in axes names in unpack("{pattern}", ...)') - if '*' not in axes: - raise EinopsError(f'No *-axis in unpack("{pattern}", ...)') - - # need some validation of identifiers - - input_shape = tensor.shape - if len(input_shape) != len(axes): - raise EinopsError(f'unpack({pattern}, ...) received input of wrong dim with shape {input_shape}') - - unpacked_axis = axes.index('*') - - lengths_of_composed_axes: List[int] = [ - -1 if -1 in p_shape else prod(p_shape) - for p_shape in packed_shapes - ] - - n_unknown_composed_axes = sum(x == -1 for x in lengths_of_composed_axes) - if n_unknown_composed_axes > 1: - raise EinopsError( - f"unpack({pattern}, ...) received more than one -1 in {packed_shapes} and can't infer dimensions" - ) - - # following manipulations allow to skip some shape verifications - # and leave them to backends - - # [[], [2, 3], [4], [-1, 5], [6]] < examples of packed_axis - # split positions when computed should be - # [0, 1, 7, 11, N-6 , N ], where N = length of axis - split_positions = [0] * len(packed_shapes) + [input_shape[unpacked_axis]] - if n_unknown_composed_axes == 0: - for i, x in enumerate(lengths_of_composed_axes[:-1]): - split_positions[i + 1] = split_positions[i] + x - else: - unknown_composed_axis: int = lengths_of_composed_axes.index(-1) - for i in range(unknown_composed_axis): - split_positions[i + 1] = split_positions[i] + lengths_of_composed_axes[i] - for j in range(unknown_composed_axis + 1, len(lengths_of_composed_axes))[::-1]: - split_positions[j] = split_positions[j + 1] + lengths_of_composed_axes[j] - - xp = tensor.__array_namespace__() - shape_start = input_shape[:unpacked_axis] - shape_end = input_shape[unpacked_axis + 1:] - slice_filler = (slice(None, None),) * unpacked_axis - return [ - xp.reshape( - # shortest way slice arbitrary axis - tensor[(*slice_filler, slice(split_positions[i], split_positions[i + 1]))], - (*shape_start, *element_shape, *shape_end) - ) - for i, element_shape in enumerate(packed_shapes) - ] - - -if __name__ == '__main__': - import numpy.array_api as np - - H = 100 - W = 101 - C = 3 - - r = np.zeros((H, W)) - g = np.zeros((H, W)) - b = np.zeros((H, W)) - embeddings = np.zeros((H, W, 32)) - - im = np.stack([r, g, b], axis=-1) - print(im.shape) - - image, shapes = pack('h w *', [r, g, b]) - print(image.shape, shapes) - - print(type(image)) - print(type(im)) - assert np.all(np.equal(image, im)) - - images_and_embedding, shapes = pack('h w *', [r, g, b, embeddings]) - print(images_and_embedding.shape, shapes) - r2, g2, b2, embeddings2 = unpack('h w *', images_and_embedding, shapes) - assert np.all(np.equal(r, r2)) - assert np.all(np.equal(g, g2)) - assert np.all(np.equal(b, b2)) - assert np.all(np.equal(embeddings, embeddings2)) - - print([x.shape for x in unpack('h w *', images_and_embedding, shapes[1:])]) - - print('all is fine') From 5b41d90142917744172e308d5bf25a606413cb43 Mon Sep 17 00:00:00 2001 From: Alex Rogozhnikov Date: Wed, 10 Jan 2024 23:20:58 -0800 Subject: [PATCH 2/2] allow anonymous axes in parse_shape, fix #302 --- einops/einops.py | 16 +++++++++++++--- tests/test_other.py | 43 ++++++++++++++++++++++++++++++++++--------- 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/einops/einops.py b/einops/einops.py index f4ae207a..2b20f481 100644 --- a/einops/einops.py +++ b/einops/einops.py @@ -687,9 +687,19 @@ def parse_shape(x, pattern: str) -> dict: else: composition = exp.composition result = {} - for (axis_name,), axis_length in zip(composition, shape): # type: ignore - if axis_name != "_": - result[axis_name] = axis_length + for axes, axis_length in zip(composition, shape): # type: ignore + # axes either [], or [AnonymousAxis] or ['axis_name'] + if len(axes) == 0: + if axis_length != 1: + raise RuntimeError(f"Length of axis is not 1: {pattern} {shape}") + else: + [axis] = axes + if isinstance(axis, str): + if axis != "_": + result[axis] = axis_length + else: + if axis.value != axis_length: + raise RuntimeError(f"Length of anonymous axis does not match: {pattern} {shape}") return result diff --git a/tests/test_other.py b/tests/test_other.py index 87b08157..313d5d2b 100644 --- a/tests/test_other.py +++ b/tests/test_other.py @@ -121,8 +121,9 @@ def test_repeating(self): with pytest.raises(einops.EinopsError): parse_shape(self.backend.from_numpy(self.x), "a a b b") - @parameterized.expand( - [ + + def test_ellipsis(self): + for shape, pattern, expected in [ ([10, 20], "...", dict()), ([10], "... a", dict(a=10)), ([10, 20], "... a", dict(a=20)), @@ -134,13 +135,37 @@ def test_repeating(self): ([10, 20, 30, 40], "a ...", dict(a=10)), ([10, 20, 30, 40], " a ... b", dict(a=10, b=40)), ([10, 40], " a ... b", dict(a=10, b=40)), - ] - ) - def test_ellipsis(self, shape: List[int], pattern: str, expected: Dict[str, int]): - x = numpy.ones(shape) - parsed1 = parse_shape(x, pattern) - parsed2 = parse_shape(self.backend.from_numpy(x), pattern) - assert parsed1 == parsed2 == expected + ]: + x = numpy.ones(shape) + parsed1 = parse_shape(x, pattern) + parsed2 = parse_shape(self.backend.from_numpy(x), pattern) + assert parsed1 == parsed2 == expected + + def test_parse_with_anonymous_axes(self): + for shape, pattern, expected in [ + ([1, 2, 3, 4], "1 2 3 a", dict(a=4)), + ([10, 1, 2], "a 1 2", dict(a=10)), + ([10, 1, 2], "a () 2", dict(a=10)), + ]: + x = numpy.ones(shape) + parsed1 = parse_shape(x, pattern) + parsed2 = parse_shape(self.backend.from_numpy(x), pattern) + assert parsed1 == parsed2 == expected + + + def test_failures(self): + # every test should fail + for shape, pattern in [ + ([1, 2, 3, 4], "a b c"), + ([1, 2, 3, 4], "2 a b c"), + ([1, 2, 3, 4], "a b c ()"), + ([1, 2, 3, 4], "a b c d e"), + ([1, 2, 3, 4], "a b c d e ..."), + ([1, 2, 3, 4], "a b c ()"), + ]: + with pytest.raises(RuntimeError): + x = numpy.ones(shape) + parse_shape(self.backend.from_numpy(x), pattern) _SYMBOLIC_BACKENDS = [