Skip to content

Commit

Permalink
Test floorDiv() against float32 inputs (#7809)
Browse files Browse the repository at this point in the history
  • Loading branch information
hujiajie committed Jul 10, 2023
1 parent c4d1199 commit dcf06fb
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tfjs-backend-webgl/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ const customInclude = (testName: string) => {
'draw on canvas context',
// https://github.com/tensorflow/tfjs/issues/7618
'numbers exceed float32 precision',
// float32 inputs with nonzero fractional part should not be rounded
'floorDiv float32',
];
for (const subStr of toExclude) {
if (testName.includes(subStr)) {
Expand Down
7 changes: 7 additions & 0 deletions tfjs-backend-webgpu/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,13 @@ const TEST_FILTERS: TestFilter[] = [
'indices invalid',
],
},
{
startsWith: 'floorDiv ',
excludes: [
// float32 inputs with nonzero fractional part should not be rounded
'floorDiv float32',
],
},

// exclude unsupported kernels and to be fixed cases
{
Expand Down
9 changes: 9 additions & 0 deletions tfjs-core/src/ops/floordiv_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,13 @@ describeWithFlags('floorDiv', ALL_ENVS, () => {
expect(result.shape).toEqual(a.shape);
expectArraysClose(await result.data(), [1, 1, -3, -8]);
});

it('floorDiv float32', async () => {
const a = tf.tensor1d([0.0, -6.0, 5.9, -5.9], 'float32');
const b = tf.tensor1d([3.0, 3.0, -3.1, -3.1], 'float32');
const result = tf.floorDiv(a, b);

expect(result.shape).toEqual(a.shape);
expectArraysClose(await result.data(), [0, -2, -2, 1]);
});
});

0 comments on commit dcf06fb

Please sign in to comment.