Skip to content

Commit

Permalink
Parser cache (#13)
Browse files Browse the repository at this point in the history
* Cache parser

* Update context docs

* Update API docs

* Mention data objects in specifying shapes docs

* Add docs on performance

* Bump version to 0.3.0
  • Loading branch information
EPronovost authored Dec 5, 2023
1 parent 490259b commit 7c983b3
Show file tree
Hide file tree
Showing 11 changed files with 333 additions and 26 deletions.
55 changes: 53 additions & 2 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ API
import numpy as np
import numpy.typing as npt
from numpy.random import randn
from eincheck import check_func, check_data, check_shapes
from eincheck import (
check_func, check_data, check_shapes,
disable_checks, enable_checks,
parser_cache_clear, parser_cache_info, parser_resize_cache,
)
from typing import NamedTuple
import attrs

Expand Down Expand Up @@ -316,6 +320,53 @@ You can use ``check_shapes`` to re-check a data object.
arg0.p: got (4,) expected [i 2]


.. autofunction:: eincheck.enable_checks
.. autofunction:: eincheck.disable_checks

Context manager to disable eincheck.
This can be used to make code run faster once you're confident the shapes are correct.
`check_shapes` will return an empty dictionary.

.. doctest::

>>> with disable_checks():
... # Eincheck is a no-op inside this context.
... print(check_shapes((randn(2, 3), "i")))
...
{}

.. autofunction:: eincheck.enable_checks

Context manager to enable eincheck (e.g. if inside a `disable_checks` context).

.. doctest::

>>> with disable_checks():
... with enable_checks():
... check_shapes((randn(2, 3), "i"))
...
Traceback (most recent call last):
...
ValueError: arg0: expected rank 1, got shape (2, 3)
arg0: got (2, 3) expected [i]

.. autofunction:: eincheck.parser_cache_clear

Clear the ``lru_cache`` for parsing shape strings.

.. autofunction:: eincheck.parser_cache_info

Get the ``lru_cache`` cache info for the parser cache.

.. doctest::

>>> parser_cache_clear()
>>> check_shapes((randn(2, 3), "a b"), (randn(3, 4), "b c"), (randn(2, 3), "a b"))
{'a': 2, 'b': 3, 'c': 4}
>>> parser_cache_info()
CacheInfo(hits=1, misses=2, maxsize=128, currsize=2)

.. autofunction:: eincheck.parser_resize_cache

Reset the parser cache to a ``lru_cache`` with the given size.
This will clear the cache and change the ``maxsize`` field in ``CacheInfo``.

2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Welcome to eincheck's documentation!

specifying_shapes
api
performance

Getting Started
---------------
Expand Down Expand Up @@ -85,6 +86,7 @@ Resources

* :ref:`Specifying Shapes` contains information on how to format shape specifications (e.g. ``"... i j"``)
* :ref:`API` contains information on the ``check_*`` functions
* :ref:`Performance` contains information on making code with shape checks run faster


Indices and tables
Expand Down
107 changes: 107 additions & 0 deletions docs/source/performance.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
Performance
===========


.. testsetup::

import numpy as np
import numpy.typing as npt
from numpy.random import randn
from eincheck import (
check_func, check_data, check_shapes,
disable_checks, enable_checks,
parser_cache_clear, parser_cache_info, parser_resize_cache,
)
from typing import List

Adding eincheck to a project introduces extra computations, which will have some latency impact.
This impact can be significantly mitigated by following best practices.

There are two major parts to performing shape checks with eincheck: parsing the user input into internal data structures and checking the tensor shapes against those data structures.
In most cases, parsing will take significantly more time than actually doing the checks.
Writing performant code with eincheck thus requires us to minimize the amount of parsing necessary.

Doing Less Parsing
------------------

There are several ways to reduce the amount of parsing.

Use Decorators
^^^^^^^^^^^^^^

The decorators ``check_data`` and ``check_func`` will parse the inputs once and then reuse them each time the data object/function is called.
These decorators should be used whenever possible.

Cached Parsing
^^^^^^^^^^^^^^
There are cases where the abovementioned decorators cannot be used.
``check_shapes`` uses an ``lru_cache`` to cache parsing, initialized with a default size of 128.
To achieve good cache utilization, prefer to use constant shape specs.
For example:

.. doctest::

>>> parser_cache_clear()
>>> parser_cache_info()
CacheInfo(hits=0, misses=0, maxsize=128, currsize=0)
>>> def bad(x: npt.NDArray[np.float64], inds: List[int]) -> npt.NDArray[np.float64]:
... y = x[..., inds]
... # Bad! The shape spec for y will change for different length inds.
... check_shapes((x, "*i _"), (y, f"*i {len(inds)}"))
... return y
>>> _ = bad(randn(5, 10), [1])
>>> _ = bad(randn(5, 10), [4, 2])
>>> _ = bad(randn(5, 10), [0, 1, 2])
>>> _ = bad(randn(5, 10), [7, 7, 7, 7, 7])
>>> _ = bad(randn(5, 10), [3, 1, 4, 1, 5])
>>> parser_cache_info()
CacheInfo(hits=5, misses=5, maxsize=128, currsize=5)
>>>
>>> parser_cache_clear()
>>> def good(x: npt.NDArray[np.float64], inds: List[int]) -> npt.NDArray[np.float64]:
... y = x[..., inds]
... # Good! The shape specs are constant.
... check_shapes((x, "*i _"), (y, "*i n"), n=len(inds))
... return y
>>> _ = good(randn(5, 10), [1])
>>> _ = good(randn(5, 10), [4, 2])
>>> _ = good(randn(5, 10), [0, 1, 2])
>>> _ = good(randn(5, 10), [7, 7, 7, 7, 7])
>>> _ = good(randn(5, 10), [3, 1, 4, 1, 5])
>>> parser_cache_info()
CacheInfo(hits=8, misses=2, maxsize=128, currsize=2)

The functions ``parser_cache_info``, ``parser_cache_clear``, and ``parser_resize_cache`` can be used to monitor and adjust the caching behavior.

Skip Parsing
^^^^^^^^^^^^

If caching is not a viable option (e.g. due to memory constraints), another option to improve performance is to pass lists instead of strings as the shape specs to ``check_shapes``.
The ``ShapeArg`` type used as input to ``check_shapes`` includes ``Sequence[Union[DimSpec, str, int, None]]``.
If a sequence is provided, eincheck will skip the parser and build the internal data structures directly from this list.
As such, advanced parsing features are not supported with this method.
The strings in the sequence must be single variable names, with no parentheses or binary operators.

.. doctest::

>>> def foo(x: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
... check_shapes((x, [None, "i", "i"]))
... return np.diagonal(x, axis1=1, axis2=2)
>>> foo(randn(3, 4, 4)).shape
(3, 4)
>>> def bad(x: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
... # Bad! Can't use binary operator + with list ShapeArg.
... check_shapes((x, ["i", "i+1"]))
... return x
>>> bad(randn(3, 4)).shape
Traceback (most recent call last):
...
ValueError: Variable name should be made of only ascii letters, got i+1

Disabling Checks
----------------

The most powerful tool to make code using eincheck run faster is to disable eincheck altogether.
For example, eincheck can be used while initially developing code and then disabled in optimized production environments.
The ``disable_checks`` and ``enable_checks`` context managers can be used to disable and re-enable eincheck within certain scopes.

7 changes: 7 additions & 0 deletions docs/source/specifying_shapes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,13 @@ Repeated underscores (``_*``) is equivalent to an ellipse.
>>> check_shapes((x, "_* 5"))
{}


Data Objects
------------

A dollar sign (``$``) can be used with data objects decorated with ``check_data``.
See the API section on this decorator for more info.

Limitations
-----------

Expand Down
10 changes: 9 additions & 1 deletion eincheck/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,19 @@
from .checks.func import check_func
from .checks.shapes import check_shapes
from .contexts import disable_checks, enable_checks
from .parser.parse_cache import (
parser_cache_clear,
parser_cache_info,
parser_resize_cache,
)

__all__ = [
"check_data",
"check_func",
"check_shapes",
"enable_checks",
"disable_checks",
"enable_checks",
"parser_cache_clear",
"parser_cache_info",
"parser_resize_cache",
]
32 changes: 32 additions & 0 deletions eincheck/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from functools import _CacheInfo, lru_cache
from typing import Callable, Generic, Optional, TypeVar

from typing_extensions import ParamSpec

_T = TypeVar("_T")
_P = ParamSpec("_P")


class ResizeableLruCache(Generic[_P, _T]):
def __init__(self, func: Callable[_P, _T], maxsize: Optional[int] = 128):
self._func = func
self._cached_func = lru_cache(maxsize=maxsize)(func)

def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
return self._cached_func(*args, **kwargs) # type: ignore[arg-type]

def reset_maxsize(self, maxsize: Optional[int]) -> None:
self._cached_func = lru_cache(maxsize=maxsize)(self._func)

def cache_info(self) -> _CacheInfo:
return self._cached_func.cache_info()

def cache_clear(self) -> None:
self._cached_func.cache_clear()


def resizeable_lru_cache(
maxsize: Optional[int] = 128,
) -> Callable[[Callable[_P, _T]], ResizeableLruCache[_P, _T]]:
"""Resizable version of functools.lru_cache."""
return lambda f: ResizeableLruCache(f, maxsize=maxsize)
24 changes: 2 additions & 22 deletions eincheck/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,10 @@ def _set_enable_checks(value: bool) -> Generator[None, None, None]:


def enable_checks() -> ContextManager[None]:
"""Enable eincheck to do shape checks.
Example:
```
with disable_checks():
# check_shapes is a no-op inside this context
with enable_checks():
# now check_shapes does something again
```
"""
"""Enable eincheck to do shape checks."""
return _set_enable_checks(True)


def disable_checks() -> ContextManager[None]:
"""Disable eincheck to do shape checks.
Example:
```
with disable_checks():
# check_shapes is a no-op inside this context
with enable_checks():
# now check_shapes does something again
```
"""
"""Disable eincheck from doing shape checks."""
return _set_enable_checks(False)
2 changes: 2 additions & 0 deletions eincheck/parser/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from lark.lark import Lark
from lark.visitors import Transformer

from eincheck.cache import resizeable_lru_cache
from eincheck.parser.dim_spec import DimSpec, DimType
from eincheck.parser.expressions import (
AddOp,
Expand Down Expand Up @@ -111,6 +112,7 @@ def _get_expr(x: Any) -> Expr:
_transformer = TreeToSpec()


@resizeable_lru_cache()
def parse_shape_spec(s: str) -> ShapeSpec:
"""Parse a string into a ShapeSpec."""
out = _transformer.transform(_parser.parse(s.strip(" ")))
Expand Down
5 changes: 5 additions & 0 deletions eincheck/parser/parse_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from eincheck.parser.grammar import parse_shape_spec

parser_cache_info = parse_shape_spec.cache_info
parser_cache_clear = parse_shape_spec.cache_clear
parser_resize_cache = parse_shape_spec.reset_maxsize
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "eincheck"
version = "0.2.1"
version = "0.3.0"
description = "Tensor shape checks inspired by einstein notation"
authors = ["Ethan Pronovost <epronovo1@gmail.com>"]
readme = "README.md"
Expand Down
Loading

0 comments on commit 7c983b3

Please sign in to comment.