Skip to content

Commit

Permalink
Initial draft
Browse files Browse the repository at this point in the history
  • Loading branch information
aarmey committed Aug 16, 2024
1 parent 78a45f6 commit cd041b3
Showing 1 changed file with 1 addition and 36 deletions.
37 changes: 1 addition & 36 deletions tensorly/decomposition/_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,23 +102,7 @@ def initialize_cp(
elif isinstance(init, (tuple, list, CPTensor)):
# TODO: Test this
try:
if normalize_factors is True:
warnings.warn(
"It is not recommended to initialize a tensor with normalizing. Consider normalizing the tensor before using this function"
)

kt = CPTensor(init)
weights, factors = kt

if tl.all(weights == 1):
kt = CPTensor((None, factors))
else:
weights_avg = tl.prod(weights) ** (1.0 / tl.shape(weights)[0])
for i in range(len(factors)):
factors[i] = factors[i] * weights_avg
kt = CPTensor((None, factors))

return kt
except ValueError:
raise ValueError(
"If initialization method is a mapping, then it must "
Expand Down Expand Up @@ -242,7 +226,6 @@ def parafac(
sparsity=None,
l2_reg=0,
mask=None,
cvg_criterion="abs_rec_error",
fixed_modes=None,
svd_mask_repeats=5,
linesearch=False,
Expand Down Expand Up @@ -282,10 +265,6 @@ def parafac(
the values are missing and 1 everywhere else. Note: if tensor is
sparse, then mask should also be sparse with a fill value of 1 (or
True). Allows for missing values [2]_
cvg_criterion : {'abs_rec_error', 'rec_error'}, optional
Stopping criterion for ALS, works if `tol` is not None.
If 'rec_error', ALS stops at current iteration if ``(previous rec_error - current rec_error) < tol``.
If 'abs_rec_error', ALS terminates when `|previous rec_error - current rec_error| < tol`.
sparsity : float or int
If `sparsity` is not None, we approximate tensor as a sum of low_rank_component and sparse_component, where low_rank_component = cp_to_tensor((weights, factors)). `sparsity` denotes desired fraction or number of non-zero elements in the sparse_component of the `tensor`.
fixed_modes : list, default is None
Expand Down Expand Up @@ -511,14 +490,7 @@ def parafac(
f"iteration {iteration}, reconstruction error: {rec_error}, decrease = {rec_error_decrease}, unnormalized = {unnorml_rec_error}"
)

if cvg_criterion == "abs_rec_error":
stop_flag = tl.abs(rec_error_decrease) < tol
elif cvg_criterion == "rec_error":
stop_flag = rec_error_decrease < tol
else:
raise TypeError("Unknown convergence criterion")

if stop_flag:
if rec_error_decrease < tol:
if verbose:
print(f"PARAFAC converged after {iteration} iterations")
break
Expand Down Expand Up @@ -796,10 +768,6 @@ class CP(DecompositionMixin):
the values are missing and 1 everywhere else. Note: if tensor is
sparse, then mask should also be sparse with a fill value of 1 (or
True). Allows for missing values [2]_
cvg_criterion : {'abs_rec_error', 'rec_error'}, optional
Stopping criterion for ALS, works if `tol` is not None.
If 'rec_error', ALS stops at current iteration if ``(previous rec_error - current rec_error) < tol``.
If 'abs_rec_error', ALS terminates when `|previous rec_error - current rec_error| < tol`.
sparsity : float or int
If `sparsity` is not None, we approximate tensor as a sum of low_rank_component and sparse_component, where low_rank_component = cp_to_tensor((weights, factors)). `sparsity` denotes desired fraction or number of non-zero elements in the sparse_component of the `tensor`.
fixed_modes : list, default is None
Expand Down Expand Up @@ -850,7 +818,6 @@ def __init__(
sparsity=None,
l2_reg=0,
mask=None,
cvg_criterion="abs_rec_error",
fixed_modes=None,
svd_mask_repeats=5,
linesearch=False,
Expand All @@ -868,7 +835,6 @@ def __init__(
self.sparsity = sparsity
self.l2_reg = l2_reg
self.mask = mask
self.cvg_criterion = cvg_criterion
self.fixed_modes = fixed_modes
self.svd_mask_repeats = svd_mask_repeats
self.linesearch = linesearch
Expand Down Expand Up @@ -901,7 +867,6 @@ def fit_transform(self, tensor):
sparsity=self.sparsity,
l2_reg=self.l2_reg,
mask=self.mask,
cvg_criterion=self.cvg_criterion,
fixed_modes=self.fixed_modes,
svd_mask_repeats=self.svd_mask_repeats,
linesearch=self.linesearch,
Expand Down

0 comments on commit cd041b3

Please sign in to comment.