From 982e4c4202e7c6c1196c59c228fd12903b79d32e Mon Sep 17 00:00:00 2001 From: Morgan Pihl Date: Tue, 20 Dec 2022 19:23:25 +0100 Subject: [PATCH] replaces numpy sqrt method with pytensor equivalent (#6405) --- pymc/distributions/transforms.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 023ff57679..c4e0d6456f 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -333,8 +333,8 @@ def log_jac_det(self, value, *rv_inputs): def extend_axis(array, axis): n = array.shape[axis] + 1 sum_vals = array.sum(axis, keepdims=True) - norm = sum_vals / (np.sqrt(n) + n) - fill_val = norm - sum_vals / np.sqrt(n) + norm = sum_vals / (at.sqrt(n) + n) + fill_val = norm - sum_vals / at.sqrt(n) out = at.concatenate([array, fill_val], axis=axis) return out - norm @@ -346,8 +346,8 @@ def extend_axis_rev(array, axis): n = array.shape[normalized_axis] last = at.take(array, [-1], axis=normalized_axis) - sum_vals = -last * np.sqrt(n) - norm = sum_vals / (np.sqrt(n) + n) + sum_vals = -last * at.sqrt(n) + norm = sum_vals / (at.sqrt(n) + n) slice_before = (slice(None, None),) * normalized_axis return array[slice_before + (slice(None, -1),)] + norm