diff --git a/tfjs-backend-webgl/src/setup_test.ts b/tfjs-backend-webgl/src/setup_test.ts index 4bcfbee6b72..0faa95d5a3a 100644 --- a/tfjs-backend-webgl/src/setup_test.ts +++ b/tfjs-backend-webgl/src/setup_test.ts @@ -42,6 +42,8 @@ const customInclude = (testName: string) => { 'draw on canvas context', // https://github.com/tensorflow/tfjs/issues/7618 'numbers exceed float32 precision', + // float32 inputs with nonzero fractional part should not be rounded + 'floorDiv float32', ]; for (const subStr of toExclude) { if (testName.includes(subStr)) { diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index f881a35fbae..19582dd84b8 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -168,6 +168,13 @@ const TEST_FILTERS: TestFilter[] = [ 'indices invalid', ], }, + { + startsWith: 'floorDiv ', + excludes: [ + // float32 inputs with nonzero fractional part should not be rounded + 'floorDiv float32', + ], + }, // exclude unsupported kernels and to be fixed cases { diff --git a/tfjs-core/src/ops/floordiv_test.ts b/tfjs-core/src/ops/floordiv_test.ts index d14d12b9c90..46fe81c8b82 100644 --- a/tfjs-core/src/ops/floordiv_test.ts +++ b/tfjs-core/src/ops/floordiv_test.ts @@ -37,4 +37,13 @@ describeWithFlags('floorDiv', ALL_ENVS, () => { expect(result.shape).toEqual(a.shape); expectArraysClose(await result.data(), [1, 1, -3, -8]); }); + + it('floorDiv float32', async () => { + const a = tf.tensor1d([0.0, -6.0, 5.9, -5.9], 'float32'); + const b = tf.tensor1d([3.0, 3.0, -3.1, -3.1], 'float32'); + const result = tf.floorDiv(a, b); + + expect(result.shape).toEqual(a.shape); + expectArraysClose(await result.data(), [0, -2, -2, 1]); + }); });