Skip to content

Commit

Permalink
Merge pull request #18 from mathematicalmichael/develop
Browse files Browse the repository at this point in the history
fix norm bug
  • Loading branch information
mathematicalmichael authored Nov 3, 2020
2 parents 1a17f3b + 59fb749 commit 155dbd4
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/mud/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def mynorm(X, mat):
"""
Y = (np.linalg.inv(mat) @ X)
result = np.sum(X * Y, axis=0)
return result[0]
return result


def full_functional(operator, inputs, data, initial_mean, initial_cov, observed_mean=0, observed_cov=1):
Expand Down
12 changes: 6 additions & 6 deletions tests/test_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@

class TestNorm(unittest.TestCase):

def test_identity_induced_norm(self):
def test_identity_induced_norm_on_vector(self):
# Arrange
X = np.random.rand(2,1)
X = np.random.rand(2,1) # single vector
mat = np.eye(2)

# Act
result = mdn.mynorm(X, mat)
check = np.linalg.norm(X)**2
check = np.linalg.norm(X, axis=0)**2

# Assert
assert isinstance(result,float)
self.assertAlmostEqual(result, check, 12)
assert isinstance(result, np.ndarray)
self.assertAlmostEqual(result[0], check[0], 12)

def test_scaled_identity_induced_norm(self):
# iterate over a few scaling factors
Expand All @@ -36,7 +36,7 @@ def test_scaled_identity_induced_norm(self):
check = np.linalg.norm(X)**2

# Assert
self.assertAlmostEqual(result, check, 12)
self.assertAlmostEqual(result[0], check, 12)

class TestFunctionals_2to1(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit 155dbd4

Please sign in to comment.