Skip to content

Commit

Permalink
add original tsr code
Browse files Browse the repository at this point in the history
  • Loading branch information
khairulislam committed Oct 15, 2023
1 parent 41706da commit 8f0da14
Show file tree
Hide file tree
Showing 9 changed files with 177 additions and 125 deletions.
11 changes: 8 additions & 3 deletions exp/exp_interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,10 @@ def interpret(self, dataloader):
end = datetime.now()
print(f'Experiment ended at {end}. Total time taken {end - start}.')
if self.args.tsr:
self.dump_results(results, f'tsr_{name}.csv')
if self.args.threshold == 0.55:
self.dump_results(results, f'tsr_{name}_orig.csv')
else:
self.dump_results(results, f'tsr_{name}.csv')
else:
self.dump_results(results, f'{name}.csv')

Expand All @@ -181,7 +184,8 @@ def evaluate_classifier(
self.args, explainer, inputs=inputs,
sliding_window_shapes=sliding_window_shapes,
strides=strides, baselines=baselines,
additional_forward_args=additional_forward_args
additional_forward_args=additional_forward_args,
threshold=self.args.threshold
)
else:
attr = compute_classifier_attr(
Expand Down Expand Up @@ -242,7 +246,8 @@ def evaluate_regressor(
self.args, explainer, inputs=inputs,
sliding_window_shapes=sliding_window_shapes,
strides=strides, baselines=baselines,
additional_forward_args=additional_forward_args
additional_forward_args=additional_forward_args,
threshold=self.args.threshold
)
else:
attr = compute_regressor_attr(
Expand Down
20 changes: 0 additions & 20 deletions interpret.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -287,26 +287,6 @@
")"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32, 24, 36, 7])"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"attr[0].shape"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
6 changes: 6 additions & 0 deletions result_classification.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,9 @@ flag val, acc 0.92301, f1 0.51507, auc 0.90794
classification_MICN_mimic_iii_mimic_ftMS_sl48_ll48_pl0_dm128_nh4_el2_dl1_df256_fc3_ebtimeF_dtTrue
flag test, acc 0.92866, f1 0.45695, auc 0.90359

classification_Crossformer_mimic_iii_mimic_ftMS_sl48_ll48_pl0_dm128_nh4_el2_dl1_df256_fc3_ebtimeF_dtTrue
flag val, acc 0.90126, f1 0.40420, auc 0.83894

classification_Crossformer_mimic_iii_mimic_ftMS_sl48_ll48_pl0_dm128_nh4_el2_dl1_df256_fc3_ebtimeF_dtTrue
flag test, acc 0.90735, f1 0.37537, auc 0.83755

18 changes: 18 additions & 0 deletions result_long_term_forecast.txt
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,21 @@ flag:val, mse:0.10076, mae:0.22866, rmse: 0.31743.
long_term_forecast_MICN_traffic_custom_ftS_sl96_ll12_pl24_dm128_nh4_el2_dl1_df256_fc3_ebtimeF_dtTrue
flag:test, mse:0.22352, mae:0.29335, rmse: 0.47278.

long_term_forecast_LightTS_electricity_custom_ftS_sl96_ll12_pl24_dm128_nh4_el2_dl1_df256_fc3_ebtimeF_dtTrue
flag:val, mse:0.31445, mae:0.41969, rmse: 0.56076.

long_term_forecast_LightTS_electricity_custom_ftS_sl96_ll12_pl24_dm128_nh4_el2_dl1_df256_fc3_ebtimeF_dtTrue
flag:test, mse:0.22642, mae:0.35058, rmse: 0.47584.

long_term_forecast_Crossformer_electricity_custom_ftS_sl96_ll12_pl24_dm128_nh4_el2_dl1_df256_fc3_ebtimeF_dtTrue
flag:val, mse:0.2201, mae:0.33177, rmse: 0.46915.

long_term_forecast_Crossformer_electricity_custom_ftS_sl96_ll12_pl24_dm128_nh4_el2_dl1_df256_fc3_ebtimeF_dtTrue
flag:test, mse:0.12923, mae:0.25335, rmse: 0.35949.

long_term_forecast_Crossformer_traffic_custom_ftS_sl96_ll12_pl24_dm128_nh4_el2_dl1_df256_fc3_ebtimeF_dtTrue
flag:val, mse:0.088908, mae:0.20552, rmse: 0.29817.

long_term_forecast_Crossformer_traffic_custom_ftS_sl96_ll12_pl24_dm128_nh4_el2_dl1_df256_fc3_ebtimeF_dtTrue
flag:test, mse:0.22583, mae:0.29375, rmse: 0.47521.

127 changes: 127 additions & 0 deletions result_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import pandas as pd

metric_map = {
'electricity': ['mae', 'mse'],
'traffic': ['mae', 'mse'],
'mimic_iii': ['auc', 'accuracy', 'cross_entropy']
}

datasets = ['electricity', 'traffic', 'mimic_iii']
models = ['DLinear', 'MICN', 'LSTM', 'Crossformer']
attr_methods = ['feature_ablation', 'occlusion', 'augmented_occlusion', 'feature_permutation', 'winIT']

short_form = {
'feature_ablation': 'FA',
'occlusion':'FO',
'augmented_occlusion': 'AFO',
'feature_permutation': 'FP',
'winIT': 'WinIT'
}

