Skip to content

Commit

Permalink
Initial implementeation of shape checking method
Browse files Browse the repository at this point in the history
  • Loading branch information
ksunden committed Nov 8, 2023
1 parent 2e70f58 commit fca3448
Showing 1 changed file with 44 additions and 2 deletions.
46 changes: 44 additions & 2 deletions data_prototype/containers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
from dataclasses import dataclass
from typing import Protocol, Dict, Tuple, Optional, Any, Union, Callable, MutableMapping
from typing import (
Protocol,
Dict,
Tuple,
Optional,
Any,
Union,
Callable,
MutableMapping,
TypeAlias,
)
import uuid

from cachetools import LFUCache
Expand All @@ -16,6 +26,9 @@ def __sub__(self, other) -> "_MatplotlibTransform":
...


ShapeSpec: TypeAlias = Tuple[Union[str, int], ...]


@dataclass(frozen=True)
class Desc:
# TODO: sort out how to actually spell this. We need to know:
Expand All @@ -24,12 +37,41 @@ class Desc:
# - is this a variable size depending on the query (e.g. N)
# - what is the relative size to the other variable values (N vs N+1)
# We are probably going to have to implement a DSL for this (😞)
shape: Tuple[Union[str, int], ...]
shape: ShapeSpec
# TODO: is using a string better?
dtype: np.dtype
# TODO: do we want to include this at this level? "naive" means unit-unaware.
units: str = "naive"

@staticmethod
def check_shapes(*args: tuple[ShapeSpec, "Desc"], broadcast=False) -> bool:
specvars: dict[str, int | tuple[str, int]] = {}
for spec, desc in args:
if not broadcast:
if len(spec) != len(desc.shape):
return False
elif len(desc.shape) > len(spec):
return False
for speccomp, desccomp in zip(spec[::-1], desc.shape[::-1]):
if broadcast and desccomp == 1:
continue
if isinstance(speccomp, str):
specv, specoff = speccomp[0], int(speccomp[1:] or 0)

if isinstance(desccomp, str):
descv, descoff = speccomp[0], int(speccomp[1:] or 0)
entry = (descv, descoff - specoff)
else:
entry = desccomp - specoff

if specv in specvars and entry != specvars[specv]:
return False

specvars[specv] = entry
elif speccomp != desccomp:
return False
return True


class DataContainer(Protocol):
def query(
Expand Down

0 comments on commit fca3448

Please sign in to comment.