Skip to content

Commit

Permalink
outer array product
Browse files Browse the repository at this point in the history
  • Loading branch information
Quafadas committed Dec 25, 2024
1 parent 0c6c843 commit a0cbf2c
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 0 deletions.
17 changes: 17 additions & 0 deletions vecxt/js-native/src/array.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,23 @@ object JsNativeDoubleArrays:

extension (vec: NArray[Double])

inline def outer(other: NArray[Double])(using ClassTag[Double]): Matrix[Double] =
val n = vec.length
val m = other.length
val out: NArray[Double] = NArray.ofSize[Double](n * m)

var i = 0
while i < n do
var j = 0
while j < m do
out(j * n + i) = vec(i) * other(j)
j = j + 1
end while
i = i + 1
end while
Matrix[Double](out, (n, m))(using BoundsCheck.DoBoundsCheck.no)
end outer

inline def <(num: Double): NArray[Boolean] =
logicalIdx((a, b) => a < b, num)

Expand Down
24 changes: 24 additions & 0 deletions vecxt/jvm/src/arrays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package vecxt
import scala.util.chaining.*

import vecxt.BoundsCheck.BoundsCheck
import vecxt.matrix.Matrix

import dev.ludovic.netlib.blas.JavaBLAS.getInstance as blas
import jdk.incubator.vector.ByteVector
Expand Down Expand Up @@ -473,6 +474,29 @@ object arrays:
ranks
end elementRanks

inline def outer(other: Array[Double])(using ClassTag[Double]): Matrix[Double] =
val n = vec.length
val m = other.length
val out = new Array[Double](n * m)

var j = 0
while j < m do
var i = 0
val tmp = DoubleVector.broadcast(spd, other(j))
while i < spd.loopBound(n) do
DoubleVector.fromArray(spd, vec, i).mul(tmp).intoArray(out, j * n + i)
i = i + spdl
end while

while i < n do
out(j * n + i) = vec(i) * other(j)
i = i + 1
end while
j = j + 1
end while
Matrix(out, (n, m))(using BoundsCheck.DoBoundsCheck.no)
end outer

def variance: Double =
// https://www.cuemath.com/sample-variance-formula/
val μ = vec.mean
Expand Down
18 changes: 18 additions & 0 deletions vecxt/test/src/array.test.scala
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,24 @@ class ArrayExtensionSuite extends munit.FunSuite:
assertEquals(v_idx3.trues, 1)
}

test("outer product".only) {
val v1 = NArray[Double](1.0, 2.0, 3.0)
val v2 = NArray[Double](4.0, 5.0)
val outer = v1.outer(v2)
val shouldBe = Matrix.fromRows[Double](
NArray(
NArray(4.0, 5.0),
NArray(8.0, 10.0),
NArray(12.0, 15.0)
)
)
assertEquals(outer.rows, 3)
assertEquals(outer.cols, 2)

assertVecEquals(outer.raw, shouldBe.raw)

}

test("norm") {
assertEqualsDouble(v_fill.norm, Math.sqrt(1 + 4 + 9 + 16), 0.00001)
}
Expand Down

0 comments on commit a0cbf2c

Please sign in to comment.