Skip to content

Commit

Permalink
[webgpu] Add cast, sigmoid, floorDiv. (tensorflow#1741)
Browse files Browse the repository at this point in the history
FEATURE
  • Loading branch information
annxingyuan authored May 7, 2019
1 parent dc8836a commit e77863d
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 6 deletions.
16 changes: 15 additions & 1 deletion src/backends/webgpu/src/backend_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import './flags_webgpu';

import {DataMover, DataType, ENV, KernelBackend, Rank, ShapeMap, Tensor, Tensor2D, Tensor3D, Tensor4D, util} from '@tensorflow/tfjs-core';
import * as backend_util from '@tensorflow/tfjs-core/dist/backends/backend_util';
import {computeOutShape} from '@tensorflow/tfjs-core/dist/ops/concat_util';
import {Conv2DInfo} from '@tensorflow/tfjs-core/dist/ops/conv_util';
import {upcastType} from '@tensorflow/tfjs-core/dist/types';
Expand Down Expand Up @@ -345,8 +346,17 @@ export class WebGPUBackend extends KernelBackend {
return this.binaryOp(a, b, binary_op.MUL);
}

floorDiv(a: Tensor, b: Tensor): Tensor {
return this.binaryOp(a, b, binary_op.INT_DIV);
}

sigmoid<T extends Tensor>(x: T): T {
const program = new UnaryOpProgram(x.shape, unary_op.SIGMOID);
return this.compileAndRun(program, [x]) as T;
}

relu<T extends Tensor>(x: T): T {
const program = new UnaryOpProgram(unary_op.RELU, x.shape);
const program = new UnaryOpProgram(x.shape, unary_op.RELU);
return this.compileAndRun(program, [x]) as T;
}

Expand All @@ -366,6 +376,10 @@ export class WebGPUBackend extends KernelBackend {
return Tensor.make(shape, {dataId: x.dataId}, x.dtype);
}

cast<T extends Tensor>(x: T, dtype: DataType): T {
return backend_util.castTensor(x, dtype, this);
}

transpose<T extends Tensor>(x: T, perm: number[]): T {
const program = new TransposeProgram(x.shape, perm);
return this.compileAndRun(program, [x]);
Expand Down
21 changes: 21 additions & 0 deletions src/backends/webgpu/src/binaryop_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,25 @@ describe('Binary ops', () => {
tf.test_util.expectArraysClose(
cData, new Float32Array([0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22]));
});

it('floor division', async () => {
const a = tf.tensor1d([-6, -6, -5, -4, -3, -3, 3, 3, 2]);
const c = tf.tensor1d([-2, 2, 3, 2, -3, 3, 2, 3, 2]);

const r = tf.floorDiv(a, c);

const rData = await r.data();
tf.test_util.expectArraysClose(
rData, new Float32Array([3, -3, -2, -2, 1, -1, 1, 1, 1]));
});

it('floor division broadcasts', async () => {
const a = tf.tensor1d([-5, -4, 3, 2]);
const c = tf.scalar(2);

const r = tf.floorDiv(a, c);

const rData = await r.data();
tf.test_util.expectArraysClose(rData, new Float32Array([-3, -2, 1, 1]));
});
});
7 changes: 7 additions & 0 deletions src/backends/webgpu/src/kernels/binary_op_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ import {WebGPUProgram} from './webgpu_program';
export const MUL = 'return a * b;';
export const ADD = 'return a + b;';

export const INT_DIV = `
float s = sign(a) * sign(b);
int ia = int(round(a));
int ib = int(round(b));
return float(idiv(ia, ib, s));
`;

export class BinaryOpProgram implements WebGPUProgram {
outputShape: number[];
userCode: string;
Expand Down
4 changes: 3 additions & 1 deletion src/backends/webgpu/src/kernels/unary_op_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ import {WebGPUProgram} from './webgpu_program';

export const RELU = 'return max(a, 0.0);';

export const SIGMOID = `return 1.0 / (1.0 + exp(-1.0 * a));`;

export class UnaryOpProgram implements WebGPUProgram {
outputShape: number[];
userCode: string;
dispatchLayout: {x: number[]};
dispatch: [number, number, number];
variableNames = ['A'];

constructor(op: string, outputShape: number[]) {
constructor(outputShape: number[], op: string) {
this.outputShape = outputShape;
this.dispatchLayout = {x: this.outputShape.map((d, i) => i)};
this.dispatch = computeDispatch(this.dispatchLayout, this.outputShape);
Expand Down
27 changes: 23 additions & 4 deletions src/backends/webgpu/src/shader_preprocessor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,15 @@ export function makeShader(

const SHADER_PREFIX = `
#version 450
int idiv(int a, int b, float sign) {
int res = a / b;
int mod = a % b;
if (sign < 0. && mod != 0) {
res -= 1;
}
return res;
}
`;

const SAMPLING_SNIPPETS = `
Expand Down Expand Up @@ -185,6 +194,14 @@ function getSamplerFromInInfo(inInfo: InputInfo): string {
const dims = ['d0', 'd1', 'd2', 'd3'].slice(0, rank);
const inputs = dims.map(d => `int ${d}`).join(', ');

if (rank < 1) {
return `
float ${funcName}() {
return ${texName}[0];
}
`;
}

return `
float ${funcName}(${inputs}) {
return ${texName}[getFlatIndex(${type}(${dims.join(',')}),
Expand All @@ -209,9 +226,11 @@ function getSamplerAtOutputCoords(

let coordsSnippet = '';

if (inRank > 0) {
if (inRank === 0) {
coordsSnippet = 'coords = 0;';
} else {
if (outRank < 2 && broadcastDims.length >= 1) {
coordsSnippet = 'coords = 0.;';
coordsSnippet = 'coords = 0;';
} else {
coordsSnippet =
broadcastDims.map(d => `coords[${d + rankDiff}] = 0;`).join('\n');
Expand All @@ -222,13 +241,13 @@ function getSamplerAtOutputCoords(
if (outRank < 2 && inRank > 0) {
unpackedCoordsSnippet = 'coords';
} else {
if (inRank > 1) {
if (outRank > 1) {
const coordsType = getCoordsDataType(inRank);
const coordsValues =
inInfo.shape.map((s, i) => `coords[${i + rankDiff}]`).join(', ');
unpackedCoordsSnippet = `${coordsType}(${coordsValues})`;
} else {
unpackedCoordsSnippet = `coords[${rankDiff}]`;
unpackedCoordsSnippet = 'coords';
}
}

Expand Down
9 changes: 9 additions & 0 deletions src/backends/webgpu/src/unaryop_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,13 @@ describe('Unary ops', () => {

tf.test_util.expectArraysClose(cData, new Float32Array([1, 0, 5, 0]));
});

it('sigmoid', async () => {
const a = tf.tensor1d([0, -1, 2, -3]);
const result = tf.sigmoid(a);
const cData = await result.data();

tf.test_util.expectArraysClose(
cData, new Float32Array([0.5, 0.2689, 0.8808, 0.0474]));
});
});

0 comments on commit e77863d

Please sign in to comment.