From 139f595d6ec2cb57b73f59fc42510e2dbf4961c6 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Wed, 28 Jun 2023 08:35:34 +0800 Subject: [PATCH] [webgpu] Implement draw API (#7749) FEATURE * [webgpu] Implement draw API * Support bgra8-unorm * Mark texture externael and rename to_pixels to draw * Nit --------- Co-authored-by: Linchenn <40653845+Linchenn@users.noreply.github.com> --- tfjs-backend-webgl/src/setup_test.ts | 2 +- tfjs-backend-webgpu/src/backend_webgpu.ts | 24 +- tfjs-backend-webgpu/src/base.ts | 8 +- tfjs-backend-webgpu/src/draw_webgpu.ts | 79 ++++++ tfjs-backend-webgpu/src/from_pixels_webgpu.ts | 4 +- tfjs-backend-webgpu/src/kernels/Draw.ts | 85 +++++++ .../src/register_all_kernels.ts | 2 + tfjs-backend-webgpu/src/setup_test.ts | 6 - tfjs-backend-webgpu/src/webgpu_program.ts | 24 +- tfjs-core/src/ops/draw_test.ts | 226 +++++++++++++----- 10 files changed, 374 insertions(+), 86 deletions(-) create mode 100644 tfjs-backend-webgpu/src/draw_webgpu.ts create mode 100644 tfjs-backend-webgpu/src/kernels/Draw.ts diff --git a/tfjs-backend-webgl/src/setup_test.ts b/tfjs-backend-webgl/src/setup_test.ts index 20fd1cfe11c..4bcfbee6b72 100644 --- a/tfjs-backend-webgl/src/setup_test.ts +++ b/tfjs-backend-webgl/src/setup_test.ts @@ -39,7 +39,7 @@ const customInclude = (testName: string) => { 'throws when index is out of bound', // otsu tests for threshold op is failing on windows 'method otsu', - 'Draw on 2d context', + 'draw on canvas context', // https://github.com/tensorflow/tfjs/issues/7618 'numbers exceed float32 precision', ]; diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index 9b5faa9f345..ba3ffecd2a6 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -909,7 +909,8 @@ export class WebGPUBackend extends KernelBackend { // program size, program defined uniforms. let programUniform: ProgramUniform = []; let bufferShapes: number[][] = []; - if (!program.isFromPixels) { + const uniformsType = 'int32'; + if (program.pixelsOpType == null) { programUniform.push( {type: 'float32', data: [NaN]}, {type: 'float32', data: [Infinity]}); bufferShapes = inputs.concat(output).map(d => d.shape); @@ -919,14 +920,16 @@ export class WebGPUBackend extends KernelBackend { const strides = util.computeStrides(d); programUniform.push({type: uniformsType, data: strides}); }); - if (program.size) { - const size = util.sizeFromShape(program.outputShape); - programUniform.push({ - type: uniformsType, - data: - [program.outputComponent ? size / program.outputComponent : size] - }); - } + } else { + const strides = util.computeStrides(output.shape); + programUniform.push({type: uniformsType, data: strides}); + } + if (program.size) { + const size = util.sizeFromShape(program.outputShape); + programUniform.push({ + type: uniformsType, + data: [program.outputComponent ? size / program.outputComponent : size] + }); } if (programDefinedUniform) { @@ -986,7 +989,8 @@ export class WebGPUBackend extends KernelBackend { if (shouldTimeProgram || env().get('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE') as - number <= this.dispatchCountInPass) { + number <= this.dispatchCountInPass || + program.pixelsOpType === webgpu_program.PixelsOpType.DRAW) { this.endComputePassEncoder(); if (shouldTimeProgram) { this.activeTimers.push( diff --git a/tfjs-backend-webgpu/src/base.ts b/tfjs-backend-webgpu/src/base.ts index 8f4ad52877c..1de43149b07 100644 --- a/tfjs-backend-webgpu/src/base.ts +++ b/tfjs-backend-webgpu/src/base.ts @@ -33,9 +33,15 @@ if (isWebGPUSupported()) { const adapter = await navigator.gpu.requestAdapter(gpuDescriptor); const deviceDescriptor: GPUDeviceDescriptor = {}; + const requiredFeatures = []; if (adapter.features.has('timestamp-query')) { - deviceDescriptor.requiredFeatures = ['timestamp-query']; + requiredFeatures.push('timestamp-query'); } + if (adapter.features.has('bgra8unorm-storage')) { + requiredFeatures.push(['bgra8unorm-storage']); + } + deviceDescriptor.requiredFeatures = + requiredFeatures as Iterable; const adapterLimits = adapter.limits; deviceDescriptor.requiredLimits = { diff --git a/tfjs-backend-webgpu/src/draw_webgpu.ts b/tfjs-backend-webgpu/src/draw_webgpu.ts new file mode 100644 index 00000000000..608e8099d98 --- /dev/null +++ b/tfjs-backend-webgpu/src/draw_webgpu.ts @@ -0,0 +1,79 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {DataType} from '@tensorflow/tfjs-core'; + +import {getMainHeaderString as main, PixelsOpType, WebGPUProgram} from './webgpu_program'; +import {computeDispatch, flatDispatchLayout} from './webgpu_util'; + +export class DrawProgram implements WebGPUProgram { + variableNames = ['Image']; + uniforms = 'alpha: f32,'; + outputShape: number[]; + shaderKey: string; + dispatchLayout: {x: number[]}; + dispatch: [number, number, number]; + workgroupSize: [number, number, number] = [64, 1, 1]; + type: DataType; + textureFormat: GPUTextureFormat; + pixelsOpType = PixelsOpType.DRAW; + size = true; + + constructor( + outShape: number[], type: DataType, textureFormat: GPUTextureFormat) { + this.outputShape = outShape; + this.dispatchLayout = flatDispatchLayout(this.outputShape); + this.dispatch = computeDispatch( + this.dispatchLayout, this.outputShape, this.workgroupSize); + this.type = type; + this.textureFormat = textureFormat; + this.shaderKey = `draw_${type}_${textureFormat}`; + } + + getUserCode(): string { + let calculateResult; + const value = this.type === 'float32' ? 'value' : 'value / 255.0'; + calculateResult = ` + if (uniforms.numChannels == 1) { + rgba[0] = ${value}; + rgba[1] = ${value}; + rgba[2] = ${value}; + } else { + rgba[d] = ${value}; + }`; + + const userCode = ` + @group(0) @binding(0) var outImage : texture_storage_2d<${ + this.textureFormat}, write>; + ${main('index')} { + if (index < uniforms.size) { + var rgba = vec4(0.0, 0.0, 0.0, uniforms.alpha); + for (var d = 0; d < uniforms.numChannels; d = d + 1) { + let value = f32(inBuf[index * uniforms.numChannels + d]); + ${calculateResult} + } + rgba.x = rgba.x * rgba.w; + rgba.y = rgba.y * rgba.w; + rgba.z = rgba.z * rgba.w; + let coords = getCoordsFromIndex(index); + textureStore(outImage, vec2(coords.yx), rgba); + } + } + `; + return userCode; + } +} diff --git a/tfjs-backend-webgpu/src/from_pixels_webgpu.ts b/tfjs-backend-webgpu/src/from_pixels_webgpu.ts index 0d3daa2709b..92a720f0249 100644 --- a/tfjs-backend-webgpu/src/from_pixels_webgpu.ts +++ b/tfjs-backend-webgpu/src/from_pixels_webgpu.ts @@ -15,13 +15,13 @@ * ============================================================================= */ -import {getMainHeaderString as main, WebGPUProgram} from './webgpu_program'; +import {getMainHeaderString as main, PixelsOpType, WebGPUProgram} from './webgpu_program'; import {computeDispatch, flatDispatchLayout} from './webgpu_util'; export class FromPixelsProgram implements WebGPUProgram { dispatch: [number, number, number]; dispatchLayout: {x: number[]}; - isFromPixels = true; + pixelsOpType = PixelsOpType.FROM_PIXELS; outputShape: number[] = [0]; shaderKey: string; importVideo: boolean; diff --git a/tfjs-backend-webgpu/src/kernels/Draw.ts b/tfjs-backend-webgpu/src/kernels/Draw.ts new file mode 100644 index 00000000000..6ccbb7b367b --- /dev/null +++ b/tfjs-backend-webgpu/src/kernels/Draw.ts @@ -0,0 +1,85 @@ +/** + * @license + * Copyright 2023 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use backend file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, KernelFunc, TensorInfo} from '@tensorflow/tfjs-core'; +import {Draw, DrawAttrs, DrawInputs,} from '@tensorflow/tfjs-core'; + +import {WebGPUBackend} from '../backend_webgpu'; +import {DrawProgram} from '../draw_webgpu'; + +export function draw( + args: {inputs: DrawInputs, backend: WebGPUBackend, attrs: DrawAttrs}): + TensorInfo { + const {inputs, backend, attrs} = args; + const {image} = inputs; + const {canvas, options} = attrs; + const [height, width] = image.shape.slice(0, 2); + const {imageOptions} = options || {}; + const alpha = imageOptions ?.alpha || 1; + + // 'rgba8unorm' should work on macOS according to + // https://bugs.chromium.org/p/chromium/issues/detail?id=1298618. But + // failed on macOS/M2. So use 'bgra8unorm' first when available. + const format = backend.device.features.has('bgra8unorm-storage') ? + 'bgra8unorm' : + 'rgba8unorm'; + const outShape = [height, width]; + const program = new DrawProgram(outShape, image.dtype, format); + canvas.width = width; + canvas.height = height; + const backendName = 'webgpu'; + let gpuContext = canvas.getContext(backendName); + let canvasWebGPU; + if (!gpuContext) { + canvasWebGPU = new OffscreenCanvas(width, height); + gpuContext = canvasWebGPU.getContext(backendName); + } + const numChannels = image.shape.length === 3 ? image.shape[2] : 1; + gpuContext.configure({ + device: backend.device, + format, + usage: GPUTextureUsage.STORAGE_BINDING, + alphaMode: 'premultiplied' + }); + + const outputDtype = 'int32'; + const output = backend.makeTensorInfo(outShape, outputDtype); + const info = backend.tensorMap.get(output.dataId); + info.resource = gpuContext.getCurrentTexture(); + info.external = true; + + const uniformData = + [{type: 'uint32', data: [numChannels]}, {type: 'float32', data: [alpha]}]; + backend.runWebGPUProgram(program, [image], outputDtype, uniformData, output); + + if (canvasWebGPU) { + const canvas2dContext = canvas.getContext('2d'); + if (!canvas2dContext) { + throw new Error( + `Please make sure this canvas has only been used for 2d or webgpu context!`); + } + canvas2dContext.drawImage(canvasWebGPU, 0, 0); + } + backend.disposeData(output.dataId); + return image; +} + +export const drawConfig: KernelConfig = { + kernelName: Draw, + backendName: 'webgpu', + kernelFunc: draw as unknown as KernelFunc +}; diff --git a/tfjs-backend-webgpu/src/register_all_kernels.ts b/tfjs-backend-webgpu/src/register_all_kernels.ts index f7b4b97868b..26d2ef3a10c 100644 --- a/tfjs-backend-webgpu/src/register_all_kernels.ts +++ b/tfjs-backend-webgpu/src/register_all_kernels.ts @@ -65,6 +65,7 @@ import {diagConfig} from './kernels/Diag'; import {dilation2DConfig} from './kernels/Dilation2D'; import {dilation2DBackpropFilterConfig} from './kernels/Dilation2DBackpropFilter'; import {dilation2DBackpropInputConfig} from './kernels/Dilation2DBackpropInput'; +import {drawConfig} from './kernels/Draw'; import {einsumConfig} from './kernels/Einsum'; import {eluConfig} from './kernels/Elu'; import {eluGradConfig} from './kernels/EluGrad'; @@ -229,6 +230,7 @@ const kernelConfigs: KernelConfig[] = [ dilation2DConfig, dilation2DBackpropFilterConfig, dilation2DBackpropInputConfig, + drawConfig, einsumConfig, eluConfig, eluGradConfig, diff --git a/tfjs-backend-webgpu/src/setup_test.ts b/tfjs-backend-webgpu/src/setup_test.ts index 90594b6301c..f881a35fbae 100644 --- a/tfjs-backend-webgpu/src/setup_test.ts +++ b/tfjs-backend-webgpu/src/setup_test.ts @@ -126,12 +126,6 @@ const TEST_FILTERS: TestFilter[] = [ 'canvas and image match', // Failing on Linux ], }, - { - startsWith: 'Draw', - excludes: [ - 'on 2d context', - ] - }, { startsWith: 'sign ', excludes: [ diff --git a/tfjs-backend-webgpu/src/webgpu_program.ts b/tfjs-backend-webgpu/src/webgpu_program.ts index 4743cd5320f..f36591040d2 100644 --- a/tfjs-backend-webgpu/src/webgpu_program.ts +++ b/tfjs-backend-webgpu/src/webgpu_program.ts @@ -19,6 +19,11 @@ import {backend_util, DataType, DataTypeMap, env, Rank, TensorInfo, util} from ' import {symbolicallyComputeStrides} from './shader_util'; +export enum PixelsOpType { + FROM_PIXELS, + DRAW +} + export interface WebGPUProgram { // Whether to use atomic built-in functions. atomic?: boolean; @@ -27,10 +32,10 @@ export interface WebGPUProgram { // dispatchLayout enumerates how tensor dimensions are distributed among // dispatch x,y,z dimensions. dispatchLayout: {x: number[], y?: number[], z?: number[]}; - isFromPixels?: boolean; // By default, the output data component is 1. outputComponent?: number; outputShape: number[]; + pixelsOpType?: PixelsOpType; // The unique key to distinguish different shader source code. shaderKey: string; // Whether to use output size for bounds checking. @@ -219,16 +224,23 @@ function makeShader( } `); - if (program.isFromPixels) { + if (program.pixelsOpType != null) { + const inoutSnippet = program.pixelsOpType === PixelsOpType.FROM_PIXELS ? + `@group(0) @binding(0) var result: array<${ + dataTypeToGPUType(outputData.dtype, program.outputComponent)}>;` : + `@group(0) @binding(1) var inBuf : array<${ + dataTypeToGPUType(inputInfo[0].dtype, program.outputComponent)}>;`; + const outShapeStridesType = + outputData.shape.length === 3 ? 'vec2' : 'i32'; prefixSnippets.push(` struct Uniform { + outShapeStrides : ${outShapeStridesType}, size : i32, numChannels : i32, - outShapeStrides : vec2, + alpha : f32, }; - @group(0) @binding(0) var result: array<${ - dataTypeToGPUType(outputData.dtype, program.outputComponent)}>; + ${inoutSnippet} @group(0) @binding(2) var uniforms: Uniform; `); const useGlobalIndex = isFlatDispatchLayout(program); @@ -339,7 +351,7 @@ export function makeShaderKey( program: WebGPUProgram, inputsData: InputInfo[], output: TensorInfo): string { let key = program.shaderKey; - if (program.isFromPixels) { + if (program.pixelsOpType != null) { return key; } diff --git a/tfjs-core/src/ops/draw_test.ts b/tfjs-core/src/ops/draw_test.ts index 0b72254e180..dd7500a1b47 100644 --- a/tfjs-core/src/ops/draw_test.ts +++ b/tfjs-core/src/ops/draw_test.ts @@ -19,86 +19,192 @@ import * as tf from '../index'; import {BROWSER_ENVS, describeWithFlags} from '../jasmine_util'; import {expectArraysClose, expectArraysEqual} from '../test_util'; -class MockContext { - data: ImageData; - - getImageData() { - return this.data; - } - - putImageData(data: ImageData, x: number, y: number) { - this.data = data; +function readPixelsFromCanvas( + canvas: OffscreenCanvas, contextType: string, width: number, + height: number) { + let actualData; + if (contextType === '2d') { + const ctx = canvas.getContext(contextType); + actualData = ctx.getImageData(0, 0, width, height).data; + } else { + const offscreenCanvas = new OffscreenCanvas(width, height); + const ctx = offscreenCanvas.getContext('2d'); + ctx.drawImage(canvas, 0, 0); + actualData = new Uint8ClampedArray( + ctx.getImageData(0, 0, ctx.canvas.width, ctx.canvas.height).data); } + return actualData; } -class MockCanvas { - context: MockContext; +function convertToRGBA( + data: number[], shape: number[], dtype: string, alpha = 1) { + const [height, width] = shape.slice(0, 2); + const depth = shape.length === 2 ? 1 : shape[2]; + const multiplier = dtype === 'float32' ? 255 : 1; + const bytes = new Uint8ClampedArray(width * height * 4); + + for (let i = 0; i < height * width; ++i) { + const rgba = [0, 0, 0, 255 * alpha]; - constructor(public width: number, public height: number) {} + for (let d = 0; d < depth; d++) { + const value = data[i * depth + d]; - getContext(type: '2d'): MockContext { - if (this.context == null) { - this.context = new MockContext(); + if (dtype === 'float32') { + if (value < 0 || value > 1) { + throw new Error( + `Tensor values for a float32 Tensor must be in the ` + + `range [0 - 1] but encountered ${value}.`); + } + } else if (dtype === 'int32') { + if (value < 0 || value > 255) { + throw new Error( + `Tensor values for a int32 Tensor must be in the ` + + `range [0 - 255] but encountered ${value}.`); + } + } + + if (depth === 1) { + rgba[0] = value * multiplier; + rgba[1] = value * multiplier; + rgba[2] = value * multiplier; + } else { + rgba[d] = value * multiplier; + } } - return this.context; + + const j = i * 4; + bytes[j + 0] = Math.round(rgba[0]); + bytes[j + 1] = Math.round(rgba[1]); + bytes[j + 2] = Math.round(rgba[2]); + bytes[j + 3] = Math.round(rgba[3]); + } + return bytes; +} + +function drawAndReadback( + contextType: string, data: number[], shape: number[], dtype: string, + alpha = 1, canvasUsedAs2d = false) { + const [height, width] = shape.slice(0, 2); + let img; + if (shape.length === 3) { + img = tf.tensor3d( + data, shape as [number, number, number], dtype as keyof tf.DataTypeMap); + } else { + img = tf.tensor2d( + data, shape as [number, number], dtype as keyof tf.DataTypeMap); } + const canvas = new OffscreenCanvas(width, height); + if (canvasUsedAs2d) { + canvas.getContext('2d'); + } + const drawOptions = {contextOptions: {contextType}, imageOptions: {alpha}}; + // tslint:disable-next-line:no-any + tf.browser.draw(img, canvas as any, drawOptions); + const actualData = readPixelsFromCanvas(canvas, contextType, width, height); + const expectedData = convertToRGBA(data, shape, dtype, alpha); + img.dispose(); + return [actualData, expectedData]; } -describeWithFlags('Draw on 2d context', BROWSER_ENVS, () => { +// CPU and GPU handle pixel value differently. The epsilon may possibly grow +// after each draw and read back. The empirical value is 3.0. +const DRAW_EPSILON = 3.0; + +describeWithFlags('draw on canvas context', BROWSER_ENVS, (env) => { + let contextType: string; + beforeAll(async () => { + await tf.setBackend(env.name); + contextType = env.name === 'cpu' ? '2d' : env.name; + }); + it('draw image with 4 channels and int values', async () => { - const data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; - const img = tf.tensor3d(data, [2, 2, 4], 'int32'); - const canvas = new MockCanvas(2, 2); - const ctx = canvas.getContext('2d'); - - // tslint:disable-next-line:no-any - tf.browser.draw(img, canvas as any, {contextOptions: {contextType: '2d'}}); - expectArraysEqual(ctx.getImageData().data, data); + const data = + [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160]; + const shape = [2, 2, 4]; + const dtype = 'int32'; + const startNumTensors = tf.memory().numTensors; + const [actualData, expectedData] = + drawAndReadback(contextType, data, shape, dtype); + expect(tf.memory().numTensors).toEqual(startNumTensors); + expectArraysClose(actualData, expectedData, DRAW_EPSILON); + }); + + it('draw image with 4 channels and int values, alpha=0.5', async () => { + const data = + [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160]; + const shape = [2, 2, 4]; + const dtype = 'int32'; + const startNumTensors = tf.memory().numTensors; + const [actualData, expectedData] = + drawAndReadback(contextType, data, shape, dtype, 0.5); + expect(tf.memory().numTensors).toEqual(startNumTensors); + expectArraysClose(actualData, expectedData, DRAW_EPSILON); }); it('draw image with 4 channels and float values', async () => { const data = - [.1, .2, .3, .4, .5, .6, .7, .8, .9, .1, .11, .12, .13, .14, .15, .16]; - const img = tf.tensor3d(data, [2, 2, 4]); - const canvas = new MockCanvas(2, 2); - const ctx = canvas.getContext('2d'); - - // tslint:disable-next-line:no-any - tf.browser.draw(img, canvas as any, {contextOptions: {contextType: '2d'}}); - const actualData = ctx.getImageData().data; - const expectedData = data.map(e => Math.round(e * 255)); - expectArraysClose(actualData, expectedData, 1); + [.1, .2, .3, .4, .5, .6, .7, .8, .09, .1, .11, .12, .13, .14, .15, .16]; + const shape = [2, 2, 4]; + const dtype = 'float32'; + const [actualData, expectedData] = + drawAndReadback(contextType, data, shape, dtype); + expectArraysClose(actualData, expectedData, DRAW_EPSILON); + }); + + it('draw image with 3 channels and int values', async () => { + const data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + const shape = [2, 2, 3]; + const dtype = 'int32'; + const [actualData, expectedData] = + drawAndReadback(contextType, data, shape, dtype); + expectArraysEqual(actualData, expectedData); + }); + + it('draw image with 3 channels and int values, alpha=0.5', async () => { + const data = [101, 32, 113, 14, 35, 76, 17, 38, 59, 70, 81, 92]; + const shape = [2, 2, 3]; + const dtype = 'int32'; + const alpha = 0.5; + const [actualData, expectedData] = + drawAndReadback(contextType, data, shape, dtype, alpha); + expectArraysClose(actualData, expectedData, DRAW_EPSILON); + }); + + it('draw image with 3 channels and float values', async () => { + const data = [.1, .2, .3, .4, .5, .6, .7, .8, .9, .1, .11, .12]; + const shape = [2, 2, 3]; + const dtype = 'float32'; + const [actualData, expectedData] = + drawAndReadback(contextType, data, shape, dtype); + expectArraysClose(actualData, expectedData, DRAW_EPSILON); }); it('draw 2D image in grayscale', async () => { - const data = [1, 2, 3, 4]; - const img = tf.tensor2d(data, [2, 2], 'int32'); - const canvas = new MockCanvas(2, 2); - const ctx = canvas.getContext('2d'); - - // tslint:disable-next-line:no-any - tf.browser.draw(img, canvas as any, {contextOptions: {contextType: '2d'}}); - const actualData = ctx.getImageData().data; - const expectedData = - [1, 1, 1, 255, 2, 2, 2, 255, 3, 3, 3, 255, 4, 4, 4, 255]; + const data = [100, 12, 90, 64]; + const shape = [2, 2]; + const dtype = 'int32'; + const [actualData, expectedData] = + drawAndReadback(contextType, data, shape, dtype); expectArraysEqual(actualData, expectedData); }); it('draw image with alpha=0.5', async () => { - const data = [1, 2, 3, 4]; - const img = tf.tensor3d(data, [2, 2, 1], 'int32'); - const canvas = new MockCanvas(2, 2); - const ctx = canvas.getContext('2d'); - - const drawOptions = { - contextOptions: {contextType: '2d'}, - imageOptions: {alpha: 0.5} - }; - // tslint:disable-next-line:no-any - tf.browser.draw(img, canvas as any, drawOptions); - const actualData = ctx.getImageData().data; - const expectedData = - [1, 1, 1, 128, 2, 2, 2, 128, 3, 3, 3, 128, 4, 4, 4, 128]; - expectArraysEqual(actualData, expectedData); + const data = [101, 212, 113, 14, 35, 76, 17, 38, 59, 70, 81, 92]; + const shape = [6, 2, 1]; + const dtype = 'int32'; + const alpha = 0.5; + const [actualData, expectedData] = + drawAndReadback(contextType, data, shape, dtype, alpha); + expectArraysClose(actualData, expectedData, DRAW_EPSILON); + }); + + it('draw image works when canvas has been used for 2d', async () => { + const data = [101, 212, 113, 14, 35, 76, 17, 38, 59, 70, 81, 92]; + const shape = [6, 2, 1]; + const dtype = 'int32'; + // Set canvasUsedAs2d to true so the canvas will be first used for 2d. + const [actualData, expectedData] = + drawAndReadback(contextType, data, shape, dtype, 1, true); + expectArraysClose(actualData, expectedData, DRAW_EPSILON); }); });