From a0cbf2c19705209683da3463ce8893d7c3d65c08 Mon Sep 17 00:00:00 2001 From: Simon Parten Date: Wed, 25 Dec 2024 22:15:50 +0100 Subject: [PATCH] outer array product --- vecxt/js-native/src/array.scala | 17 +++++++++++++++++ vecxt/jvm/src/arrays.scala | 24 ++++++++++++++++++++++++ vecxt/test/src/array.test.scala | 18 ++++++++++++++++++ 3 files changed, 59 insertions(+) diff --git a/vecxt/js-native/src/array.scala b/vecxt/js-native/src/array.scala index 3dff968..5564786 100644 --- a/vecxt/js-native/src/array.scala +++ b/vecxt/js-native/src/array.scala @@ -71,6 +71,23 @@ object JsNativeDoubleArrays: extension (vec: NArray[Double]) + inline def outer(other: NArray[Double])(using ClassTag[Double]): Matrix[Double] = + val n = vec.length + val m = other.length + val out: NArray[Double] = NArray.ofSize[Double](n * m) + + var i = 0 + while i < n do + var j = 0 + while j < m do + out(j * n + i) = vec(i) * other(j) + j = j + 1 + end while + i = i + 1 + end while + Matrix[Double](out, (n, m))(using BoundsCheck.DoBoundsCheck.no) + end outer + inline def <(num: Double): NArray[Boolean] = logicalIdx((a, b) => a < b, num) diff --git a/vecxt/jvm/src/arrays.scala b/vecxt/jvm/src/arrays.scala index cd004f2..20c287a 100644 --- a/vecxt/jvm/src/arrays.scala +++ b/vecxt/jvm/src/arrays.scala @@ -18,6 +18,7 @@ package vecxt import scala.util.chaining.* import vecxt.BoundsCheck.BoundsCheck +import vecxt.matrix.Matrix import dev.ludovic.netlib.blas.JavaBLAS.getInstance as blas import jdk.incubator.vector.ByteVector @@ -473,6 +474,29 @@ object arrays: ranks end elementRanks + inline def outer(other: Array[Double])(using ClassTag[Double]): Matrix[Double] = + val n = vec.length + val m = other.length + val out = new Array[Double](n * m) + + var j = 0 + while j < m do + var i = 0 + val tmp = DoubleVector.broadcast(spd, other(j)) + while i < spd.loopBound(n) do + DoubleVector.fromArray(spd, vec, i).mul(tmp).intoArray(out, j * n + i) + i = i + spdl + end while + + while i < n do + out(j * n + i) = vec(i) * other(j) + i = i + 1 + end while + j = j + 1 + end while + Matrix(out, (n, m))(using BoundsCheck.DoBoundsCheck.no) + end outer + def variance: Double = // https://www.cuemath.com/sample-variance-formula/ val μ = vec.mean diff --git a/vecxt/test/src/array.test.scala b/vecxt/test/src/array.test.scala index 5bd0fde..d8771ca 100644 --- a/vecxt/test/src/array.test.scala +++ b/vecxt/test/src/array.test.scala @@ -223,6 +223,24 @@ class ArrayExtensionSuite extends munit.FunSuite: assertEquals(v_idx3.trues, 1) } + test("outer product".only) { + val v1 = NArray[Double](1.0, 2.0, 3.0) + val v2 = NArray[Double](4.0, 5.0) + val outer = v1.outer(v2) + val shouldBe = Matrix.fromRows[Double]( + NArray( + NArray(4.0, 5.0), + NArray(8.0, 10.0), + NArray(12.0, 15.0) + ) + ) + assertEquals(outer.rows, 3) + assertEquals(outer.cols, 2) + + assertVecEquals(outer.raw, shouldBe.raw) + + } + test("norm") { assertEqualsDouble(v_fill.norm, Math.sqrt(1 + 4 + 9 + 16), 0.00001) }