Skip to content

Commit

Permalink
Tests fixes (#2773)
Browse files Browse the repository at this point in the history
* skip node import if there's no properties

* Max tf name fix

* axis normalization

* keepDims :/

* get rid of default F order. C ftw!

* pow fix for tf import

* nullcheck
  • Loading branch information
raver119 authored Mar 21, 2018
1 parent 152f892 commit 24e0073
Show file tree
Hide file tree
Showing 10 changed files with 133 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ public SameDiff importGraph(GRAPH_TYPE tfGraph) {
//for each variable
for(Map.Entry<String,TENSOR_TYPE> entry : variablesForGraph.entrySet()) {
if(dataTypeForTensor(entry.getValue()) == DataBuffer.Type.UNKNOWN) {
val var = importState.getSameDiff().var(entry.getKey(),null,new ZeroInitScheme('f'));
val var = importState.getSameDiff().var(entry.getKey(),null,new ZeroInitScheme('c'));
//mark as place holder for validating resolution later.
if(isPlaceHolder(entry.getValue())) {
importState.getSameDiff().addAsPlaceHolder(var.getVarName());
Expand All @@ -175,7 +175,7 @@ public SameDiff importGraph(GRAPH_TYPE tfGraph) {
diff.associateArrayWithVariable(arr,var);
}
else if(getShapeFromTensor(entry.getValue()) == null) {
val var = importState.getSameDiff().var(entry.getKey(),null,new ZeroInitScheme('f'));
val var = importState.getSameDiff().var(entry.getKey(),null,new ZeroInitScheme('c'));
//mark as place holder for validating resolution later.

//note that this vertex id can still be a place holder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,11 @@ public void initFunctionFromProperties(String mappedTfName, DifferentialFunction
val tfProperties = properties.get(mappedTfName);
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
val attributeAdapters = on.attributeAdaptersForFunction();

// if there's no properties announced for this function - just return
if (tfProperties == null)
return;

for(val entry : tfProperties.entrySet()) {
val tfAttrName = entry.getValue().getTfAttrName();
val currentField = fields.get(entry.getKey());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,13 @@ else if(hasReductionIndices(nodeDef)) {
for(int i = 0; i < graph.getNodeCount(); i++) {
if (graph.getNode(i).getName().equals(nodeDef.getName() + "/reduction_indices")) {
reductionNode = graph.getNode(i);
val arr = TFGraphMapper.getInstance().getNDArrayFromTensor("value", graph.getNode(i), graph);
val arr = TFGraphMapper.getInstance().getNDArrayFromTensor("value", reductionNode, graph);

boolean keepAxis = nodeDef.getAttrOrThrow("keep_dims").getB();

// keepAxis = false by default
int[] dimensions = ArrayUtils.add(arr.data().asInt(), 0, keepAxis ? 1 : 0);

//int[] dimensions = ArrayUtils.add(arr.data().asInt(), 0, keepAxis ? 1 : 0);
int[] dimensions = arr.data().asInt();

this.dimensions = dimensions;
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ public String onnxName() {

@Override
public String tensorflowName() {
return "reduce_max";
return "Max";
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, A
}



@Override
public String opName() {
return "dynamic_stitch";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,19 @@

package org.nd4j.linalg.api.ops.impl.transforms;

import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

import java.util.Arrays;
import java.util.List;
import java.util.Map;

/**
* Pow function
Expand Down Expand Up @@ -99,6 +105,19 @@ public void init(INDArray x, INDArray y, INDArray z, long n) {
this.extraArgs = new Object[]{pow};
}

@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
val weightsName = nodeDef.getInput(1);
val variable = initWith.getVariable(weightsName);
val tmp = initWith.getArrForVarName(weightsName);

// if second argument is scalar - we should provide array of same shape
if (tmp != null) {
if (tmp.isScalar()) {
this.pow = tmp.getDouble(0);
}
}
}

@Override
public String onnxName() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,12 @@
import org.nd4j.linalg.indexing.ShapeOffsetResolution;
import org.nd4j.linalg.util.ArrayUtil;

import java.lang.reflect.Array;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.IntBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.*;

/**
* Encapsulates all shape related logic (vector of 0 dimension is a scalar is equivalent to
Expand Down Expand Up @@ -306,9 +304,12 @@ else if (dimensions.length == 1 && wholeShape.length == 2) {
* @return the shape of the result array as the result of the reduce
*/
public static int[] getReducedShape(int[] wholeShape, int[] dimensions, boolean keepDims, boolean newFormat) {
// we need to normalize dimensions, in case they have negative values or unsorted, or whatever
dimensions = Shape.normalizeAxis(wholeShape.length, dimensions);

// strip leading keepDims argument
if (newFormat)
dimensions = Arrays.copyOfRange(dimensions, 1, dimensions.length);
//if (newFormat)
// dimensions = Arrays.copyOfRange(dimensions, 1, dimensions.length);

if (!keepDims)
if (!newFormat)
Expand Down Expand Up @@ -2446,6 +2447,39 @@ public static boolean wholeArrayDimension(int... arr) {
return arr.length == 1 && arr[0] == Integer.MAX_VALUE;
}

public static int[] uniquify(int[] array) {
if (array.length <= 1)
return array;

Set<Integer> ints = new LinkedHashSet<>();

for (val v: array)
ints.add(v);

return Ints.toArray(ints);
}

public static int[] normalizeAxis(int rank, int... axis) {
// first we should get rid of all negative axis
int[] tmp = new int[axis.length];

int cnt = 0;
for (val v: axis) {
val t = v < 0 ? v + rank : v;

if ((t >= rank && t != Integer.MAX_VALUE)|| t < 0)
throw new ND4JIllegalStateException("Axis array " + Arrays.toString(axis) + " contains values above rank " + rank);

tmp[cnt++] = t;
}

// now we're sorting array
Arrays.sort(tmp);

// and getting rid of possible duplicates
return uniquify(tmp);
}

/**
*
* Compare the contents of a buffer and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1444,6 +1444,15 @@ protected CudaContext invoke(TransformOp op) {
if (extraz.get() == null)
extraz.set(new PointerPointer(32));

// Pow operations might be special
if (op.opNum() == 7) {
if (op.y() != null && op.y().isScalar()) {
Nd4j.getExecutioner().commit();
op.setY(Nd4j.valueArrayOf(op.x().shape(), op.y().getDouble(0)));
Nd4j.getExecutioner().commit();
}
}

CudaContext context = allocator.getFlowController().prepareAction(op.z(), op.x(), op.y());

if (CudaEnvironment.getInstance().getConfiguration().isDebug())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ public INDArray exec(IndexAccumulation op, int... dimension) {
if (extraz.get() == null)
extraz.set(new PointerPointer(32));

Arrays.sort(dimension);
dimension = Shape.normalizeAxis(op.x().rank(), dimension);

for (int i = 0; i < dimension.length; i++) {
if (dimension[i] < 0)
dimension[i] += op.x().rank();
Expand Down Expand Up @@ -260,7 +261,7 @@ public INDArray exec(IndexAccumulation op, int... dimension) {

@Override
public INDArray exec(Accumulation op, int... dimension) {
Arrays.sort(dimension);
dimension = Shape.normalizeAxis(op.x().rank(), dimension);


validateDataType(Nd4j.dataType(), op);
Expand Down Expand Up @@ -561,7 +562,7 @@ else if (op.y() != null && op.getOpType() == Op.Type.REDUCE3) {
* @param dimension
*/
private void invoke(ScalarOp op, int[] dimension) {
Arrays.sort(dimension);
dimension = Shape.normalizeAxis(op.x().rank(), dimension);
// do tad magic
/**
* Returns the {@link Shape#createShapeInformation(int[], int[], int, int, char)}
Expand Down Expand Up @@ -684,6 +685,13 @@ private void exec(TransformOp op) {

PointerPointer dummy = extraz.get();

// Pow operations might be special
if (op.opNum() == 7) {
if (op.y() != null && op.y().isScalar()) {
op.setY(Nd4j.valueArrayOf(op.x().shape(), op.y().getDouble(0)));
}
}

/**
* This is the {@link org.nd4j.linalg.api.ops.impl.transforms.IsMax}
* operation.
Expand Down Expand Up @@ -819,7 +827,7 @@ public INDArray exec(BroadcastOp op, int... dimension) {
long st = profilingHookIn(op);
if(dimension == null)
dimension = new int[] {Integer.MAX_VALUE};
Arrays.sort(dimension);
dimension = Shape.normalizeAxis(op.x().rank(), dimension);

validateDataType(Nd4j.dataType(), op);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;

Expand Down Expand Up @@ -131,6 +132,46 @@ public void testKeepDimsShape_4_F() throws Exception {
}


@Test
public void testAxisNormalization_1() throws Exception {
val axis = new int[] {1, -2};
val rank = 2;
val exp = new int[] {0, 1};

val norm = Shape.normalizeAxis(rank, axis);
assertArrayEquals(exp, norm);
}

@Test
public void testAxisNormalization_2() throws Exception {
val axis = new int[] {1, -2, 0};
val rank = 2;
val exp = new int[] {0, 1};

val norm = Shape.normalizeAxis(rank, axis);
assertArrayEquals(exp, norm);
}

@Test(expected = ND4JIllegalStateException.class)
public void testAxisNormalization_3() throws Exception {
val axis = new int[] {1, -2, 2};
val rank = 2;
val exp = new int[] {0, 1};

val norm = Shape.normalizeAxis(rank, axis);
assertArrayEquals(exp, norm);
}

@Test
public void testAxisNormalization_4() throws Exception {
val axis = new int[] {1, 2, 0};
val rank = 3;
val exp = new int[] {0, 1, 2};

val norm = Shape.normalizeAxis(rank, axis);
assertArrayEquals(exp, norm);
}

@Override
public char ordering() {
return 'c';
Expand Down

0 comments on commit 24e0073

Please sign in to comment.