Skip to content

Commit

Permalink
Added back INDArray constructors to Mmul (#2795)
Browse files Browse the repository at this point in the history
  • Loading branch information
hvesalai authored and AlexDBlack committed Mar 27, 2018
1 parent e3dfca8 commit 8041327
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
Expand Down Expand Up @@ -73,6 +74,23 @@ public Mmul(SameDiff sameDiff,
this(sameDiff,i_v1,i_v2,MMulTranspose.allFalse());
}

/**
*
* @param x
* @param y
* @param z
*/
public Mmul(INDArray x,
INDArray y,
INDArray z,
MMulTranspose mMulTranspose) {
super(null, new INDArray[]{x, y}, z == null ? null : new INDArray[]{z});
if (mMulTranspose != null) {
this.mMulTranspose = mMulTranspose;
addIArgument(ArrayUtil.fromBoolean(mMulTranspose.isTransposeA()),
ArrayUtil.fromBoolean(mMulTranspose.isTransposeB()));
}
}


public Mmul() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.accum.LogSumExp;
import org.nd4j.linalg.api.ops.impl.accum.Mmul;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.indexing.BooleanIndexing;
Expand Down Expand Up @@ -419,9 +420,24 @@ public void testMMul() {

INDArray test = arr.mmul(arr.transpose());
assertEquals(getFailureMessage(), assertion, test);

}

@Test
public void testMmulOp() {
INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}});
INDArray z = Nd4j.create(2, 2);
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
MMulTranspose mMulTranspose = MMulTranspose.builder()
.transposeB(true)
.a(arr)
.b(arr)
.build();

DynamicCustomOp op = new Mmul(arr, arr, z, mMulTranspose);
Nd4j.getExecutioner().exec(op);

assertEquals(getFailureMessage(), assertion, z);
}


@Test
Expand Down

0 comments on commit 8041327

Please sign in to comment.