You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi was playing around with the ppi_distribution_label_shift_ci function and was supplying dummy values when I encountered an exception. I'm not very sure if I defined the nu vector correctly as I'm not very sure what is it for and how to define it, would appreciate if you can clarify as well. Thank you!
import numpy as np
from ppi_py import ppi_distribution_label_shift_ci
# True labels
Y = np.array([0, 1, 0, 1, 0])
# Predicted labels for labeled data
Yhat = np.array([0, 1, 1, 1, 0])
# Predicted labels for unlabeled data
Yhat_unlabeled = np.array([0, 0, 1, 1, 1, 0, 1])
# Number of classes
K = 2
nu = np.array([0, 1])
# Calling the function
result = ppi_distribution_label_shift_ci(Y, Yhat, Yhat_unlabeled, K, nu)
ValueError Traceback (most recent call last) in <cell line: 19>()
17
18 # Calling the function
---> 19 result = ppi_distribution_label_shift_ci(Y, Yhat, Yhat_unlabeled, K, nu)
20 print("Confidence Interval for class 1 probability:", result)
4 frames /usr/local/lib/python3.10/dist-packages/ppi_py/ppi.py in ppi_distribution_label_shift_ci(Y, Yhat, Yhat_unlabeled, K, nu, alpha, delta, return_counts)
1206 budget_split = 0.999999
1207 epsilon1 = max(
-> 1208 [
1209 linfty_binom(C.sum(axis=0)[k], K, budget_split * delta, Ahat[:, k])
1210 for k in range(K)
I tried with a larger N but still thrown the same error:
import numpy as np
# True labels
Y = np.array([1]*100000+[0]*100000)
# Predicted labels for labeled data
Yhat = np.array([1]*120000+[0]*80000)
# Predicted labels for unlabeled data
Yhat_unlabeled = np.array([1]*170000+[0]*30000)
# Number of classes
K = 2
nu = np.array([0, 1])
# Calling the function
result = ppi_distribution_label_shift_ci(Y, Yhat, Yhat_unlabeled, K, nu)
print("Confidence Interval for class 1 probability:", result)
Hi was playing around with the
ppi_distribution_label_shift_ci
function and was supplying dummy values when I encountered an exception. I'm not very sure if I defined thenu
vector correctly as I'm not very sure what is it for and how to define it, would appreciate if you can clarify as well. Thank you!The text was updated successfully, but these errors were encountered: