-
Notifications
You must be signed in to change notification settings - Fork 2
/
exact_sp.py
28 lines (24 loc) · 1.11 KB
/
exact_sp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import numpy as np
from tqdm import tqdm
def get_true_KNN(x_trn, x_tst):
N = x_trn.shape[0]
N_tst = x_tst.shape[0]
x_tst_knn_gt = np.zeros((N_tst, N))
for i_tst in tqdm(range(N_tst)):
dist_gt = np.zeros(N)
for i_trn in range(N):
dist_gt[i_trn] = np.linalg.norm(x_trn[i_trn, :] - x_tst[i_tst, :], 2)
x_tst_knn_gt[i_tst, :] = np.argsort(dist_gt)
return x_tst_knn_gt.astype(int)
def compute_single_unweighted_knn_class_shapley(x_trn, y_trn, x_tst_knn_gt, y_tst, K):
N = x_trn.shape[0]
N_tst = x_tst_knn_gt.shape[0]
sp_gt = np.zeros((N_tst, N))
for j in tqdm(range(N_tst)):
sp_gt[j, x_tst_knn_gt[j, -1]] = (y_trn[x_tst_knn_gt[j, -1]] == y_tst[j]) / N
for i in np.arange(N - 2, -1, -1):
sp_gt[j, x_tst_knn_gt[j, i]] = sp_gt[j, x_tst_knn_gt[j, i + 1]] + \
(int(y_trn[x_tst_knn_gt[j, i]] == y_tst[j]) -
int(y_trn[x_tst_knn_gt[j, i + 1]] == y_tst[j])) / K * min([K, i + 1]) / (
i + 1)
return sp_gt