diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 023ff57679..c4e0d6456f 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -333,8 +333,8 @@ def log_jac_det(self, value, *rv_inputs): def extend_axis(array, axis): n = array.shape[axis] + 1 sum_vals = array.sum(axis, keepdims=True) - norm = sum_vals / (np.sqrt(n) + n) - fill_val = norm - sum_vals / np.sqrt(n) + norm = sum_vals / (at.sqrt(n) + n) + fill_val = norm - sum_vals / at.sqrt(n) out = at.concatenate([array, fill_val], axis=axis) return out - norm @@ -346,8 +346,8 @@ def extend_axis_rev(array, axis): n = array.shape[normalized_axis] last = at.take(array, [-1], axis=normalized_axis) - sum_vals = -last * np.sqrt(n) - norm = sum_vals / (np.sqrt(n) + n) + sum_vals = -last * at.sqrt(n) + norm = sum_vals / (at.sqrt(n) + n) slice_before = (slice(None, None),) * normalized_axis return array[slice_before + (slice(None, -1),)] + norm