From 92b8a7b21a7e1ac083aba9c99d0fc81d3505394a Mon Sep 17 00:00:00 2001 From: Simon Parten Date: Wed, 25 Dec 2024 21:23:02 +0100 Subject: [PATCH] trace --- vecxt/jvm/src/rpt.scala | 1 + vecxt/src/doublematrix.scala | 14 ++++++++------ vecxt/test/src/matrix.test.scala | 6 ++++++ 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/vecxt/jvm/src/rpt.scala b/vecxt/jvm/src/rpt.scala index 0041c5e..5092a5c 100644 --- a/vecxt/jvm/src/rpt.scala +++ b/vecxt/jvm/src/rpt.scala @@ -9,6 +9,7 @@ import vecxt.reinsurance.Retentions.* f(X;retention, limit) = MIN(MAX(X - retention, 0), limit)) Note: mutates the input array + TODO: SIMD */ object rpt: extension (vec: Array[Double]) diff --git a/vecxt/src/doublematrix.scala b/vecxt/src/doublematrix.scala index b947a77..b2db415 100644 --- a/vecxt/src/doublematrix.scala +++ b/vecxt/src/doublematrix.scala @@ -1,15 +1,11 @@ package vecxt import vecxt.BoundsCheck.BoundsCheck -import vecxt.JsDoubleMatrix.* import vecxt.JvmDoubleMatrix.* -import vecxt.MatrixHelper.* -import vecxt.MatrixInstance.* -import vecxt.NativeDoubleMatrix.* import vecxt.arrays.* import vecxt.matrix.* -import vecxt.rangeExtender.MatrixRange.RangeExtender -import vecxt.rangeExtender.MatrixRange.range + +import vecxt.matrixUtil.diag object DoubleMatrix: @@ -19,6 +15,12 @@ object DoubleMatrix: inline def *:*=(d: Double): Unit = m.raw.multInPlace(d) + inline def trace = + if m.shape(0) != m.shape(1) then throw new IllegalArgumentException("Matrix must be square") + end if + m.diag.sum + end trace + // inline def >=(d: Double): Matrix[Boolean] = // Matrix[Boolean](m.raw >= d, m.shape)(using BoundsCheck.DoBoundsCheck.no) diff --git a/vecxt/test/src/matrix.test.scala b/vecxt/test/src/matrix.test.scala index 4a67137..59fc283 100644 --- a/vecxt/test/src/matrix.test.scala +++ b/vecxt/test/src/matrix.test.scala @@ -276,6 +276,12 @@ class MatrixExtensionSuite extends FunSuite: assertEquals(mat.raw(1), 7.0) } + test("trace") { + val mat = Matrix[Double](NArray(1.0, 2.0, 3.0, 4.0), (2, 2)) + println(mat.printMat) + assertEquals(mat.trace, 5.0) + } + test("Matrix column extraction") { val array = NArray[Double](1.0, 2.0, 3.0, 4.0) val matrix = Matrix[Double](array, (2, 2))