Skip to content

Commit

Permalink
that seems to work
Browse files Browse the repository at this point in the history
  • Loading branch information
Quafadas committed Oct 2, 2024
1 parent a89b610 commit 99ff031
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 1 deletion.
121 changes: 120 additions & 1 deletion vecxt/jvm/src/arrays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,15 @@ import scala.util.chaining.*
import jdk.incubator.vector.VectorMask
import jdk.incubator.vector.ByteVector
import jdk.incubator.vector.DoubleVector
import jdk.incubator.vector.IntVector
import jdk.incubator.vector.VectorSpecies
import jdk.incubator.vector.VectorShape
import jdk.incubator.vector.VectorOperators
import scala.compiletime.constValue

import vecxt.BoundsCheck.BoundsCheck
import scala.compiletime.ops.double
import scala.annotation.static

object arrays:

Expand Down Expand Up @@ -294,14 +299,128 @@ object arrays:
temp
end sum

// https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda#:~:text=A%20simple%20and%20common%20parallel%20algorithm#:~:text=A%20simple%20and%20common%20parallel%20algorithm
inline def cumsum: Unit =
var i = 1
val spd: VectorSpecies[java.lang.Double] = DoubleVector.SPECIES_PREFERRED
val spi: VectorSpecies[java.lang.Integer] =
VectorSpecies.of(java.lang.Integer.TYPE, VectorShape.forBitSize(spd.vectorBitSize() / 2))

def doubleLength = spd.length()
// println(doubleLength)
def intLength = spi.length()
// println(intLength)
// println(vec.length)

var dBound = Math.log(vec.length) / Math.log(2)
// println(dBound - 1)

var d: Int = 0
while d < (dBound - 1) do
var k: Int = 0
val dPow2 = Math.pow(2, d).toInt
val dPow2_1 = Math.pow(2, d + 1).toInt
// println("---------------------------------")
// println(s"--d loop $d")
// println(dPow2)
// println(dPow2_1)

while k < spi.loopBound((vec.length - 1)) do
val idxs = IntVector.broadcast(spi, k).addIndex(dPow2_1)

val x_idx = idxs.add(dPow2_1 - 1)
val part1Idx = idxs.add(dPow2 - 1)
val part2Idx = idxs.add(dPow2_1 - 1)

val mask = part2Idx.lt(vec.length - 1).cast(spd)
// println(s"indexes")
// println(s"k: $k")
// println(idxs.toArray().mkString(","))
// println(x_idx.toArray().mkString(","))
// println(part1Idx.toArray().mkString(","))
// println(part2Idx.toArray().mkString(","))

val part1 = DoubleVector.fromArray(spd, vec, 0, part1Idx.toArray(), 0, mask)
val part2 = DoubleVector.fromArray(spd, vec, 0, part2Idx.toArray(), 0, mask)

// println(s"new vecs")
// println(part1.toArray().mkString(","))
// println(part2.toArray().mkString(","))

// println(s"combined")
// println(part1.add(part2).toArray().mkString(","))
part1.add(part2).intoArray(vec, 0, x_idx.toArray(), 0, mask)
// println(s"new vec")
// println(vec.mkString(","))
k += intLength
end while
d += 1
end while
// println(Math.pow(2, dBound.toInt))

println("UP SWEEP COMPLETE")
println(vec.mkString("[", ",", "]"))
println("-----")

while d > 1 do
val dPow2 = Math.pow(2, d).toInt
val dPow2_2 = Math.pow(2, d - 2).toInt
val dPow2_1 = Math.pow(2, d - 1).toInt
println(s"d : $d")
var k = 0
while k < spi.loopBound(vec.length - 1) do
println(s"dPow2_1 : $dPow2_1")
val idxs = IntVector.broadcast(spi, k).addIndex(dPow2_1)
val idxsInsert = IntVector.broadcast(spi, k).addIndex(dPow2_1).add(dPow2_2).sub(1)
println("idxs")
println(idxs.sub(1).toArray().mkString(","))
println(idxsInsert.toArray().mkString(","))

val mask = idxs.compare(
VectorOperators.LT,
vec.length - 1
) // .and(idxs.lanewise(VectorOperators.GT, vec.length - 1)).and(idxs. (0)).cast(spd)

val mask2 = idxs.compare(VectorOperators.GT, 0)

val finalM = mask.and(mask2).cast(spd)

println("mask")
println(finalM.toArray().mkString(","))

val xtract = DoubleVector.fromArray(spd, vec, 0, idxs.sub(1).toArray(), 0, finalM)
println("xtract")
println(xtract.toArray().mkString(","))
val current = DoubleVector.fromArray(spd, vec, 0, idxsInsert.toArray(), 0, finalM)
println("current")
println(current.toArray().mkString(","))

current.add(xtract).intoArray(vec, 0, idxsInsert.toArray(), 0, finalM)
println("ITRT END")
println(vec.mkString(","))
println("----- END")

k += intLength
end while
d -= 1
end while

var i = Math.pow(2, dBound.toInt).toInt
while i < vec.length do
vec(i) = vec(i - 1) + vec(i)
i = i + 1
end while
end cumsum

def cumsum2: Array[Double] =
val vec2 = vec.clone()
var i = 1
while i < vec2.length do
vec2(i) = vec2(i - 1) + vec2(i)
i = i + 1
end while
vec2
end cumsum2

inline def dot(v1: Array[Double])(using inline boundsCheck: BoundsCheck): Double =
dimCheck(vec, v1)
blas.ddot(vec.length, vec, 1, v1, 1)
Expand Down
12 changes: 12 additions & 0 deletions vecxt/test/src/array.test.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,18 @@ class ArrayExtensionSuite extends munit.FunSuite:
assert(v1(2) == 6)
}

test("cumsum2".only) {
val v1 =
NArray[Double](1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)

println("cumsum2")
println(v1.cumsum2.printArr)

println("cumsum")
println(v1.tap(_.cumsum).printArr)

}

test("increments") {
val v1 = NArray[Double](1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0)
v1.increments.foreach(d => assertEqualsDouble(d, 1.0, 0.0001))
Expand Down

0 comments on commit 99ff031

Please sign in to comment.