Skip to content

Commit

Permalink
Merge pull request #303 from arogozhnikov/dev
Browse files Browse the repository at this point in the history
Allow anonymous axes in parse_shape, fix #302
  • Loading branch information
arogozhnikov authored Jan 11, 2024
2 parents d495e7c + 5b41d90 commit 5655ce1
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 149 deletions.
16 changes: 13 additions & 3 deletions einops/einops.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,9 +687,19 @@ def parse_shape(x, pattern: str) -> dict:
else:
composition = exp.composition
result = {}
for (axis_name,), axis_length in zip(composition, shape): # type: ignore
if axis_name != "_":
result[axis_name] = axis_length
for axes, axis_length in zip(composition, shape): # type: ignore
# axes either [], or [AnonymousAxis] or ['axis_name']
if len(axes) == 0:
if axis_length != 1:
raise RuntimeError(f"Length of axis is not 1: {pattern} {shape}")
else:
[axis] = axes
if isinstance(axis, str):
if axis != "_":
result[axis] = axis_length
else:
if axis.value != axis_length:
raise RuntimeError(f"Length of anonymous axis does not match: {pattern} {shape}")
return result


Expand Down
137 changes: 0 additions & 137 deletions einops/experimental/data_api_packing.py

This file was deleted.

43 changes: 34 additions & 9 deletions tests/test_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,9 @@ def test_repeating(self):
with pytest.raises(einops.EinopsError):
parse_shape(self.backend.from_numpy(self.x), "a a b b")

@parameterized.expand(
[

def test_ellipsis(self):
for shape, pattern, expected in [
([10, 20], "...", dict()),
([10], "... a", dict(a=10)),
([10, 20], "... a", dict(a=20)),
Expand All @@ -134,13 +135,37 @@ def test_repeating(self):
([10, 20, 30, 40], "a ...", dict(a=10)),
([10, 20, 30, 40], " a ... b", dict(a=10, b=40)),
([10, 40], " a ... b", dict(a=10, b=40)),
]
)
def test_ellipsis(self, shape: List[int], pattern: str, expected: Dict[str, int]):
x = numpy.ones(shape)
parsed1 = parse_shape(x, pattern)
parsed2 = parse_shape(self.backend.from_numpy(x), pattern)
assert parsed1 == parsed2 == expected
]:
x = numpy.ones(shape)
parsed1 = parse_shape(x, pattern)
parsed2 = parse_shape(self.backend.from_numpy(x), pattern)
assert parsed1 == parsed2 == expected

def test_parse_with_anonymous_axes(self):
for shape, pattern, expected in [
([1, 2, 3, 4], "1 2 3 a", dict(a=4)),
([10, 1, 2], "a 1 2", dict(a=10)),
([10, 1, 2], "a () 2", dict(a=10)),
]:
x = numpy.ones(shape)
parsed1 = parse_shape(x, pattern)
parsed2 = parse_shape(self.backend.from_numpy(x), pattern)
assert parsed1 == parsed2 == expected


def test_failures(self):
# every test should fail
for shape, pattern in [
([1, 2, 3, 4], "a b c"),
([1, 2, 3, 4], "2 a b c"),
([1, 2, 3, 4], "a b c ()"),
([1, 2, 3, 4], "a b c d e"),
([1, 2, 3, 4], "a b c d e ..."),
([1, 2, 3, 4], "a b c ()"),
]:
with pytest.raises(RuntimeError):
x = numpy.ones(shape)
parse_shape(self.backend.from_numpy(x), pattern)


_SYMBOLIC_BACKENDS = [
Expand Down

0 comments on commit 5655ce1

Please sign in to comment.