Skip to content

Commit

Permalink
Merge pull request #112 from Atry/fix-broadcast
Browse files Browse the repository at this point in the history
Fix broadcast method
  • Loading branch information
Atry authored Mar 26, 2018
2 parents 83901f4 + 2e77368 commit 7e6c25e
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ trait OpenCLKernelBuilder extends AllExpressions {
def extract: Element = {
val numberOfRows = originalShape.length
val numberOfColumns = matrix.length / numberOfRows
if (matrix.length % numberOfRows != 0) {
if (matrix.length != numberOfRows * numberOfColumns) {
throw new IllegalStateException()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ trait Tensors extends OpenCL {
if (i < length) {
shape(i) match {
case di if di == newShape(i) =>
matrix1(i * (length + 1) + i) = 1.0
matrix1(i * (newLength + 1) + i) = 1.0
case 1 =>
case _ =>
throw new IllegalArgumentException(
Expand Down
62 changes: 62 additions & 0 deletions Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -293,4 +293,66 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {
.run
.toScalaFuture

"matrix multiplication" in doTensors
.map { tensors =>
import tensors._

def matrixMultiply(matrix1: Tensor, matrix2: Tensor): Tensor = {
val Array(i, j) = matrix1.shape
val Array(`j`, k) = matrix2.shape
val product = matrix1.broadcast(Array(i, j, k)) * matrix2.reshape(Array(1, j, k)).broadcast(Array(i, j, k))

product.unzip(1).reduce(_ + _)

}

val matrix1 = Tensor(Array(Array(1.0f, 2.0f, 3.0f), Array(4.0f, 5.0f, 6.0f)))
val matrix2 = Tensor(
Array(Array(7.0f, 8.0f, 9.0f, 10.0f), Array(11.0f, 12.0f, 13.0f, 14.0f), Array(15.0f, 16.0f, 17.0f, 18.0f)))

matrixMultiply(matrix1, matrix2).toString should be("[[74.0,80.0,86.0,92.0],[173.0,188.0,203.0,218.0]]")

}
.run
.toScalaFuture

"broadcast" in doTensors
.map { tensors =>
import tensors._

val matrix1 = Tensor(Array(Array(1.0f, 2.0f, 3.0f), Array(4.0f, 5.0f, 6.0f)))
matrix1.broadcast(Array(2, 3, 4)).toString should be(
"[[[1.0,1.0,1.0,1.0],[2.0,2.0,2.0,2.0],[3.0,3.0,3.0,3.0]],[[4.0,4.0,4.0,4.0],[5.0,5.0,5.0,5.0],[6.0,6.0,6.0,6.0]]]")
}
.run
.toScalaFuture

"unrolled matrix multiplication" in doTensors
.map { tensors =>
import tensors._

def matrixMultiply(matrix1: Tensor, matrix2: Tensor): Tensor = {

val columns1 = matrix1.unzip(1)

Tensor.zip(matrix2.unzip(1).map { column2: Tensor =>
(columns1 zip column2.unzip(0))
.map {
case (l: Tensor, r: Tensor) =>
l * r.broadcast(l.shape)
}
.reduce[Tensor](_ + _)
})
}

matrixMultiply(
Tensor(Array(Array(1.0f, 2.0f, 3.0f), Array(4.0f, 5.0f, 6.0f))),
Tensor(
Array(Array(7.0f, 8.0f, 9.0f, 10.0f), Array(11.0f, 12.0f, 13.0f, 14.0f), Array(15.0f, 16.0f, 17.0f, 18.0f)))
).toString should be("[[74.0,80.0,86.0,92.0],[173.0,188.0,203.0,218.0]]")

}
.run
.toScalaFuture

}

0 comments on commit 7e6c25e

Please sign in to comment.