Skip to content

Commit

Permalink
Merge pull request tensorly#252 from merajhashemi/mh-reg
Browse files Browse the repository at this point in the history
Support multi-dimensional output for CP regression
  • Loading branch information
aarmey authored Jul 13, 2024
2 parents a21f844 + 6cb1d77 commit 59c1126
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 27 deletions.
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@ scipy
pytest
pytest-cov
pytest-randomly
python_version >= "3.6"
91 changes: 66 additions & 25 deletions tensorly/regression/cp_regression.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import numpy as np
import math
from ..base import partial_tensor_to_vec, partial_unfold
from ..tenalg import khatri_rao
from ..cp_tensor import cp_to_tensor, cp_to_vec
Expand All @@ -20,8 +20,8 @@ class CPRegressor:
rank of the CP decomposition of the regression weights
tol : float
convergence value
reg_W : int, optional, default is 1
regularisation on the weights
reg_W : float, optional, default is 1
l2 regularisation constant for the regression weights (:math:`reg_W * \sum_i ||factors[i]||_F^2`)
n_iter_max : int, optional, default is 100
maximum number of iteration
random_state : None, int or RandomState, optional, default is None
Expand Down Expand Up @@ -68,9 +68,8 @@ def fit(self, X, y):
Parameters
----------
X : ndarray
tensor data of shape (n_samples, N1, ..., NS)
y : 1D-array of shape (n_samples, )
X : tensor data of shape (n_samples, I_1, ..., I_p)
y : tensor of shape (n_samples, O_1, ..., O_q)
labels associated with each sample
Returns
Expand All @@ -79,12 +78,12 @@ def fit(self, X, y):
"""
rng = T.check_random_state(self.random_state)

# Initialise randomly the weights
# Initialise the weights randomly
W = []
for i in range(
1, T.ndim(X)
): # The first dimension of X is the number of samples
for i in range(1, T.ndim(X)): # The first dimension is the number of samples
W.append(T.tensor(rng.randn(X.shape[i], self.weight_rank), **T.context(X)))
for i in range(1, T.ndim(y)):
W.append(T.tensor(rng.randn(y.shape[i], self.weight_rank), **T.context(X)))

# Norm of the weight tensor at each iteration
norm_W = []
Expand All @@ -93,19 +92,48 @@ def fit(self, X, y):
for iteration in range(self.n_iter_max):
# Optimise each factor of W
for i in range(len(W)):
phi = T.reshape(
T.dot(
partial_unfold(X, i, skip_begin=1), khatri_rao(W, skip_matrix=i)
),
(X.shape[0], -1),
)
inv_term = T.dot(T.transpose(phi), phi) + self.reg_W * T.tensor(
np.eye(phi.shape[1]), **T.context(X)
)
W[i] = T.reshape(
T.solve(inv_term, T.dot(T.transpose(phi), y)),
(X.shape[i + 1], self.weight_rank),
)
if i < T.ndim(X) - 1:
X_unfolded = partial_unfold(X, i, skip_begin=1)
phi = T.dot(
X_unfolded,
T.reshape(
khatri_rao(W, skip_matrix=i), (X_unfolded.shape[-1], -1)
),
)
phi = T.transpose(
T.reshape(
phi, (X.shape[0], X.shape[i + 1], -1, self.weight_rank)
),
(0, 2, 1, 3),
)
phi = T.reshape(phi, (-1, X.shape[i + 1] * self.weight_rank))
y_reshaped = T.reshape(y, (-1,))
inv_term = T.dot(T.transpose(phi), phi) + self.reg_W * T.eye(
phi.shape[1], **T.context(X)
)
W[i] = T.reshape(
T.solve(inv_term, T.dot(T.transpose(phi), y_reshaped)),
(-1, self.weight_rank),
)
else:
X_unfolded = partial_tensor_to_vec(X, skip_begin=1)
phi = T.dot(
X_unfolded,
T.reshape(
khatri_rao(W, skip_matrix=i), (X_unfolded.shape[-1], -1)
),
)
phi = T.reshape(phi, (-1, self.weight_rank))
y_reshaped = T.reshape(
T.moveaxis(y, i - T.ndim(X) + 2, -1),
(-1, y.shape[i - T.ndim(X) + 2]),
)
inv_term = T.dot(T.transpose(phi), phi) + self.reg_W * T.eye(
phi.shape[1], **T.context(X)
)
W[i] = T.transpose(
T.solve(inv_term, T.dot(T.transpose(phi), y_reshaped))
)

weight_tensor_ = cp_to_tensor((weights, W))
norm_W.append(T.norm(weight_tensor_, 2))
Expand Down Expand Up @@ -134,6 +162,19 @@ def predict(self, X):
Parameters
----------
X : ndarray
tensor data of shape (n_samples, N1, ..., NS)
tensor data of shape (n_samples, I_1, ..., I_p)
"""
return T.dot(partial_tensor_to_vec(X), self.vec_W_)
out_shape = (-1, *self.weight_tensor_.shape[T.ndim(X) - 1 :])
if T.ndim(self.weight_tensor_) > T.ndim(X) - 1:
weight_shape = (
-1,
int(math.prod(self.weight_tensor_.shape[T.ndim(X) - 1 :])),
)
else:
weight_shape = (-1,)
return T.reshape(
T.dot(
partial_tensor_to_vec(X), T.reshape(self.weight_tensor_, weight_shape)
),
out_shape,
)
35 changes: 34 additions & 1 deletion tensorly/regression/tests/test_cp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from ...base import tensor_to_vec, partial_tensor_to_vec
from ...metrics.regression import RMSE
from ... import backend as T
from ...testing import assert_
from ...random import random_cp
from ...testing import assert_, assert_allclose


def test_CPRegressor():
Expand Down Expand Up @@ -52,3 +53,35 @@ def test_CPRegressor():
estimator.weight_rank == 5,
msg="set_params did not correctly set the given parameters",
)


def test_multidim_CPRegressor():
tol = 1e-3
rng = T.check_random_state(1234)

regression_weights = random_cp(
shape=(12, 5, 4, 4, 3), rank=3, full=True, random_state=rng
)
X = T.randn((1200, 12, 5, 4), seed=rng)
y = T.reshape(
T.dot(partial_tensor_to_vec(X), T.reshape(regression_weights, (-1, 4 * 3))),
(-1, 4, 3),
)
X_train = X[:1000]
X_test = X[1000:]
y_train = y[:1000]
y_test = y[1000:]

estimator = CPRegressor(
weight_rank=3, tol=1e-8, reg_W=0.0, n_iter_max=200, verbose=True
)
estimator.fit(X_train, y_train)
y_pred = estimator.predict(X_test)
error = RMSE(y_test, y_pred)
assert_(error <= tol, msg=f"CP Regressor : RMSE is too large, {error} > {tol}")
assert_allclose(
estimator.weight_tensor_,
regression_weights,
atol=tol,
err_msg="CPRegressor did not converge to the correct weights",
)

0 comments on commit 59c1126

Please sign in to comment.