Skip to content

Commit

Permalink
Closes Bears-R-Us#3206 MultiIndex.levels (Bears-R-Us#3207)
Browse files Browse the repository at this point in the history
Co-authored-by: Amanda Potts <ajpotts@users.noreply.github.com>
  • Loading branch information
ajpotts and ajpotts authored May 20, 2024
1 parent bc2697f commit b435ed4
Show file tree
Hide file tree
Showing 7 changed files with 299 additions and 244 deletions.
12 changes: 10 additions & 2 deletions PROTO_tests/tests/index_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ def test_multiindex_creation(self, size):
# test list generation
idx = ak.MultiIndex([ak.arange(size), ak.arange(size)])
assert isinstance(idx, ak.MultiIndex)
assert idx.levels == 2
assert idx.nlevels == 2
assert idx.size == size

# test tuple generation
idx = ak.MultiIndex((ak.arange(size), ak.arange(size)))
assert isinstance(idx, ak.MultiIndex)
assert idx.levels == 2
assert idx.nlevels == 2
assert idx.size == size

with pytest.raises(TypeError):
Expand All @@ -50,6 +50,14 @@ def test_multiindex_creation(self, size):
with pytest.raises(ValueError):
idx = ak.MultiIndex([ak.arange(size), ak.arange(size - 1)])

def test_nlevels(self):
i = ak.Index([1, 2, 3], name="test")
assert i.nlevels == 1

size = 10
m = ak.MultiIndex([ak.arange(size), ak.arange(size) * -1])
assert m.nlevels == 2

@pytest.mark.parametrize("size", pytest.prob_size)
def test_memory_usage(self, size):
from arkouda.dtypes import BigInt
Expand Down
2 changes: 1 addition & 1 deletion PROTO_tests/tests/symbol_table_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def test_multi_index_registration(self, size):
# assert that the object is registered
assert reg_name in reg["Objects"]
# assert that the sym entry name is recorded
for x in i.values:
for x in i.levels:
if x.objType == ak.Categorical.objType:
assert x.codes.name in reg["Components"]
assert x.categories.name in reg["Components"]
Expand Down
72 changes: 46 additions & 26 deletions arkouda/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,17 @@ def _dtype_of_list_values(self, lst):
else:
raise TypeError("Index Types must match")

@property
def nlevels(self):
"""
Integer number of levels in this Index.
An Index will always have 1 level.
See Also
--------
MultiIndex.nlevels
"""
return 1

@property
def index(self):
"""
Expand Down Expand Up @@ -901,36 +912,35 @@ class MultiIndex(Index):

def __init__(
self,
values: Union[list, pdarray, Strings, Categorical],
levels: Union[list, pdarray, Strings, Categorical],
name: Optional[str] = None,
names: Optional[list[str]] = None,
):
self.registered_name: Optional[str] = None
if not (isinstance(values, list) or isinstance(values, tuple)):
if not (isinstance(levels, list) or isinstance(levels, tuple)):
raise TypeError("MultiIndex should be an iterable")
self.values = values
self.levels = levels
first = True
self.names = names
self.name = name
for col in self.values:
for col in self.levels:
# col can be a python int which doesn't have a size attribute
col_size = col.size if not isinstance(col, int) else 0
if first:
# we are implicitly assuming values contains arkouda types and not python lists
# we are implicitly assuming levels contains arkouda types and not python lists
# because we are using obj.size/obj.dtype instead of len(obj)/type(obj)
# this should be made explict using typechecking
self.size = col_size
first = False
else:
if col_size != self.size:
raise ValueError("All columns in MultiIndex must have same length")
self.levels = len(self.values)

def __getitem__(self, key):
from arkouda.series import Series

if isinstance(key, Series):
key = key.values
key = key.levels
return MultiIndex([i[key] for i in self.index])

def __repr__(self):
Expand All @@ -952,11 +962,21 @@ def __eq__(self, v):

@property
def index(self):
return self.values
return self.levels

@property
def nlevels(self) -> int:
"""
Integer number of levels in this MultiIndex.
See Also
--------
Index.nlevels
"""
return len(self.levels)

def memory_usage(self, unit="B"):
"""
Return the memory usage of the MultiIndex values.
Return the memory usage of the MultiIndex levels.
Parameters
----------
Expand Down Expand Up @@ -988,7 +1008,7 @@ def memory_usage(self, unit="B"):
from arkouda.util import convert_bytes

