Skip to content

Commit

Permalink
Merge pull request #21 from mathematicalmichael/feature/shape-handling
Browse files Browse the repository at this point in the history
enforce ravel
  • Loading branch information
mathematicalmichael authored Nov 11, 2020
2 parents 60faaab + e981bed commit c05fca0
Showing 1 changed file with 12 additions and 15 deletions.
27 changes: 12 additions & 15 deletions src/mud/funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,15 @@ def mud_sol(A, b, y=None, mean=None, cov=None, data_cov=None):
if mean is None: mean = np.zeros((A.shape[1],1))
if y is None: y = np.zeros((A.shape[0],1))

x = y - b - A@mean
z = y.ravel() - b.ravel() - (A@mean).ravel()
z = z.reshape(-1,1)

# compute once for re-use
pre = A@cov@A.T
ipc = np.linalg.pinv(pre)

# using `makeRi` would waste FLOPS since we already computed `ipc`
# idc = np.linalg.inv(data_cov)
# Ri = np.linalg.inv(cov) - A.T @ ipc @ A
# up_cov = np.linalg.inv(A.T@idc@A + Ri)
# update = up_cov @ A.T @ idc
# mud_point = mean.ravel() + (update @ x).ravel()

mud_point = mean.ravel() + (cov@A.T@ipc@x).ravel()
update = cov@A.T@ipc
mud_point = mean.ravel() + (update@z).ravel()
return mud_point.reshape(-1,1)


Expand All @@ -144,8 +140,8 @@ def mud_sol_alt(A, b, y=None, mean=None, cov=None, data_cov=None):
if mean is None: mean = np.zeros((A.shape[1],1))
if y is None: y = np.zeros((A.shape[0],1))

x = y - b - A@mean
x = x.reshape(-1,1)
z = y.ravel() - b.ravel() - (A@mean).ravel()
z = z.reshape(-1,1)

# compute once for re-use
idc = np.linalg.inv(data_cov)
Expand All @@ -156,7 +152,7 @@ def mud_sol_alt(A, b, y=None, mean=None, cov=None, data_cov=None):
# Form derived via Hua's identity + Woodbury
up_cov = cov - cov@A.T@ipc@(pred_cov - data_cov)@ipc@A@cov
update = up_cov @ A.T @ idc
mud_point = mean.ravel() + (update @ x).ravel()
mud_point = mean.ravel() + (update @ z).ravel()

# mud_point = mean.ravel() + (cov@A.T@ipc@x).ravel()
return mud_point.reshape(-1,1)
Expand All @@ -168,11 +164,12 @@ def map_sol(A, b, y=None, mean=None, cov=None, data_cov=None, w=1):
if mean is None: mean = np.zeros((A.shape[1],1))
if y is None: y = np.zeros((A.shape[0],1))

x = y - b - A@mean
x = x.reshape(-1,1)
z = y.ravel() - b.ravel() - (A@mean).ravel()
z = z.reshape(-1,1)

precision = np.linalg.inv(A.T@np.linalg.inv(data_cov)@A + w*np.linalg.inv(cov))
update = precision@A.T@np.linalg.inv(data_cov)
map_point = mean.ravel() + (update @ x).ravel()
map_point = mean.ravel() + (update @ z).ravel()
return map_point.reshape(-1,1)


Expand Down

0 comments on commit c05fca0

Please sign in to comment.