diff --git a/data_prototype/containers.py b/data_prototype/containers.py index 4d87446..c278879 100644 --- a/data_prototype/containers.py +++ b/data_prototype/containers.py @@ -1,5 +1,15 @@ from dataclasses import dataclass -from typing import Protocol, Dict, Tuple, Optional, Any, Union, Callable, MutableMapping +from typing import ( + Protocol, + Dict, + Tuple, + Optional, + Any, + Union, + Callable, + MutableMapping, + TypeAlias, +) import uuid from cachetools import LFUCache @@ -16,6 +26,9 @@ def __sub__(self, other) -> "_MatplotlibTransform": ... +ShapeSpec: TypeAlias = Tuple[Union[str, int], ...] + + @dataclass(frozen=True) class Desc: # TODO: sort out how to actually spell this. We need to know: @@ -24,12 +37,41 @@ class Desc: # - is this a variable size depending on the query (e.g. N) # - what is the relative size to the other variable values (N vs N+1) # We are probably going to have to implement a DSL for this (😞) - shape: Tuple[Union[str, int], ...] + shape: ShapeSpec # TODO: is using a string better? dtype: np.dtype # TODO: do we want to include this at this level? "naive" means unit-unaware. units: str = "naive" + @staticmethod + def check_shapes(*args: tuple[ShapeSpec, "Desc"], broadcast=False) -> bool: + specvars: dict[str, int | tuple[str, int]] = {} + for spec, desc in args: + if not broadcast: + if len(spec) != len(desc.shape): + return False + elif len(desc.shape) > len(spec): + return False + for speccomp, desccomp in zip(spec[::-1], desc.shape[::-1]): + if broadcast and desccomp == 1: + continue + if isinstance(speccomp, str): + specv, specoff = speccomp[0], int(speccomp[1:] or 0) + + if isinstance(desccomp, str): + descv, descoff = speccomp[0], int(speccomp[1:] or 0) + entry = (descv, descoff - specoff) + else: + entry = desccomp - specoff + + if specv in specvars and entry != specvars[specv]: + return False + + specvars[specv] = entry + elif speccomp != desccomp: + return False + return True + class DataContainer(Protocol): def query(