nbytes = 0
for item in self.values:
for item in self.levels:
nbytes += item.nbytes

return convert_bytes(nbytes, unit=unit)
Expand All @@ -1008,7 +1028,7 @@ def set_dtype(self, dtype):
return self

def to_ndarray(self):
return ndarray([convert_if_categorical(val).to_ndarray() for val in self.values])
return ndarray([convert_if_categorical(val).to_ndarray() for val in self.levels])

def to_list(self):
return self.to_ndarray().tolist()
Expand Down Expand Up @@ -1057,7 +1077,7 @@ def register(self, user_defined_name):
args={
"name": user_defined_name,
"objType": self.objType,
"num_idxs": len(self.values),
"num_idxs": len(self.levels),
"idx_names": [
json.dumps(
{
Expand All @@ -1070,9 +1090,9 @@ def register(self, user_defined_name):
)
if isinstance(v, Categorical)
else v.name
for v in self.values
for v in self.levels
],
"idx_types": [v.objType for v in self.values],
"idx_types": [v.objType for v in self.levels],
},
)
self.registered_name = user_defined_name
Expand Down Expand Up @@ -1131,7 +1151,7 @@ def lookup(self, key):
raise TypeError("MultiIndex lookup failure")
# if individual vals convert to pdarrays
if not isinstance(key[0], pdarray):
dt = self.values[0].dtype if isinstance(self.values[0], pdarray) else akint64
dt = self.levels[0].dtype if isinstance(self.levels[0], pdarray) else akint64
key = [akcast(array([x]), dt) for x in key]

return in1d(self.index, key)
Expand Down Expand Up @@ -1200,7 +1220,7 @@ def to_hdf(
**({"segments": obj.segments.name} if obj.segments is not None else {}),
}
)
for obj in self.values
for obj in self.levels
]
return typecast(
str,
Expand All @@ -1212,10 +1232,10 @@ def to_hdf(
"file_format": _file_type_to_int(file_type),
"write_mode": _mode_str_to_int(mode),
"objType": self.objType,
"num_idx": len(self.values),
"num_idx": len(self.levels),
"idx": index_data,
"idx_objTypes": [obj.objType for obj in self.values],
"idx_dtypes": [str(obj.dtype) for obj in self.values],
"idx_objTypes": [obj.objType for obj in self.levels],
"idx_dtypes": [str(obj.dtype) for obj in self.levels],
},
),
)
Expand Down Expand Up @@ -1252,7 +1272,7 @@ def update_hdf(
RuntimeError
Raised if a server-side error is thrown saving the index
TypeError
Raised if the Index values are a list.
Raised if the Index levels are a list.
Notes
------
Expand All @@ -1271,8 +1291,8 @@ def update_hdf(
_repack_hdf,
)

if isinstance(self.values, list):
raise TypeError("Unable update hdf when Index values are a list.")
if isinstance(self.levels, list):
raise TypeError("Unable update hdf when Index levels are a list.")

# determine the format (single/distribute) that the file was saved in
file_type = _get_hdf_filetype(prefix_path + "*")
Expand All @@ -1289,7 +1309,7 @@ def update_hdf(
**({"segments": obj.segments.name} if obj.segments is not None else {}),
}
)
for obj in self.values
for obj in self.levels
]

generic_msg(
Expand All @@ -1300,10 +1320,10 @@ def update_hdf(
"file_format": _file_type_to_int(file_type),
"write_mode": _mode_str_to_int("append"),
"objType": self.objType,
"num_idx": len(self.values),
"num_idx": len(self.levels),
"idx": index_data,
"idx_objTypes": [obj.objType for obj in self.values],
"idx_dtypes": [str(obj.dtype) for obj in self.values],
"idx_objTypes": [obj.objType for obj in self.levels],
"idx_dtypes": [str(obj.dtype) for obj in self.levels],
"overwrite": True,
},
),
Expand Down
2 changes: 1 addition & 1 deletion arkouda/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ def topn(self, n: int = 10) -> Series:

def _reindex(self, idx):
if isinstance(self.index, MultiIndex):
new_index = MultiIndex(self.index[idx].values, name=self.index.name, names=self.index.names)
new_index = MultiIndex(self.index[idx].levels, name=self.index.name, names=self.index.names)
elif isinstance(self.index, Index):
new_index = Index(self.index[idx], name=self.index.name)
else:
Expand Down
Loading

0 comments on commit b435ed4

Please sign in to comment.