-
Notifications
You must be signed in to change notification settings - Fork 0
/
complex_stiefel.py
127 lines (104 loc) · 4.67 KB
/
complex_stiefel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import numpy as np
from scipy.linalg import expm
from pymanopt.manifolds.manifold import EuclideanEmbeddedSubmanifold
from pymanopt.tools.multi import multiprod, multiherm, multihconj
class ComplexStiefel(EuclideanEmbeddedSubmanifold):
"""
Factory class for the Stiefel manifold. Instantiation requires the
dimensions n, p to be specified. Optional argument k allows the user to
optimize over the product of k Stiefels.
Elements are represented as n x p matrices (if k == 1), and as k x n x p
matrices if k > 1 (Note that this is different to manopt!).
"""
def __init__(self, n, p, k=1):
self._n = n
self._p = p
self._k = k
# Check that n is greater than or equal to p
if n < p or p < 1:
raise ValueError("Need n >= p >= 1. Values supplied were n = %d "
"and p = %d." % (n, p))
if k < 1:
raise ValueError("Need k >= 1. Value supplied was k = %d." % k)
if k == 1:
name = "Stiefel manifold St(%d, %d)" % (n, p)
elif k >= 2:
name = "Product Stiefel manifold St(%d, %d)^%d" % (n, p, k)
dimension = int(k*(2*n*p - p**2))
super().__init__(name, dimension)
@property
def typicaldist(self):
return np.sqrt(self._p * self._k)
def inner(self, X, G, H):
# Inner product (Riemannian metric) on the tangent space
# For the stiefel this is the Frobenius inner product.
return np.real(np.tensordot(np.conjugate(G), H, axes=G.ndim))
def proj(self, X, U):
return U - multiprod(X, multiherm(multiprod(multihconj(X), U)))
# TODO(nkoep): Implement the weingarten map instead.
def ehess2rhess(self, X, egrad, ehess, H):
XtG = multiprod(multihconj(X), egrad)
symXtG = multiherm(XtG)
HsymXtG = multiprod(H, symXtG)
return self.proj(X, ehess - HsymXtG)
# Retract to the Stiefel using the qr decomposition of X + G.
def retr(self, X, G):
if self._k == 1:
# # Calculate 'thin' qr decomposition of X + G
# q, r = np.linalg.qr(X + G)
# # Unflip any flipped signs
# XNew = np.dot(q, np.diag(np.sign(np.sign(np.diag(r)) + 0.5)))
U,_,Vh = np.linalg.svd(X+G,full_matrices=False)
XNew = U @ Vh
else:
XNew = X + G
# for i in range(self._k):
# q, r = np.linalg.qr(XNew[i])
# XNew[i] = np.dot(
# q, np.diag(np.sign(np.sign(np.diag(r)) + 0.5)))
for i in range(self._k):
U,_,Vh = np.linalg.svd(XNew[i],full_matrices=False)
XNew[i] = U @ Vh
return XNew
def norm(self, X, G):
# Norm on the tangent space of the Stiefel is simply the Euclidean
# norm.
return np.linalg.norm(G)
# Generate random Stiefel point using qr of random normally distributed
# matrix.
def rand(self):
if self._k == 1:
X = np.random.randn(self._n, self._p) +1j*np.random.randn(self._n, self._p)
q, r = np.linalg.qr(X)
return q
X = np.zeros((self._k, self._n, self._p),dtype=np.complex128)
for i in range(self._k):
X[i], r = np.linalg.qr(np.random.randn(self._n, self._p) +1j*np.random.randn(self._n, self._p))
return X
def randvec(self, X):
U = np.random.randn(*np.shape(X) +1j*np.shape(X))
U = self.proj(X, U)
U = U / np.linalg.norm(U)
return U
def transp(self, x1, x2, d):
return self.proj(x2, d)
def exp(self, X, U):
# TODO: Simplify these expressions.
if self._k == 1:
W = expm(np.bmat([[X.conj().T.dot(U), -U.conj().T.dot(U)],
[np.eye(self._p,dtype=np.complex128), X.conj().T.dot(U)]]))
Z = np.bmat([[expm(-X.conj().T.dot(U))], [np.zeros((self._p, self._p),dtype=np.complex128)]])
Y = np.bmat([X, U]).dot(W).dot(Z)
else:
Y = np.zeros(np.shape(X),dtype=np.complex128)
for i in range(self._k):
W = expm(np.bmat([[X[i].conj().T.dot(U[i]), -U[i].conj().T.dot(U[i])],
[np.eye(self._p), X[i].conj().T.dot(U[i])]]))
Z = np.bmat([[expm(-X[i].conj().T.dot(U[i]))],
[np.zeros((self._p, self._p),dtype=np.complex128)]])
Y[i] = np.bmat([X[i], U[i]]).dot(W).dot(Z)
return Y
def zerovec(self, X):
if self._k == 1:
return np.zeros((self._n, self._p),dtype=np.complex128)
return np.zeros((self._k, self._n, self._p),dtype=np.complex128)