From d79d8acb9d22ababc79124ef5c041444a23d9269 Mon Sep 17 00:00:00 2001 From: Nicolas Feybesse Date: Mon, 26 Feb 2024 15:29:03 +0100 Subject: [PATCH 1/5] Interface should be public for external usage --- .../src/main/java/org/tensorflow/framework/metrics/Metric.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index c8c1df607c2..c2982e9b0b0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -22,7 +22,7 @@ import org.tensorflow.types.family.TNumber; /** Interface for metrics */ -interface Metric { +public interface Metric { /** * Creates a List of Operations to update the metric state based on input values. From 17951df8ef650b0727c548c835491777b2841776 Mon Sep 17 00:00:00 2001 From: Nicolas Feybesse Date: Tue, 27 Feb 2024 11:42:00 +0100 Subject: [PATCH 2/5] Fix https://github.com/tensorflow/java/issues/523 --- .../framework/losses/impl/LossesHelper.java | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java index 3e635b0d957..451706af1d1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java @@ -282,11 +282,12 @@ private static Operand reduceWeightedLoss( if (reduction == Reduction.NONE) { loss = weightedLoss; } else { - loss = - tf.reduceSum(weightedLoss, allAxes(tf, weightedLoss), ReduceSum.keepDims(Boolean.FALSE)); if (reduction == Reduction.AUTO || reduction == Reduction.SUM_OVER_BATCH_SIZE) { - loss = safeMean(tf, loss, weightedLoss.shape().size()); + loss = safeMean(tf, weightedLoss); } + else + loss = tf.reduceSum(weightedLoss, allAxes(tf, weightedLoss), ReduceSum.keepDims(Boolean.FALSE)); + } return loss; } @@ -302,9 +303,9 @@ private static Operand reduceWeightedLoss( * zero, then zero is returned. */ public static Operand safeMean( - Ops tf, Operand losses, long numElements) { - Operand totalLoss = tf.reduceSum(losses, allAxes(tf, losses)); - return tf.math.divNoNan(totalLoss, cast(tf, tf.constant(numElements), losses.type())); + Ops tf, Operand losses) { + Operand totalLoss = tf.reduceSum(losses, allAxes(tf, losses),ReduceSum.keepDims(Boolean.FALSE)); + return tf.math.divNoNan(totalLoss, cast(tf,tf.shape.size(tf.shape(losses)),losses.type())); } /** From 2b985ce36fc69946df412b97d5fd21739cebd485 Mon Sep 17 00:00:00 2001 From: Nicolas Feybesse Date: Wed, 28 Feb 2024 15:04:31 +0100 Subject: [PATCH 3/5] Fix google format --- .../framework/losses/impl/LossesHelper.java | 38 ++++++++++--------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java index 451706af1d1..6c40149f3de 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java @@ -51,7 +51,7 @@ public class LossesHelper { * @param tf the TensorFlow Ops * @param predictions Predicted values, a Operand of arbitrary dimensions. * @param labels Optional label Operand whose dimensions match prediction - * . + * . * @param the data type for the labels, predictions and result * @return LossTuple of prediction, label,sampleWeight will * be null. Each of them possibly has the last dimension squeezed, sampleWeight @@ -77,7 +77,7 @@ public static LossTuple squeezeOrExpandDimensions( * @param tf the TensorFlow Ops * @param predictions Predicted values, a Operand of arbitrary dimensions. * @param labels Optional label Operand whose dimensions match prediction - * . + * . * @param sampleWeights Optional sample weight(s) Operand whose dimensions match * * prediction. @@ -179,7 +179,7 @@ private static Operand maybeExpandWeights( * * @param tf the TensorFlowOps * @param labels Label values, a Tensor whose dimensions match predictions - * . + * . * @param predictions Predicted values, a Tensor of arbitrary dimensions. * @param the data type for the labels, predictions and result * @return labels and predictions, possibly with last dim squeezed. @@ -194,7 +194,7 @@ public static LossTuple removeSqueezableDimensions( * * @param tf the TensorFlowOps * @param labels Label values, a Operand whose dimensions match predictions - * . + * . * @param predictions Predicted values, a Tensor of arbitrary dimensions. * @param expectedRankDiff Expected result of rank(predictions) - rank(labels). * @param the data type for the labels, predictions and result @@ -222,11 +222,13 @@ public static LossTuple removeSqueezableDimensions( // Use dynamic rank. // TODO: hold for lazy select feature, - // Operand rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels)); + // Operand rankDiff = tf.math.sub(tf.rank(predictions), + // tf.rank(labels)); if (predictionsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(predictionsShape.size(-1), 1)) { /* - * TODO, if we ever get a select that does lazy evaluation, but for now do the tf.squeeze - * predictions = tf.select( tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ), + * TODO, if we ever get a select that does lazy evaluation, but for now do the + * tf.squeeze predictions = tf.select( + * tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ), * tf.squeeze(predictions, Squeeze.axis(Arrays.asList(-1L))), predictions ); * */ predictions = tf.squeeze(predictions, Squeeze.axis(Collections.singletonList(-1L))); @@ -284,10 +286,10 @@ private static Operand reduceWeightedLoss( } else { if (reduction == Reduction.AUTO || reduction == Reduction.SUM_OVER_BATCH_SIZE) { loss = safeMean(tf, weightedLoss); - } - else - loss = tf.reduceSum(weightedLoss, allAxes(tf, weightedLoss), ReduceSum.keepDims(Boolean.FALSE)); - + } else + loss = + tf.reduceSum( + weightedLoss, allAxes(tf, weightedLoss), ReduceSum.keepDims(Boolean.FALSE)); } return loss; } @@ -302,10 +304,10 @@ private static Operand reduceWeightedLoss( * @return A scalar representing the mean of losses. If numElements is * zero, then zero is returned. */ - public static Operand safeMean( - Ops tf, Operand losses) { - Operand totalLoss = tf.reduceSum(losses, allAxes(tf, losses),ReduceSum.keepDims(Boolean.FALSE)); - return tf.math.divNoNan(totalLoss, cast(tf,tf.shape.size(tf.shape(losses)),losses.type())); + public static Operand safeMean(Ops tf, Operand losses) { + Operand totalLoss = + tf.reduceSum(losses, allAxes(tf, losses), ReduceSum.keepDims(Boolean.FALSE)); + return tf.math.divNoNan(totalLoss, cast(tf, tf.shape.size(tf.shape(losses)), losses.type())); } /** @@ -349,7 +351,8 @@ public static Operand rangeCheck( tf.math.logicalAnd( tf.reduceAll(tf.math.greaterEqual(values, minValue), allDims), tf.reduceAll(tf.math.lessEqual(values, maxValue), allDims)); - // Graph and Eager mode need to be handled differently, control dependencies are not allowed in + // Graph and Eager mode need to be handled differently, control dependencies are + // not allowed in // Eager mode if (tf.scope().env().isGraph()) { AssertThat assertThat = @@ -399,7 +402,8 @@ public static Operand valueCheck( } else return values; } else { // use dynamic shape Operand cond = tf.math.equal(tf.shape.size(tf.shape(diff.out())), tf.constant(0)); - // Graph and Eager mode need to be handled differently, control dependencies are not allowed + // Graph and Eager mode need to be handled differently, control dependencies are + // not allowed // in Eager mode if (tf.scope().env().isGraph()) { AssertThat assertThat = From fa29fc71f4c9e54418c18e9224be3564155e234f Mon Sep 17 00:00:00 2001 From: Nicolas Feybesse Date: Mon, 4 Mar 2024 14:13:43 +0100 Subject: [PATCH 4/5] fix https://github.com/tensorflow/java/issues/526 --- .../op/nn/SoftmaxCrossEntropyWithLogits.java | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java index a95110c9a96..10106723fba 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java @@ -42,15 +42,12 @@ public class SoftmaxCrossEntropyWithLogits { *

Usage: * *

-   *   Operand<TFloat32> logits =
-   *       tf.constant(new float[][] {{4.0F, 2.0F, 1.0F}, {0.0F, 5.0F, 1.0F}} );
-   *   Operand<TFloat32> labels =
-   *       tf.constant(new float[][] {{1.0F, 0.0F, 0.0F}, {0.0F, 0.8F, 0.2F}} );
-   *   Operand<TFloat32> output =
-   *       tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
-   *   // output Shape = [2]
-   *   // dataType = FLOAT (1)
-   *   // values { 0.169846, 0.824745 }
+   * Operand<TFloat32> logits = tf.constant(new float[][] { { 4.0F, 2.0F, 1.0F }, { 0.0F, 5.0F, 1.0F } });
+   * Operand<TFloat32> labels = tf.constant(new float[][] { { 1.0F, 0.0F, 0.0F }, { 0.0F, 0.8F, 0.2F } });
+   * Operand<TFloat32> output = tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
+   * // output Shape = [2]
+   * // dataType = FLOAT (1)
+   * // values { 0.169846, 0.824745 }
    * 
* *

Backpropagation will happen into both logits and labels. To @@ -157,7 +154,7 @@ public static Operand softmaxCrossEntr * @return the flattened logits */ private static Operand flattenOuterDims(Scope scope, Operand logits) { - Operand one = Constant.scalarOf(scope, 1L); + Operand one = Constant.arrayOf(scope, 1L); Shape shape = logits.shape(); int ndims = shape.numDimensions(); From b11b918b7f973b707fa056d240d39add479245ba Mon Sep 17 00:00:00 2001 From: Nicolas Feybesse Date: Thu, 7 Mar 2024 13:29:32 +0100 Subject: [PATCH 5/5] Add test to CategoricalCrossentropyTest.java --- .../losses/CategoricalCrossentropyTest.java | 112 ++++++++---------- 1 file changed, 47 insertions(+), 65 deletions(-) diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java index 25f5e5a54f1..1be85927d4f 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java @@ -17,10 +17,14 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import org.junit.jupiter.api.Test; +import org.tensorflow.Graph; import org.tensorflow.Operand; +import org.tensorflow.Session; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DataBuffers; import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Placeholder; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; @@ -36,16 +40,8 @@ public void testAllCorrectUnweighted() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - long[] trueArray = { - 1L, 0L, 0L, - 0L, 1L, 0L, - 0L, 0L, 1L - }; - float[] predArray = { - 1.F, 0.F, 0.F, - 0.F, 1.F, 0.F, - 0.F, 0.F, 1.F - }; + long[] trueArray = {1L, 0L, 0L, 0L, 1L, 0L, 0L, 0L, 1L}; + float[] predArray = {1.F, 0.F, 0.F, 0.F, 1.F, 0.F, 0.F, 0.F, 1.F}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); CategoricalCrossentropy instance = new CategoricalCrossentropy(); @@ -55,11 +51,7 @@ public void testAllCorrectUnweighted() { testSession.evaluate(expected, loss); // Test with logits. - float[] logitsArray = { - 10.F, 0.F, 0.F, - 0.F, 10.F, 0.F, - 0.F, 0.F, 10.F - }; + float[] logitsArray = {10.F, 0.F, 0.F, 0.F, 10.F, 0.F, 0.F, 0.F, 10.F}; yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); @@ -85,11 +77,7 @@ public void testInvalidPredictionsRange() { Ops tf = testSession.getTF(); CategoricalCrossentropy instance = new CategoricalCrossentropy(); - float[] trueArray = { - 1L, 0L, 0L, - 0L, 1L, 0L, - 0L, 0L, 1L - }; + float[] trueArray = {1L, 0L, 0L, 0L, 1L, 0L, 0L, 0L, 1L}; float[] predArray = {-1.F, 0.F, 0.F, 0.F, 1.F, 0.F, 0.F, 0.F, 1.F}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); @@ -111,11 +99,7 @@ public void testUnweighted() { CategoricalCrossentropy instance = new CategoricalCrossentropy(); int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1}; - float[] predArray = { - .9F, .05F, .05F, - .5F, .89F, .6F, - .05F, .01F, .94F - }; + float[] predArray = {.9F, .05F, .05F, .5F, .89F, .6F, .05F, .01F, .94F}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); Operand loss = instance.call(tf, yTrue, yPred); @@ -123,11 +107,7 @@ public void testUnweighted() { testSession.evaluate(expected, loss); // Test with logits. - float[] logitsArray = { - 8.F, 1.F, 1.F, - 0.F, 9.F, 1.F, - 2.F, 3.F, 5.F - }; + float[] logitsArray = {8.F, 1.F, 1.F, 0.F, 9.F, 1.F, 2.F, 3.F, 5.F}; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); instance = new CategoricalCrossentropy(true); @@ -145,16 +125,8 @@ public void testScalarWeighted() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); - int[] trueArray = { - 1, 0, 0, - 0, 1, 0, - 0, 0, 1 - }; - float[] predArray = { - .9F, .05F, .05F, - .5F, .89F, .6F, - .05F, .01F, .94F - }; + int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1}; + float[] predArray = {.9F, .05F, .05F, .5F, .89F, .6F, .05F, .01F, .94F}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); Operand sampleWeight = tf.constant(2.3F); @@ -166,11 +138,7 @@ public void testScalarWeighted() { testSession.evaluate(expected, loss); // Test with logits. - float[] logitsArray = { - 8.F, 1.F, 1.F, - 0.F, 9.F, 1.F, - 2.F, 3.F, 5.F - }; + float[] logitsArray = {8.F, 1.F, 1.F, 0.F, 9.F, 1.F, 2.F, 3.F, 5.F}; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); instance = new CategoricalCrossentropy(true); @@ -189,16 +157,8 @@ public void testSsampleWeighted() { CategoricalCrossentropy instance = new CategoricalCrossentropy(); float[] sampeWeightArray = {1.2F, 3.4F, 5.6F}; - int[] trueArray = { - 1, 0, 0, - 0, 1, 0, - 0, 0, 1 - }; - float[] predArray = { - .9F, .05F, .05F, - .5F, .89F, .6F, - .05F, .01F, .94F - }; + int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1}; + float[] predArray = {.9F, .05F, .05F, .5F, .89F, .6F, .05F, .01F, .94F}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand yPred = tf.reshape(tf.constant(predArray), tf.constant(Shape.of(3, 3))); Operand sampleWeight = @@ -208,11 +168,7 @@ public void testSsampleWeighted() { testSession.evaluate(expected, loss); // Test with logits. - float[] logitsArray = { - 8.F, 1.F, 1.F, - 0.F, 9.F, 1.F, - 2.F, 3.F, 5.F - }; + float[] logitsArray = {8.F, 1.F, 1.F, 0.F, 9.F, 1.F, 2.F, 3.F, 5.F}; Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); instance = new CategoricalCrossentropy(true); @@ -231,11 +187,7 @@ public void testNoReduction() { // Test with logits. int[] trueArray = {1, 0, 0, 0, 1, 0, 0, 0, 1}; - float[] logitsArray = { - 8.F, 1.F, 1.F, - 0.F, 9.F, 1.F, - 2.F, 3.F, 5.F - }; + float[] logitsArray = {8.F, 1.F, 1.F, 0.F, 9.F, 1.F, 2.F, 3.F, 5.F}; Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(3, 3))); Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(3, 3))); @@ -266,4 +218,34 @@ public void testLabelSmoothing() { testSession.evaluate(expected, loss); } } + + @Test + public void testCategoricalCrossEntopyWithDynamicBatchSize() { + try (Graph graph = new Graph()) { + Ops tf = Ops.create(graph); + Operand yPred = tf.placeholder(TFloat32.class, Placeholder.shape(Shape.of(-1, 3))); + Operand yTrue = + tf.reshape(tf.constant(new float[] {1f, 0f, 0f, 0f, 1f, 0f, 0f, 0f, 1f}), tf.array(3, 3)); + CategoricalCrossentropy instance = new CategoricalCrossentropy(true); + Operand loss = + instance.call(tf, yTrue, yPred); // Throw TFInvalidArgument Exception without fix + try (Session session = new Session(graph); + TFloat32 result = + (TFloat32) + session + .runner() + .feed( + yPred, + TFloat32.tensorOf( + Shape.of(3, 3), + DataBuffers.of( + new float[] {1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f}))) + .fetch(loss) + .run() + .get(0)) { + if (Math.abs(0.5514477f - result.getFloat()) > 0.01) + throw new IllegalStateException("Invalid result :" + result.getFloat()); + } + } + } }