Skip to content

Commit

Permalink
Additional DataSetUtil merge utility method (#2792)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack authored Mar 27, 2018
1 parent 1e178cf commit 423d519
Showing 1 changed file with 29 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,35 @@ public static void setMaskedValuesToZero(INDArray data, INDArray mask) {
Nd4j.getExecutioner().exec(new BroadcastMulOp(data, mask, data, 0, 2));
}

/**
* Merge all of the features arrays into one minibatch.
*
* @param featuresToMerge features to merge. Note that first index is the input array (example) index, the second
* index is the input array.
* Thus to merge 10 examples with 3 input arrays each, featuresToMerge will be indexed
* like featuresToMerge[0..9][0..2]
* @param featureMasksToMerge May be null. If non-null: feature masks to merge
* @return Merged features, and feature masks. Note that feature masks may be added automatically, if required - even
* if no feature masks were present originally
*/
public Pair<INDArray[], INDArray[]> mergeFeatures(@NonNull INDArray[][] featuresToMerge, INDArray[][] featureMasksToMerge) {
int nInArrs = featuresToMerge[0].length;
INDArray[] outF = new INDArray[nInArrs];
INDArray[] outM = null;

for (int i = 0; i < nInArrs; i++) {
Pair<INDArray, INDArray> p = mergeFeatures(featuresToMerge, featureMasksToMerge, i);
outF[i] = p.getFirst();
if (p.getSecond() != null) {
if (outM == null) {
outM = new INDArray[nInArrs];
}
outM[i] = p.getSecond();
}
}
return new Pair<>(outF, outM);
}

/**
* Merge the specified features and mask arrays (i.e., concatenate the examples)
*
Expand Down

0 comments on commit 423d519

Please sign in to comment.