Skip to content

Commit

Permalink
Move tensor to array functions to utils (#7810)
Browse files Browse the repository at this point in the history
* Move tensor to array functions to utils

* fix lint errors
  • Loading branch information
pforderique authored Jul 11, 2023
1 parent 206c0af commit bba29cb
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 32 deletions.
3 changes: 2 additions & 1 deletion tfjs-layers/src/layers/nlp/tokenizers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ import { Tensor, serialization, tensor, tidy} from '@tensorflow/tfjs-core';

import { Layer, LayerArgs } from '../../engine/topology';
import { NotImplementedError, ValueError } from '../../errors';
import { BytePairTokenizerCache, StaticHashTable, bytesToUnicode, createStaticHashtable, removeStringsFromInputs, splitStringsForBpe, tensorArrTo2DArr, tensorToArr } from './tokenizers_utils';
import { BytePairTokenizerCache, StaticHashTable, bytesToUnicode, createStaticHashtable, removeStringsFromInputs, splitStringsForBpe } from './tokenizers_utils';
import { tensorToArr, tensorArrTo2DArr } from './utils';

export declare interface TokenizerOptions {
mode?: 'tokenize' | 'detokenize';
Expand Down
9 changes: 1 addition & 8 deletions tfjs-layers/src/layers/nlp/tokenizers_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import { Tensor, tensor } from '@tensorflow/tfjs-core';
import { ValueError } from '../../errors';
import { matchAll } from './match_all_polyfill';
import { tensorArrTo2DArr, tensorToArr } from './utils';

export function bytesToUnicode(): [Uint8Array, string[]] {
const inclusiveRange = (start: number, end: number) =>
Expand Down Expand Up @@ -249,14 +250,6 @@ export function regexSplit(
});
}

export function tensorToArr(input: Tensor): unknown[] {
return input.dataSync() as unknown as unknown[];
}

export function tensorArrTo2DArr(inputs: Tensor[]): unknown[][] {
return inputs.map(input => tensorToArr(input));
}

export function splitStringsForBpe(
inputs: Tensor, unsplittableTokens?: string[]): Tensor[] {

Expand Down
24 changes: 1 addition & 23 deletions tfjs-layers/src/layers/nlp/tokenizers_utils_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ import { tensor, test_util } from '@tensorflow/tfjs-core';

import { BytePairTokenizerCache, SPLIT_PATTERN_1, bytesToUnicode,
createAltsForUnsplittableTokens, createStaticHashtable, regexSplit,
removeStringsFromInputs, splitStringsForBpe, tensorArrTo2DArr,
tensorToArr } from './tokenizers_utils';
removeStringsFromInputs, splitStringsForBpe } from './tokenizers_utils';
import { expectTensorsClose } from '../../utils/test_utils';

describe('bytesToUnicode', () => {
Expand Down Expand Up @@ -231,24 +230,3 @@ describe('splitStringsForBpe', () => {
expectTensorsClose(result[1], tensor(['black', '.']));
});
});

describe('tensor to array functions', () => {
it('tensorToArr', () => {
const inputStr = tensor(['these', 'are', 'strings', '.']);
const inputNum = tensor([2, 11, 15]);

test_util.expectArraysEqual(
tensorToArr(inputStr) as string[], ['these', 'are', 'strings', '.']);
test_util.expectArraysEqual(tensorToArr(inputNum) as number[], [2, 11, 15]);
});

it('tensorArrTo2DArr', () => {
const inputStr = [tensor(['these', 'are']), tensor(['strings', '.'])];
const inputNum = [tensor([2, 11]), tensor([15])];

test_util.expectArraysEqual(
tensorArrTo2DArr(inputStr) as string[][], [['these', 'are'], ['strings', '.']]);
test_util.expectArraysEqual(
tensorArrTo2DArr(inputNum) as number[][], [[2, 11], [15]]);
});
});
26 changes: 26 additions & 0 deletions tfjs-layers/src/layers/nlp/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import { Tensor } from '@tensorflow/tfjs-core';

export function tensorToArr(input: Tensor): unknown[] {
return Array.from(input.dataSync()) as unknown as unknown[];
}

export function tensorArrTo2DArr(inputs: Tensor[]): unknown[][] {
return inputs.map(input => tensorToArr(input));
}
42 changes: 42 additions & 0 deletions tfjs-layers/src/layers/nlp/utils_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import { tensor, test_util } from '@tensorflow/tfjs-core';
import { tensorArrTo2DArr, tensorToArr } from './utils';

describe('tensor to array functions', () => {
it('tensorToArr', () => {
const inputStr = tensor(['these', 'are', 'strings', '.']);
const inputNum = tensor([2, 11, 15]);

test_util.expectArraysEqual(
tensorToArr(inputStr) as string[], ['these', 'are', 'strings', '.']);
test_util.expectArraysEqual(tensorToArr(inputNum) as number[], [2, 11, 15]);
});

it('tensorArrTo2DArr', () => {
const inputStr = [tensor(['these', 'are']), tensor(['strings', '.'])];
const inputNum = [tensor([2, 11]), tensor([15])];

test_util.expectArraysEqual(
tensorArrTo2DArr(inputStr) as string[][],
[['these', 'are'], ['strings', '.']]
);
test_util.expectArraysEqual(
tensorArrTo2DArr(inputNum) as number[][], [[2, 11], [15]]);
});
});

0 comments on commit bba29cb

Please sign in to comment.