Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
drv850 committed Nov 4, 2024
1 parent 11d22f8 commit 8dc344b
Showing 1 changed file with 9 additions and 15 deletions.
24 changes: 9 additions & 15 deletions GenNet_utils/Interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,11 @@ def get_DeepExplainer_scores(args):

xval = xval[0]
xtest = xtest[0]

yval = yval.flatten()
ytest = ytest.flatten()

xval = xval if args.regression else xval[yval==0]
xtest = xtest if args.regression else xtest[ytest==1]
xval = xval if args.regression else xval[yval==0,:]
xtest = xtest if args.regression else xtest[ytest==1,:]

explainer = shap.DeepExplainer((model.input, model.output), )
print("Created explainer")
Expand Down Expand Up @@ -168,23 +169,17 @@ def get_DFIM_scores(args):

xval = xval[0]
xtest = xtest[0]
yval = yval.flatten()
ytest = ytest.flatten()

print("xval shape", xval.shape)
print("yval shape", yval.shape)


print("xtest shape", xtest.shape)
print("ytestl shape", ytest.shape)


if np.unique(np.array(ytest)).shape[0] > 2:
args.regression = True
else:
args.regression = False


xval = xval if args.regression else xval[yval==0]
xtest = xtest if args.regression else xtest[ytest==1]
xval = xval if args.regression else xval[yval==0,:]
xtest = xtest if args.regression else xtest[ytest==1,:]

explainer = shap.DeepExplainer((model.input, model.output), xval)
print("Created explainer")
Expand Down Expand Up @@ -246,8 +241,7 @@ def get_pathexplain_scores(args):

yval = yval.flatten()
ytest = ytest.flatten()

print("Shapes",xval.shape, xtest.shape)


if np.unique(np.array(ytest)).shape[0] > 2:
args.regression = True
Expand Down

0 comments on commit 8dc344b

Please sign in to comment.