Skip to content

Commit

Permalink
Resolve some comments
Browse files Browse the repository at this point in the history
  • Loading branch information
aarmey committed Dec 2, 2024
1 parent f3809d5 commit 2896e47
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions tensorly/utils/jointdiag.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def joint_matrix_diagonalization(
and Signal Processing (ICASSP 2006), vol. IV, pp. 1137-1140, Toulouse, France, May 2006.
Args:
X (_type_): n matrices, organized in tensor of dimension (k,k,n), for joint diagonalization.
X (_type_): n matrices, organized in a single tensor of dimension (k, k, n).
max_n_iter (int, optional): Maximum iteration number. Defaults to 50.
threshold (float, optional): Threshold for decrease in error indicating convergence. Defaults to 1e-10.
verbose (bool, optional): Output progress information during diagonalization. Defaults to False.
Expand All @@ -34,10 +34,10 @@ def joint_matrix_diagonalization(
Tensor: X after joint diagonalization.
Tensor: The transformation matrix resulting in the diagonalization.
"""
X = tl.tensor(X, **tl.context(X))
D = tl.shape(X)[0] # Dimension of square matrix slices
X = tl.copy()
matrix_dimension = tl.shape(X)[0] # Dimension of square matrix slices
assert tl.ndim(X) == 3, "Input must be a 3D tensor"
assert D == X.shape[1], "All slices must be square"
assert matrix_dimension == X.shape[1], "All matrices must be square."

# Initial error calculation
# Transpose is because np.tril operates on the last two dimensions
Expand All @@ -47,17 +47,17 @@ def joint_matrix_diagonalization(
print(f"Sweep # 0: e = {e:.3e}")

# Additional output parameters
Q_total = tl.eye(D)
Q_total = tl.eye(matrix_dimension)

for k in range(max_n_iter):
# loop over all pairs of slices
for p, q in combinations(range(D), 2):
for p, q in combinations(range(matrix_dimension), 2):
# Finds matrix slice with greatest variability among diagonal elements
d_ = X[p, p, :] - X[q, q, :]
h = tl.argmax(tl.abs(d_))

# List of indices
all_but_pq = list(set(range(D)) - set([p, q]))
all_but_pq = list(set(range(matrix_dimension)) - set([p, q]))

# Compute certain quantities
dh = d_[h]
Expand Down

0 comments on commit 2896e47

Please sign in to comment.