From dcf06fb550307ee49c0e92dece2145d4bc903291 Mon Sep 17 00:00:00 2001 From: Jiajie Hu Date: Mon, 10 Jul 2023 08:53:22 +0800 Subject: [PATCH] Test floorDiv() against float32 inputs (#7809) --- tfjs-backend-webgl/src/setup_test.ts | 2 ++ tfjs-backend-webgpu/src/setup_test.ts | 7 +++++++ tfjs-core/src/ops/floordiv_test.ts | 9 +++++++++ 3 files changed, 18 insertions(+) diff --git a/tfjs-backend-webgl/src/setup_test.ts b/tfjs-backend-webgl/src/setup_test.ts index 4bcfbee6b72..0faa95d5a3a 100644 --- a/tfjs-backend-webgl/src/setup_test.ts +++ b/tfjs-backend-webgl/src/setup_test.ts @@ -42,6 +42,8 @@ const customInclude = (testName: string) => { 'draw on canvas context', // https://github.com/tensorflow/tfjs/issues/7618 'numbers exceed float32 precision', + // float32 inputs with nonzero fractional part should not be rounded + 'floorDiv float32', ]; for (const subStr of toExclude) { if (testName.includes(subStr)) { diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index f881a35fbae..19582dd84b8 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -168,6 +168,13 @@ const TEST_FILTERS: TestFilter[] = [ 'indices invalid', ], }, + { + startsWith: 'floorDiv ', + excludes: [ + // float32 inputs with nonzero fractional part should not be rounded + 'floorDiv float32', + ], + }, // exclude unsupported kernels and to be fixed cases { diff --git a/tfjs-core/src/ops/floordiv_test.ts b/tfjs-core/src/ops/floordiv_test.ts index d14d12b9c90..46fe81c8b82 100644 --- a/tfjs-core/src/ops/floordiv_test.ts +++ b/tfjs-core/src/ops/floordiv_test.ts @@ -37,4 +37,13 @@ describeWithFlags('floorDiv', ALL_ENVS, () => { expect(result.shape).toEqual(a.shape); expectArraysClose(await result.data(), [1, 1, -3, -8]); }); + + it('floorDiv float32', async () => { + const a = tf.tensor1d([0.0, -6.0, 5.9, -5.9], 'float32'); + const b = tf.tensor1d([3.0, 3.0, -3.1, -3.1], 'float32'); + const result = tf.floorDiv(a, b); + + expect(result.shape).toEqual(a.shape); + expectArraysClose(await result.data(), [0, -2, -2, 1]); + }); });