-
Notifications
You must be signed in to change notification settings - Fork 1
/
fairness_checks.py
55 lines (40 loc) · 2.29 KB
/
fairness_checks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import pandas as pd
from itertools import product
import plotly.express as px
def sensitive_feature_combinations(dataset_original, sensitive_features, target_column, bins=5):
""""Check 16: Biased data. Function that finds all combinations of possible sensitive features and displays them in a table,
used for plotting the stacked bar chart later on"""
dataset = dataset_original.copy()
#bin numeric sensitive features into 5 bins
for feat in sensitive_features:
if pd.api.types.is_numeric_dtype(dataset[feat]):
#if dtypes[col] == 'floating' or dtypes[col] == 'numeric' or dtypes[col] == 'integer':
dataset[feat] = pd.cut(dataset[feat], bins=bins)
#obtain all combinations of sensitive features
sensitive_combinations = list(product(*[dataset[feat].unique() for feat in sensitive_features]))
counts = {}
# Loop through each sensitive feature combination
for combo in sensitive_combinations:
#create mask for rows that match the combination
mask = True
for i, feat in enumerate(sensitive_features):
mask = mask & (dataset[feat] == combo[i])
#count combinations per target
count = dataset.loc[mask, target_column].value_counts()
counts[combo] = count.to_dict()
result = pd.DataFrame.from_dict(counts, orient='index').fillna(0)
#add columns count & sensitive feautures
result['count'] = result.sum(axis=1)
result['sensitive_features'] = result.index.map(lambda x: ', '.join(map(str, x)))
#reorder columns
result = result[['sensitive_features', 'count'] + list(dataset[target_column].unique())]
return result
def plot_stacked_barchart(sensitive_feature_counts_table):
""""creates a plotly stacked bar chart to display the distribution of class labels per sensitive subgroup"""
#table always looks the same (sensitive features (vary in amount), count, class labels (vary in amount))
list_of_labels = list(sensitive_feature_counts_table.columns)
list_of_labels.remove('sensitive_features')
list_of_labels.remove('count')
fig = px.bar(sensitive_feature_counts_table, x=list_of_labels, y='sensitive_features', barmode='stack', title='Sensitive Feature Combinations vs Target Column', orientation='h')
fig.update_xaxes(title_text='Count')
return fig