Skip to content

Commit

Permalink
[webgpu] Concat kernel + logical input samplers. (tensorflow#1740)
Browse files Browse the repository at this point in the history
FEATURE
  • Loading branch information
annxingyuan authored May 7, 2019
1 parent a6acfce commit dc8836a
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 32 deletions.
26 changes: 25 additions & 1 deletion src/backends/webgpu/src/backend_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@

import './flags_webgpu';

import {DataMover, DataType, ENV, KernelBackend, Rank, ShapeMap, Tensor, Tensor3D, Tensor4D, util} from '@tensorflow/tfjs-core';
import {DataMover, DataType, ENV, KernelBackend, Rank, ShapeMap, Tensor, Tensor2D, Tensor3D, Tensor4D, util} from '@tensorflow/tfjs-core';
import {computeOutShape} from '@tensorflow/tfjs-core/dist/ops/concat_util';
import {Conv2DInfo} from '@tensorflow/tfjs-core/dist/ops/conv_util';
import {upcastType} from '@tensorflow/tfjs-core/dist/types';
import * as shaderc from '@webgpu/shaderc';

import * as binary_op from './kernels/binary_op_webgpu';
import {BinaryOpProgram} from './kernels/binary_op_webgpu';
import {ConcatProgram} from './kernels/concat_webgpu';
import {Conv2DMMProgram} from './kernels/conv2d_mm_webgpu';
import {Conv2DNaiveProgram} from './kernels/conv2d_naive_webgpu';
import {MatMulPackedProgram} from './kernels/matmul_packed_webgpu';
Expand Down Expand Up @@ -317,6 +319,28 @@ export class WebGPUBackend extends KernelBackend {
Tensor4D;
}

concat(tensors: Tensor[], axis: number): Tensor {
if (tensors.length === 1) {
return tensors[0];
}
// Is there a maximum number of buffers that can be uploaded to a WebGPU
// program?
// if (tensors.length > MAX_SSBOS_FOR_WEBGPU_PROGRAM) {
// const midIndex = Math.floor(tensors.length / 2);
// const leftSide = this.concat(tensors.slice(0, midIndex), axis);
// const rightSide = this.concat(tensors.slice(midIndex), axis);
// return this.concat([leftSide, rightSide], axis);
// }
const outShape = computeOutShape(tensors.map(t => t.shape), axis);
const tensors2D = tensors.map(t => t.reshape([
util.sizeFromShape(t.shape.slice(0, axis)),
util.sizeFromShape(t.shape.slice(axis))
]) as Tensor2D);
const program = new ConcatProgram(tensors2D.map(t => t.shape));
const res = this.compileAndRun(program, tensors2D) as Tensor;
return res.reshape(outShape);
}

multiply(a: Tensor, b: Tensor): Tensor {
return this.binaryOp(a, b, binary_op.MUL);
}
Expand Down
75 changes: 75 additions & 0 deletions src/backends/webgpu/src/concat_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as tf from '@tensorflow/tfjs-core';

import * as tfwebgpu from './index';

describe('concat', () => {
beforeAll(async () => tfwebgpu.ready);

it('3 + 5', async () => {
const a = tf.tensor1d([3]);
const b = tf.tensor1d([5]);

const result = tf.concat1d([a, b]);
const expected = [3, 5];
const resultData = await result.data();
tf.test_util.expectArraysClose(resultData, new Float32Array(expected));
});

it('[[3]] + [[5]], axis=0', async () => {
const axis = 0;
const a = tf.tensor2d([3], [1, 1]);
const b = tf.tensor2d([5], [1, 1]);

const result = tf.concat2d([a, b], axis);
const expected = [3, 5];

expect(result.shape).toEqual([2, 1]);
const resultData = await result.data();
tf.test_util.expectArraysClose(resultData, new Float32Array(expected));
});

it('[[1, 2],[3, 4]] + [[5, 6],[7, 8]] + [[9, 10],[11, 12]], axis=1',
async () => {
const axis = 1;
const a = tf.tensor2d([[1, 2], [3, 4]]);
const b = tf.tensor2d([[5, 6], [7, 8]]);
const c = tf.tensor2d([[9, 10], [11, 12]]);

const result = tf.concat2d([a, b, c], axis);
const expected = [1, 2, 5, 6, 9, 10, 3, 4, 7, 8, 11, 12];

expect(result.shape).toEqual([2, 6]);
const resultData = await result.data();
tf.test_util.expectArraysClose(resultData, new Float32Array(expected));
});

it('concat axis=2', async () => {
const tensor1 = tf.tensor3d([1, 11, 2, 22, 3, 33, 4, 44], [2, 2, 2]);
const tensor2 = tf.tensor3d(
[5, 55, 555, 6, 66, 666, 7, 77, 777, 8, 88, 888], [2, 2, 3]);
const values = tf.concat3d([tensor1, tensor2], 2);
expect(values.shape).toEqual([2, 2, 5]);
const valuesData = await values.data();
tf.test_util.expectArraysClose(valuesData, new Float32Array([
1, 11, 5, 55, 555, 2, 22, 6, 66, 666,
3, 33, 7, 77, 777, 4, 44, 8, 88, 888
]));
});
});
68 changes: 68 additions & 0 deletions src/backends/webgpu/src/kernels/concat_webgpu.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as concat_util from '@tensorflow/tfjs-core/dist/ops/concat_util';
import {computeDispatch} from '../webgpu_util';
import {WebGPUProgram} from './webgpu_program';

export class ConcatProgram implements WebGPUProgram {
outputShape: number[];
userCode: string;
dispatchLayout: {x: number[], y: number[]};
dispatch: [number, number, number];
variableNames: string[];

constructor(shapes: Array<[number, number]>) {
this.outputShape =
concat_util.computeOutShape(shapes, 1 /* axis */) as [number, number];
this.variableNames = shapes.map((_, i) => `T${i}`);

this.dispatchLayout = {x: [0], y: [1]};
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape);

const offsets: number[] = new Array(shapes.length - 1);
offsets[0] = shapes[0][1];
for (let i = 1; i < offsets.length; i++) {
offsets[i] = offsets[i - 1] + shapes[i][1];
}

const snippets = [
`if (yC < ${offsets[0]}) setOutput(coords.x, coords.y, getT0(yR, yC));`
];

for (let i = 1; i < offsets.length; i++) {
const shift = offsets[i - 1];
snippets.push(
`else if (yC < ${offsets[i]}) ` +
`setOutput(coords.x, coords.y, getT${i}(yR, yC-${shift}));`);
}
const lastIndex = offsets.length;
const lastShift = offsets[offsets.length - 1];
snippets.push(`else setOutput(coords.x, coords.y, getT${lastIndex}(yR, yC-${
lastShift}));`);

this.userCode = `
void main() {
ivec2 coords = getOutputCoords();
int yR = coords.x;
int yC = coords.y;
${snippets.join('\n ')}
}
`;
}
}
8 changes: 4 additions & 4 deletions src/backends/webgpu/src/kernels/conv2d_naive_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,19 @@ export class Conv2DNaiveProgram implements WebGPUProgram {
float readInp(uint batch, uint row, uint col, uint chan) {
ivec4 coord = ivec4(batch, row, col, chan);
return coordIsValid(coord, xShape) ? x[getFlatIndex(coord, xShape)] : 0;
return coordIsValid(coord, xShape) ? getX(coord) : 0;
}
float readFilt(uint row, uint col, uint xChannel, uint outChannel) {
ivec4 coord = ivec4(row, col, xChannel, outChannel);
ivec4 shape = ivec4(filterDims, xShape[3], outShape[3]);
return coordIsValid(coord, shape) ? W[getFlatIndex(coord, shape)] : 0;
return coordIsValid(coord, shape) ?
getW(row, col, xChannel, outChannel) : 0;
}
void writeResult(uint batch, uint row, uint col, uint chan, float value) {
ivec4 coord = ivec4(batch, row, col, chan);
if (coordIsValid(coord, outShape)) {
result[getFlatIndex(coord, outShape)] = value;
setOutput(batch, row, col, chan, value);
}
}
Expand Down
5 changes: 2 additions & 3 deletions src/backends/webgpu/src/kernels/maxpool_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,13 @@ export class MaxPoolProgram implements WebGPUProgram {
if (xC < 0 || xC >= convDims.x) {
return 0.0;
}
return x[getFlatIndex(ivec4(batch, xR, xC, d), xShape)];
return getX(batch, xR, xC, d);
}
void main() {
ivec4 coords = getOutputCoords();
int batch = coords[0];
int d = coords[3];
uint index = getFlatIndex(coords, outShape);
if (all(lessThan(coords, outShape))) {
ivec2 xRCCorner = coords.yz * stride - pad;
Expand All @@ -73,7 +72,7 @@ export class MaxPoolProgram implements WebGPUProgram {
minMaxValue = max(value, minMaxValue);
}
}
setOutput(index, minMaxValue);
setOutput(batch, coords[1], coords[2], d, minMaxValue);
}
}
`;
Expand Down
10 changes: 5 additions & 5 deletions src/backends/webgpu/src/kernels/pad_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,15 @@ export class PadProgram implements WebGPUProgram {
const startValue = rank > 1 ? `${type}(${start})` : `${start}`;
const endValue = rank > 1 ? `${type}(${end})` : `${end}`;

const xShapeValue =
rank > 1 ? `${type}(${xShape.join(',')})` : `${xShape[0]}`;

const leftPadCondition =
rank > 1 ? `any(lessThan(outC, start))` : `outC < start`;
const rightPadCondition =
rank > 1 ? `any(greaterThanEqual(outC, end))` : `outC >= end`;

const unpackedCoords = rank > 1 ?
['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank) :
'coords';

this.userCode = `
${type} start = ${startValue};
${type} end = ${endValue};
Expand All @@ -62,8 +63,7 @@ export class PadProgram implements WebGPUProgram {
setOutput(index, ${constantValue});
} else {
${type} coords = outC - start;
${type} xShape = ${xShapeValue};
setOutput(index, x[getFlatIndex(coords, xShape)]);
setOutput(index, getX(${unpackedCoords}));
}
}
`;
Expand Down
16 changes: 5 additions & 11 deletions src/backends/webgpu/src/kernels/resize_bilinear_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ export class ResizeBilinearProgram implements WebGPUProgram {
const adjustWidth = alignCorners && newWidth > 1;

this.userCode = `
float getValue(int b, int r, int c, int d) {
return x[getFlatIndex(ivec4(b, r, c, d), xShape)];
}
void main() {
ivec4 coords = getOutputCoords();
Expand All @@ -62,8 +58,6 @@ export class ResizeBilinearProgram implements WebGPUProgram {
vec2 effectiveInputOverOutputRatioRC =
effectiveInSize / effectiveOutSize;
uint index = getFlatIndex(coords, outShape);
// Fractional source index
vec2 sourceFracIndexRC = vec2(rc) * effectiveInputOverOutputRatioRC;
Expand All @@ -72,18 +66,18 @@ export class ResizeBilinearProgram implements WebGPUProgram {
ivec2 sourceCeilRC = ivec2(
min(xShape.yz - 1.0, ceil(sourceFracIndexRC)));
float topLeft = getValue(b, sourceFloorRC.x, sourceFloorRC.y, d);
float bottomLeft = getValue(b, sourceCeilRC.x, sourceFloorRC.y, d);
float topRight = getValue(b, sourceFloorRC.x, sourceCeilRC.y, d);
float bottomRight = getValue(b, sourceCeilRC.x, sourceCeilRC.y, d);
float topLeft = getX(b, sourceFloorRC.x, sourceFloorRC.y, d);
float bottomLeft = getX(b, sourceCeilRC.x, sourceFloorRC.y, d);
float topRight = getX(b, sourceFloorRC.x, sourceCeilRC.y, d);
float bottomRight = getX(b, sourceCeilRC.x, sourceCeilRC.y, d);
vec2 fracRC = sourceFracIndexRC - vec2(sourceFloorRC);
float top = topLeft + (topRight - topLeft) * fracRC.y;
float bottom = bottomLeft + (bottomRight - bottomLeft) * fracRC.y;
float newValue = top + (bottom - top) * fracRC.x;
setOutput(index, newValue);
setOutput(b, coords[1], coords[2], d, newValue);
}
}
`;
Expand Down
Loading

0 comments on commit dc8836a

Please sign in to comment.