Skip to content

Commit

Permalink
adding tests and fixing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
mahaalbashir committed Sep 28, 2023
1 parent b4514f8 commit 6daf851
Show file tree
Hide file tree
Showing 3 changed files with 1,530 additions and 139 deletions.
56 changes: 21 additions & 35 deletions acro/acro_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,12 +541,7 @@ def create_crosstab_masks( # pylint: disable=too-many-arguments,too-many-locals
masks["negative"] = negative
# p-percent check
masks["p-ratio"] = pd.crosstab( # type: ignore
index,
columns,
values,
aggfunc=pperc_funcs,
margins=margins,
dropna=dropna,
index, columns, values, aggfunc=pperc_funcs, margins=margins, dropna=dropna
)
# nk values check
masks["nk-rule"] = pd.crosstab( # type: ignore
Expand Down Expand Up @@ -837,7 +832,7 @@ def apply_suppression(
outcome_df += tmp_df
except TypeError:
logger.warning("problem mask %s is not binary", name)
except ValueError as error:
except ValueError as error: # pragma: no cove
error_message = (
f"An error occurred with the following details"
f":\n Name: {name}\n Mask: {mask}\n Table: {table}"
Expand Down Expand Up @@ -949,53 +944,42 @@ def get_queries(masks, aggfunc) -> list[str]:
# identify level names for rows and columns
index_level_names = mask.index.names
column_level_names = mask.columns.names

# iterate through the masks to identify the true cells and extract queries
for column_level_values in mask.columns:
for index_level_values in mask.index:
if mask.loc[index_level_values, column_level_values]:
if isinstance(index_level_values, tuple):
for col_index, col_label in enumerate(mask.columns):
for row_index, row_label in enumerate(mask.index):
if mask.iloc[row_index, col_index]:
if isinstance(row_label, tuple):
index_query = " & ".join(
[
f"({level} == {val})"
if isinstance(val, (int, float))
else f'({level} == "{val}")'
for level, val in zip(
index_level_names, index_level_values
)
for level, val in zip(index_level_names, row_label)
]
)
else:
index_query = " & ".join(
[
f"({index_level_names} == {index_level_values})"
if isinstance(index_level_values, (int, float))
else (
f"({index_level_names}"
f'== "{index_level_values}")'
)
f"({index_level_names} == {row_label})"
if isinstance(row_label, (int, float))
else (f"({index_level_names}" f'== "{row_label}")')
]
)
if isinstance(column_level_values, tuple):
if isinstance(col_label, tuple):
column_query = " & ".join(
[
f"({level} == {val})"
if isinstance(val, (int, float))
else f'({level} == "{val}")'
for level, val in zip(
column_level_names, column_level_values
)
for level, val in zip(column_level_names, col_label)
]
)
else:
column_query = " & ".join(
[
f"({column_level_names} == {column_level_values})"
if isinstance(column_level_values, (int, float))
else (
f"({column_level_names}"
f'== "{column_level_values}")'
)
f"({column_level_names} == {col_label})"
if isinstance(col_label, (int, float))
else (f"({column_level_names}" f'== "{col_label}")')
]
)
query = f"{index_query} & {column_query}"
Expand Down Expand Up @@ -1051,17 +1035,19 @@ def get_index_columns(index, columns, data) -> tuple[list | Series, list | Serie
List | Series
The columns extracted from the data.
"""
shift = 1
if isinstance(index, list):
index_new = []
for _, val in enumerate(index):
index_new.append(data[val.name])
for i in range(len(index)):
index_new.append(data.iloc[:, i])
shift = len(index)
else:
index_new = data[index.name]

if isinstance(columns, list):
columns_new = []
for _, val in enumerate(columns):
columns_new.append(data[val.name])
for i in range(shift, shift + len(columns)):
columns_new.append(data.iloc[:, i])
else:
columns_new = data[columns.name]
return index_new, columns_new
Expand Down
Loading

0 comments on commit 6daf851

Please sign in to comment.