# for dataset in datasets:
# print(dataset)
# tsr_better_comp = 0
# tsr_better_suff = 0
# comp_count = suff_count = 0

# for attr_method in attr_methods:
# for model in models: # , 'Crossformer'
# df = pd.read_csv(f'results/{dataset}_{model}/{attr_method}.csv')
# df = df[df['area']==0.05]

# df_tsr = pd.read_csv(f'results/{dataset}_{model}/tsr_{attr_method}.csv')
# df_tsr = df_tsr[df_tsr['area']==0.05]

# for metric in metric_map[dataset]:
# [comp, suff] = df[df['metric']==metric][['comp', 'suff']].values[0]
# # print(comp, suff)

# [tsr_comp, tsr_suff] = df_tsr[df_tsr['metric']==metric][['comp', 'suff']].values[0]
# # print(tsr_comp, tsr_suff)

# if metric in ['accuracy', 'auc']:
# if comp > tsr_comp: tsr_better_comp +=1
# if suff < tsr_suff: tsr_better_suff += 1
# else:
# if comp < tsr_comp: tsr_better_comp +=1
# if suff > tsr_suff: tsr_better_suff += 1
# comp_count +=1
# suff_count +=1
# break

# print(f'TSR better for: comp {tsr_better_comp}, suff {tsr_better_suff}. Total cases: {suff_count}')
# print(f'TSR improve comprehensiveness on {100.0* tsr_better_comp/comp_count:0.4f}\%, \
# and sufficiency on {tsr_better_suff*100.0/suff_count:0.4f}\% cases.\n')

def reduce_df(df:pd.DataFrame):
# df[df['area']==0.05]
return df.groupby('metric').aggregate('mean')[['comp', 'suff']].reset_index()

def print_row(item):
print(f'& {item:0.4g} ', end='')

for dataset in datasets:
print(f'printing latex table for {dataset} dataset.\n')
print(f" & {' & '.join(models)} & {' & '.join(models)} \\\\ \\hline")

tsr_better_comp = 0
tsr_better_suff = 0
comp_count = suff_count = 0

for attr_method in attr_methods:
print(f'{short_form[attr_method]} ', end='')
for model in models:
df = pd.read_csv(f'results/{dataset}_{model}/{attr_method}.csv')
df = reduce_df(df)

for metric in metric_map[dataset]:
[comp, _] = df[df['metric']==metric][['comp', 'suff']].values[0]
if metric in ['auc', 'accuracy']:
comp = 1- comp
break
print_row(comp)

for model in models:
df = pd.read_csv(f'results/{dataset}_{model}/{attr_method}.csv')
df = reduce_df(df)
for metric in metric_map[dataset]:
[_, suff] = df[df['metric']==metric][['comp', 'suff']].values[0]
if metric in ['auc', 'accuracy']:
suff = 1- suff
break
print_row(suff)
print('\\\\')

print('\\hline')
for attr_method in attr_methods:
if attr_method == 'winIT': continue
print(f'WTSR+{short_form[attr_method]} ', end='')
for model in models:
df_tsr = pd.read_csv(f'results/{dataset}_{model}/tsr_{attr_method}.csv')
df_tsr = reduce_df(df_tsr)

for metric in metric_map[dataset]:
[tsr_comp, _] = df_tsr[df_tsr['metric']==metric][['comp', 'suff']].values[0]
if metric in ['auc', 'accuracy']:
tsr_comp = 1- tsr_comp
break

print_row(tsr_comp)

for model in models:
df_tsr = pd.read_csv(f'results/{dataset}_{model}/tsr_{attr_method}.csv')
df_tsr = reduce_df(df_tsr)

for metric in metric_map[dataset]:
[_, tsr_suff] = df_tsr[df_tsr['metric']==metric][['comp', 'suff']].values[0]
if metric in ['auc', 'accuracy']:
tsr_suff = 1- tsr_suff
break
print_row(tsr_suff)

print('\\\\')
print('\\hline\n\n')
# break



9 changes: 9 additions & 0 deletions results/electricity_Crossformer/tsr_feature_permutation.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
metric,area,comp,suff
mae,0.05,11.820683,11.788181
mae,0.075,13.516388,10.315534
mae,0.1,14.248765,9.522623
mae,0.15,15.422024,7.927388
mse,0.05,10.148638,9.057515
mse,0.075,12.760038,7.038268
mse,0.1,13.971034,6.098217
mse,0.15,16.04396,4.389713
9 changes: 9 additions & 0 deletions results/electricity_DLinear/tsr_integrated_gradients_orig.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
metric,area,comp,suff
mae,0.05,1.486402,12.089916
mae,0.075,1.663017,11.305623
mae,0.1,1.829393,10.807415
mae,0.15,2.307418,9.467468
mse,0.05,0.206225,9.598049
mse,0.075,0.239076,8.608364
mse,0.1,0.271157,7.933516
mse,0.15,0.375217,5.939165
51 changes: 0 additions & 51 deletions results/national_illness_Transformer/interpretations_test.csv

This file was deleted.

51 changes: 0 additions & 51 deletions results/national_illness_Transformer/interpretations_val.csv

This file was deleted.

0 comments on commit 8f0da14

Please sign in to comment.