Skip to content

Commit

Permalink
MNT api update for benchopt 1.4 (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
mathurinm authored Jun 4, 2024
1 parent 83da1e7 commit e510093
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 9 deletions.
8 changes: 4 additions & 4 deletions objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class Objective(BaseObjective):
min_benchopt_version = "1.3"
min_benchopt_version = "1.4"
name = "Ridge Regression"

parameters = {
Expand All @@ -21,13 +21,13 @@ def set_data(self, X, y, X_test=None, y_test=None):
self.X_test, self.y_test = X_test, y_test
self.n_features = self.X.shape[1]

def get_one_solution(self):
def get_one_result(self):
n_features = self.n_features
if self.fit_intercept:
n_features += 1
return np.zeros(n_features)
return dict(beta=np.zeros(n_features))

def compute(self, beta):
def evaluate_result(self, beta):
# compute residuals
test_loss = None
if self.X_test is not None:
Expand Down
2 changes: 1 addition & 1 deletion solvers/cd.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,4 @@ def sparse_cd(X_data, X_indices, X_indptr, y, lmbd, L, n_iter):
return w

def get_result(self):
return self.w
return dict(beta=self.w)
3 changes: 2 additions & 1 deletion solvers/glmnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,5 @@ def get_result(self):
as_matrix = robjects.r["as"]
coefs = np.array(as_matrix(results["beta"], "matrix"))
beta = coefs.flatten()
return np.r_[beta, results["a0"]] if self.fit_intercept else beta
beta = np.r_[beta, results["a0"]] if self.fit_intercept else beta
return dict(beta=beta)
2 changes: 1 addition & 1 deletion solvers/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,4 @@ def get_result(self):
beta = self.ridge.coef_.flatten()
if self.fit_intercept:
beta = np.r_[beta, self.ridge.intercept_]
return beta
return dict(beta=beta)
2 changes: 1 addition & 1 deletion solvers/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ def run(self, n_iter):
self.coef = coef

def get_result(self):
return self.coef
return dict(beta=self.coef)
2 changes: 1 addition & 1 deletion solvers/snapml.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ def run(self, n_iter):
self.coef = coef

def get_result(self):
return self.coef
return dict(beta=self.coef)

0 comments on commit e510093

Please sign in to comment.