Skip to content

Commit

Permalink
[webgpu] Add readSync just when data lives on the CPU so getParamValu…
Browse files Browse the repository at this point in the history
…e works during model conversion. (tensorflow#1765)

FEATURE
  • Loading branch information
annxingyuan authored May 22, 2019
1 parent e76d682 commit 5e96913
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 2 deletions.
31 changes: 29 additions & 2 deletions src/backends/webgpu/src/backend_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import {DataMover, DataType, ENV, KernelBackend, Rank, ShapeMap, Tensor, Tensor2
import * as backend_util from '@tensorflow/tfjs-core/dist/backends/backend_util';
import {computeOutShape} from '@tensorflow/tfjs-core/dist/ops/concat_util';
import {Conv2DInfo} from '@tensorflow/tfjs-core/dist/ops/conv_util';
import {upcastType} from '@tensorflow/tfjs-core/dist/types';
import {TypedArray, upcastType} from '@tensorflow/tfjs-core/dist/types';
import {assert} from '@tensorflow/tfjs-core/dist/util';
import * as shaderc from '@webgpu/shaderc';

Expand Down Expand Up @@ -146,14 +146,41 @@ export class WebGPUBackend extends KernelBackend {
return mapped.slice(0);
}

private convertAndCacheOnCPU(dataId: DataId, float32Values: Float32Array):
TypedArray {
const texData = this.tensorMap.get(dataId);

// TODO: implement release GPU data.
// TODO: add backend_webgl float32ToTypedArray to util and use that here.

texData.values = float32Values;
return texData.values as TypedArray;
}

// TODO: Remove once this is fixed:
// https://github.com/tensorflow/tfjs/issues/1595
readSync(dataId: object): Float32Array|Int32Array|Uint8Array {
const texData = this.tensorMap.get(dataId);
const {values} = texData;

if (values == null) {
throw new Error(
'WebGPU readSync is only available for CPU-resident tensors.');
}

return values;
}

async read(dataId: object): Promise<Float32Array|Int32Array|Uint8Array> {
if (!this.tensorMap.has(dataId)) {
throw new Error(`Tensor ${dataId} was not registered!`);
}
const info = this.tensorMap.get(dataId);
const data = await this.getBufferData(info);

return new Float32Array(data);
const dataAsFloat32Array = new Float32Array(data);
this.convertAndCacheOnCPU(dataId, dataAsFloat32Array);
return dataAsFloat32Array;
}

private getAndSavePipeline(
Expand Down
35 changes: 35 additions & 0 deletions src/backends/webgpu/src/backend_webgpu_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as tf from '@tensorflow/tfjs-core';
import {describeWebGPU} from './test_util';

describeWebGPU('backend webgpu', () => {
it('readSync should throw if tensors are on the GPU', async () => {
const a = tf.tensor2d([1, 2, 3, 4], [2, 2]);
const b = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);

const c = tf.matMul(a, b);
expect(() => c.dataSync())
.toThrowError(
'WebGPU readSync is only available for CPU-resident tensors.');

await c.data();
// Now that data has been downloaded to the CPU, dataSync should work.
expect(() => c.dataSync()).not.toThrow();
});
});

0 comments on commit 5e96913

Please sign in to comment.