Skip to content

Commit

Permalink
Merge branch 'main' into patch/py38-120
Browse files Browse the repository at this point in the history
  • Loading branch information
jacanchaplais committed Mar 31, 2023
2 parents 92408f1 + 9cb5c90 commit ae29d14
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 53 deletions.
63 changes: 55 additions & 8 deletions graphicle/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from numpy.lib import recfunctions as rfn
from rich.console import Console
from rich.tree import Tree
from scipy.sparse import coo_array
from scipy.sparse import coo_array, csr_array
from typicle import Types

from . import base, calculate
Expand Down Expand Up @@ -352,7 +352,7 @@ def converter(values: npt.ArrayLike) -> base.AnyVector:
array = np.asarray(values, dtype=dtype)
shape = array.shape
len_shape = len(shape)
is_flat = (len_shape == 1) or (len_shape == 0)
is_flat = (len_shape == 1) or (len_shape == 0) or (shape[1] == 1)
if is_flat and (num_cols == 1):
return array.reshape(-1)
cols = shape[0] if is_flat else shape[1]
Expand Down Expand Up @@ -640,7 +640,7 @@ def from_numpy_structured(cls, arr: base.VoidVector) -> "MaskGroup":
)

def __repr__(self) -> str:
keys = ", ".join(map(lambda name: '"' + name + '"', self.names))
keys = ", ".join(map(lambda name: '"' + name + '"', self.keys()))
return f"MaskGroup(masks=[{keys}], agg_op={self.agg_op.name})"

def __rich__(self) -> Tree:
Expand Down Expand Up @@ -712,7 +712,7 @@ def __bool__(self) -> bool:
return True

def __len__(self) -> int:
return len(self.names)
return len(self._mask_arrays)

def __delitem__(self, key) -> None:
"""Remove a MaskArray from the group, using given key."""
Expand Down Expand Up @@ -818,6 +818,39 @@ def to_dict(self) -> ty.Dict[str, base.BoolVector]:
"""Masks nested in a dictionary instead of a ``MaskGroup``."""
return {key: val.data for key, val in self._mask_arrays.items()}

def recursive_drop(
self, key: str = "latent", inplace: bool = False
) -> "MaskGroup":
"""Removes masks indexed by ``key`` at all levels of nesting.
Operation is performed inplace.
.. versionadded:: 0.2.8
Parameters
----------
key : str
String key to be dropped from the ``MaskGroup``. Default is
``"latent"``.
inplace : bool
If ``True``, will mutate the ``MaskGroup`` instance inplace.
If ``False``, will first copy the instance, leaving the
original object unchanged. Default is ``False``.
Returns
-------
MaskGroup
``MaskGroup`` object with ``key`` recursively removed.
"""
mask_group = self
if inplace is False:
mask_group = mask_group.copy()
if key in mask_group:
del mask_group[key]
for value in mask_group.values():
if isinstance(value, type(mask_group)):
value.recursive_drop(key, inplace=True)
return mask_group

def flatten(
self, how: ty.Literal["rise", "agg"] = "rise"
) -> "MaskGroup[MaskArray]":
Expand Down Expand Up @@ -1801,7 +1834,7 @@ class AdjacencyList(base.AdjacencyBase):
Scalar value embedded on each edge.
"""

_data: base.AnyVector = _array_field("<i4", 2)
_data: base.IntVector = _array_field("<i4", 2)
weights: base.DoubleVector = _array_field("<f8")
dtype: np.dtype = field(init=False, repr=False)
_HANDLED_TYPES: ty.Tuple[ty.Type, ...] = field(init=False, repr=False)
Expand Down Expand Up @@ -1835,21 +1868,31 @@ def __len__(self) -> int:
def __bool__(self) -> bool:
return _truthy(self)

def __array__(self) -> base.VoidVector:
def __array__(self) -> base.IntVector:
return self._data

def __eq__(
self, other: ty.Union[base.ArrayBase, base.AnyVector]
) -> MaskArray:
return _array_eq(self, other)

def __ne__(
self, other: ty.Union[base.ArrayBase, base.AnyVector]
) -> MaskArray:
return _array_ne(self, other)

def __getitem__(self, key) -> "AdjacencyList":
if isinstance(key, base.MaskBase):
key = key.data
return self.__class__(self._data[key])
return self.__class__(self._data[key].reshape(-1, 2))

def copy(self) -> "AdjacencyList":
return self.__class__(self._data.copy())

@fn.cached_property
def _edge_relabel(self) -> base.IntVector:
_, inv = np.unique(self._data, return_inverse=True)
return inv.reshape(-1, 2)
return inv.reshape(-1, 2).astype("<i4")

@fn.cached_property
def _sparse_signed(self) -> coo_array:
Expand All @@ -1861,6 +1904,10 @@ def _sparse_unsigned(self) -> coo_array:
sparse_arr.data[...] = True
return sparse_arr

@fn.cached_property
def _sparse_csr(self) -> csr_array:
return csr_array(self._sparse_signed)

@property
def _sparse_weighted(self) -> coo_array:
sparse_arr = self._sparse_signed.copy()
Expand Down
Loading

0 comments on commit ae29d14

Please sign in to comment.