Skip to content

Commit

Permalink
replaces numpy sqrt method with pytensor equivalent (#6405)
Browse files Browse the repository at this point in the history
  • Loading branch information
morganstrom authored Dec 20, 2022
1 parent f231d13 commit 982e4c4
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions pymc/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,8 @@ def log_jac_det(self, value, *rv_inputs):
def extend_axis(array, axis):
n = array.shape[axis] + 1
sum_vals = array.sum(axis, keepdims=True)
norm = sum_vals / (np.sqrt(n) + n)
fill_val = norm - sum_vals / np.sqrt(n)
norm = sum_vals / (at.sqrt(n) + n)
fill_val = norm - sum_vals / at.sqrt(n)

out = at.concatenate([array, fill_val], axis=axis)
return out - norm
Expand All @@ -346,8 +346,8 @@ def extend_axis_rev(array, axis):
n = array.shape[normalized_axis]
last = at.take(array, [-1], axis=normalized_axis)

sum_vals = -last * np.sqrt(n)
norm = sum_vals / (np.sqrt(n) + n)
sum_vals = -last * at.sqrt(n)
norm = sum_vals / (at.sqrt(n) + n)
slice_before = (slice(None, None),) * normalized_axis

return array[slice_before + (slice(None, -1),)] + norm
Expand Down

0 comments on commit 982e4c4

Please sign in to comment.