Skip to content

Commit

Permalink
updating the supression when total
Browse files Browse the repository at this point in the history
  • Loading branch information
Maha Albashir authored and Maha Albashir committed Sep 22, 2023
1 parent cb1d287 commit 6ff77e5
Show file tree
Hide file tree
Showing 3 changed files with 568 additions and 810 deletions.
250 changes: 127 additions & 123 deletions acro/acro_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,17 +182,12 @@ def crosstab( # pylint: disable=too-many-arguments,too-many-locals
columns,
values,
aggfunc=pperc_funcs,
margins=margins,
dropna=dropna
margins=margins,
dropna=dropna,
)
# nk values check
masks["nk-rule"] = pd.crosstab( # type: ignore
index,
columns,
values,
aggfunc=nk_funcs,
margins=margins,
dropna=dropna
index, columns, values, aggfunc=nk_funcs, margins=margins, dropna=dropna
)
# check for missing values -- currently unsupported
if CHECK_MISSING_VALUES:
Expand Down Expand Up @@ -232,129 +227,127 @@ def crosstab( # pylint: disable=too-many-arguments,too-many-locals
if self.suppress:
table = safe_table
if margins:
if aggfunc is None:
table = table.drop(margins_name, axis=1)
rows_total = table.sum(axis=1)
table.loc[:, margins_name] = rows_total
table = table.drop(margins_name, axis=0)
cols_total = table.sum(axis=0)
table.loc[margins_name] = cols_total
if aggfunc == "mean":
count_table = pd.crosstab( # type: ignore
index,
columns,
# initialize a list to store queries for true cells
true_cell_queries = []
for _, mask in masks.items():
# drop the name of the mask
mask = mask.droplevel(0, axis=1)
# identify level names for rows and columns
index_level_names = mask.index.names
column_level_names = mask.columns.names

# iterate through the masks to identify the true cells and extract queries
for column_level_values in mask.columns:
for index_level_values in mask.index:
if (
mask.loc[index_level_values, column_level_values]
# == True
):
if isinstance(index_level_values, tuple):
index_query = " & ".join(
[
f"({level} == {val})"
if isinstance(val, (int, float))
else f'({level} == "{val}")'
for level, val in zip(
index_level_names, index_level_values
)
]
)
else:
index_query = " & ".join(
[
f"({index_level_names} == {index_level_values})"
if isinstance(
index_level_values, (int, float)
)
else (
f"({index_level_names}"
f'== "{index_level_values}")'
)
]
)
if isinstance(column_level_values, tuple):
column_query = " & ".join(
[
f"({level} == {val})"
if isinstance(val, (int, float))
else f'({level} == "{val}")'
for level, val in zip(
column_level_names, column_level_values
)
]
)
else:
column_query = " & ".join(
[
f"({column_level_names} == {column_level_values})"
if isinstance(
column_level_values, (int, float)
)
else (
f"({column_level_names}"
f'== "{column_level_values}")'
)
]
)
query = f"{index_query} & {column_query}"
true_cell_queries.append(query)

# delete the duplication
true_cell_queries = list(set(true_cell_queries))

# create dataframe from the index and columns parameters
if isinstance(index, list):
index_df = pd.concat(index, axis=1)
elif isinstance(index, pd.Series):
index_df = pd.DataFrame({index.name: index})
if isinstance(columns, list):
columns_df = pd.concat(columns, axis=1)
elif isinstance(columns, pd.Series):
columns_df = pd.DataFrame({columns.name: columns})
data = pd.concat([index_df, columns_df], axis=1)

# apply the queries to the data
for query in true_cell_queries:
query = str(query).replace("['", "").replace("']", "")
data = data.query(f"not ({query})")

# get the index and columns from the data after the queries are applied
try:
if isinstance(index, list):
index_new = []
for _, val in enumerate(index):
index_new.append(data[val.name])
else:
index_new = data[index.name]

if isinstance(columns, list):
columns_new = []
for _, val in enumerate(columns):
columns_new.append(data[val.name])
else:
columns_new = data[columns.name]

# apply the crosstab with the new index and columns
table = pd.crosstab( # type: ignore
index_new,
columns_new,
values=values,
rownames=rownames,
colnames=colnames,
aggfunc="count",
aggfunc=aggfunc,
margins=margins,
margins_name=margins_name,
dropna=dropna,
normalize=normalize,
)
count_table = count_table.where(table.notna(), other=np.nan)
columns_to_keep = table.columns
count_table = count_table[columns_to_keep]
if not isinstance(
count_table.columns, pd.MultiIndex
) and not isinstance(count_table.index, pd.MultiIndex):
count_table = count_table.drop(margins_name, axis=1)
count_table.loc[:, margins_name] = count_table.sum(axis=1)
count_table = count_table.drop(margins_name, axis=0)
count_table.loc[(margins_name)] = count_table.sum(axis=0)
table[margins_name] = 1
table.loc[margins_name, :] = 1
multip_table = count_table * table
table[margins_name] = (
multip_table.drop(margins_name, axis=1).sum(axis=1)
/ multip_table[margins_name]
)
table.loc[margins_name, :] = (
multip_table.drop(margins_name, axis=0).sum()
/ multip_table.loc[margins_name, :]
)
table.loc[margins_name, margins_name] = (
multip_table.drop(index=margins_name, columns=margins_name)
.sum()
.sum()
) / multip_table.loc[margins_name, margins_name]

if isinstance(count_table.columns, pd.MultiIndex) and isinstance(
count_table.index, pd.MultiIndex
): # multidimensional columns and rows
count_table = count_table.drop(margins_name, axis=1, level=0)
count_table.loc[:, margins_name] = count_table.sum(axis=1)
count_table = count_table.drop(margins_name, axis=0)
count_table.loc[(margins_name, ""), :] = count_table.sum(axis=0)
table[margins_name] = 1
table.loc[margins_name, :] = 1
multip_table = count_table * table
table[margins_name] = (
multip_table.drop(margins_name, axis=1, level=0).sum(axis=1)
/ multip_table[margins_name]
)
table.loc[(margins_name, ""), :] = (
multip_table.drop(margins_name, axis=0).sum()
/ multip_table.loc[(margins_name, ""), :]
)
table.loc[margins_name, margins_name] = (
multip_table.drop(index=margins_name, columns=margins_name)
.sum()
.sum()
) / multip_table.loc[margins_name, margins_name][0]

if isinstance(
count_table.columns, pd.MultiIndex
) and not isinstance(
count_table.index, pd.MultiIndex
): # multidimensional columns
count_table = count_table.drop(margins_name, axis=1, level=0)
count_table.loc[:, margins_name] = count_table.sum(axis=1)
count_table = count_table.drop(margins_name, axis=0)
count_table.loc[(margins_name)] = count_table.sum(axis=0)
table[margins_name] = 1
table.loc[margins_name, :] = 1
multip_table = count_table * table
table[margins_name] = (
multip_table.drop(margins_name, axis=1, level=0).sum(axis=1)
/ multip_table[margins_name]
)
table.loc[margins_name, :] = (
multip_table.drop(margins_name, axis=0).sum()
/ multip_table.loc[margins_name, :]
)
table.loc[margins_name, margins_name] = (
multip_table.drop(index=margins_name, columns=margins_name)
.sum()
.sum()
) / multip_table.loc[margins_name, margins_name][0]

if isinstance(count_table.index, pd.MultiIndex) and not isinstance(
count_table.columns, pd.MultiIndex
): # multidimensional rows
count_table = count_table.where(table.notna(), other=np.nan)
columns_to_keep = table.columns
count_table = count_table[columns_to_keep]
count_table = count_table.drop(margins_name, axis=1)
count_table.loc[:, margins_name] = count_table.sum(axis=1)
count_table = count_table.drop(margins_name, axis=0)
count_table.loc[(margins_name, ""), :] = count_table.sum(axis=0)
table[margins_name] = 1
table.loc[margins_name, :] = 1
multip_table = count_table * table
table[margins_name] = (
multip_table.drop(margins_name, axis=1).sum(axis=1)
/ multip_table[margins_name]
)
table.loc[(margins_name, ""), :] = (
multip_table.drop(margins_name, axis=0).sum()
/ multip_table.loc[(margins_name, ""), :]
)
table.loc[margins_name, margins_name] = (
multip_table.drop(index=margins_name, columns=margins_name)
.sum()
.sum()
) / multip_table.loc[margins_name, margins_name][0]
except ValueError:
logger.info(
"All the cells in this data are discolsive."
" Thus suppression can not be applied"
)
return None

# record output
self.results.add(
Expand Down Expand Up @@ -652,6 +645,7 @@ def plot( # pylint: disable=too-many-arguments,too-many-locals
)
return plot


def rounded_survival_table(survival_table):
"""Calculates the rounded surival function."""
death_censored = (
Expand Down Expand Up @@ -695,6 +689,7 @@ def rounded_survival_table(survival_table):
survival_table["rounded_survival_fun"] = rounded_survival_func
return survival_table


def get_aggfunc(aggfunc: str | None) -> str | None:
"""Checks whether an aggregation function is allowed and returns the
appropriate function.
Expand Down Expand Up @@ -724,6 +719,7 @@ def get_aggfunc(aggfunc: str | None) -> str | None:
logger.debug("aggfunc: %s", func)
return func


def get_aggfuncs(
aggfuncs: str | list[str] | None,
) -> str | list[str] | None:
Expand Down Expand Up @@ -760,6 +756,7 @@ def get_aggfuncs(
return functions
raise ValueError("aggfuncs must be: either str or list[str]") # pragma: no cover


def agg_negative(vals: Series) -> bool:
"""Aggregation function that returns whether any values are negative.
Expand All @@ -775,6 +772,7 @@ def agg_negative(vals: Series) -> bool:
"""
return vals.min() < 0


def agg_missing(vals: Series) -> bool:
"""Aggregation function that returns whether any values are missing.
Expand All @@ -790,6 +788,7 @@ def agg_missing(vals: Series) -> bool:
"""
return vals.isna().sum() != 0


def agg_p_percent(vals: Series) -> bool:
"""Aggregation function that returns whether the p percent rule is violated.
Expand Down Expand Up @@ -823,6 +822,7 @@ def agg_p_percent(vals: Series) -> bool:
p_val: float = sub_total / sorted_vals.iloc[0] if total > 0 else 1
return p_val < SAFE_PRATIO_P


def agg_nk(vals: Series) -> bool:
"""Aggregation function that returns whether the top n items account for
more than k percent of the total.
Expand All @@ -844,6 +844,7 @@ def agg_nk(vals: Series) -> bool:
return (n_total / total) > SAFE_NK_K
return False


def agg_threshold(vals: Series) -> bool:
"""Aggregation function that returns whether the number of contributors is
below a threshold.
Expand All @@ -860,6 +861,7 @@ def agg_threshold(vals: Series) -> bool:
"""
return vals.count() < THRESHOLD


def apply_suppression(
table: DataFrame, masks: dict[str, DataFrame]
) -> tuple[DataFrame, DataFrame]:
Expand Down Expand Up @@ -913,6 +915,7 @@ def apply_suppression(
logger.info("outcome_df:\n%s", utils.prettify_table_string(outcome_df))
return safe_df, outcome_df


def get_table_sdc(masks: dict[str, DataFrame], suppress: bool) -> dict:
"""Returns the SDC dictionary using the suppression masks.
Expand Down Expand Up @@ -945,6 +948,7 @@ def get_table_sdc(masks: dict[str, DataFrame], suppress: bool) -> dict:
sdc["cells"][name].append([int(row_index), int(col_index)])
return sdc


def get_summary(sdc: dict) -> tuple[str, str]:
"""Returns the status and summary of the suppression masks.
Expand Down
Loading

0 comments on commit 6ff77e5

Please sign in to comment.