From 9bededd7314d19c796c4dac40c9ac9e624d88aec Mon Sep 17 00:00:00 2001 From: Linus Hamlin <78953007+lilinus@users.noreply.github.com> Date: Wed, 11 Dec 2024 18:56:53 +0100 Subject: [PATCH] Fix TensorExtensions.StdDev (#110392) * Fix TensorExtensions.StdDev * Add constraint to ref * Revert "Add constraint to ref" This reverts commit f740f503a9ef0b763f71442fd8b4e578142f9c83. * Revert "Fix TensorExtensions.StdDev" This reverts commit c21298471fe1ea272930c3e5427bf986dba275b6. * Use pow method * Use existing variable --- .../Numerics/Tensors/netcore/TensorExtensions.cs | 12 +++++++++++- .../System.Numerics.Tensors/tests/TensorTests.cs | 2 +- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorExtensions.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorExtensions.cs index eec1e295e3c77..1274b20b8cf9e 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorExtensions.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorExtensions.cs @@ -3521,7 +3521,17 @@ public static T StdDev(in ReadOnlyTensorSpan x) TensorPrimitives.Abs(output, output); TensorPrimitives.Pow((ReadOnlySpan)output, T.CreateChecked(2), output); T sum = TensorPrimitives.Sum((ReadOnlySpan)output); - return T.CreateChecked(sum / T.CreateChecked(x._shape._memoryLength)); + T variance = sum / T.CreateChecked(x._shape._memoryLength); + + if (typeof(T) == typeof(float)) + { + return T.CreateChecked(MathF.Sqrt(float.CreateChecked(variance))); + } + if (typeof(T) == typeof(double)) + { + return T.CreateChecked(Math.Sqrt(double.CreateChecked(variance))); + } + return T.Pow(variance, T.CreateChecked(0.5)); } #endregion diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs b/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs index 96e2e9ede1313..a418230aff14a 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorTests.cs @@ -1113,7 +1113,7 @@ public static float StdDev(float[] values) { sum += MathF.Pow(values[i] - mean, 2); } - return sum / values.Length; + return MathF.Sqrt(sum / values.Length); } [Fact]