Skip to content

Commit

Permalink
Add PReLU math backend function and differentiable graph operation. (t…
Browse files Browse the repository at this point in the history
  • Loading branch information
wang2bo2 authored and Nikhil Thorat committed Dec 20, 2017
1 parent 673afff commit ba7d597
Show file tree
Hide file tree
Showing 13 changed files with 335 additions and 9 deletions.
33 changes: 31 additions & 2 deletions src/graph/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,16 @@ export class Graph {
return this.addNodeAndReturnOutput(new LeakyReLUNode(this, x, alpha));
}

/**
* Computes PReLU of x element-wise.
* @param x The input tensor to the LeakyReLU.
* @param alpha Negative slope coefficient tensor.
* @return The tensor representing the PReLU operation.
*/
prelu(x: Tensor, alpha: Tensor): Tensor {
return this.addNodeAndReturnOutput(new PReLUNode(this, x, alpha));
}

/**
* Computes Elu of x element-wise.
* @param x the input tensor to the Elu.
Expand Down Expand Up @@ -531,8 +541,8 @@ export class AddNode extends Node {
constructor(graph: Graph, private t1: Tensor, private t2: Tensor) {
super(
graph, 'Add', {t1, t2},
new Tensor(util.sizeFromShape(t1.shape) === 1
? t2.shape
new Tensor(util.sizeFromShape(t1.shape) === 1
? t2.shape
: (t1.shape.length < t2.shape.length ? t2.shape : t1.shape)));
}

Expand Down Expand Up @@ -790,6 +800,25 @@ export class LeakyReLUNode extends Node {
validate() {}
}

/**
* PReLUNode represents a PReLU operation in the graph.
* @hidden
*/
export class PReLUNode extends Node {
static readonly X = 'x';
static readonly ALPHA = 'alpha';

constructor(graph: Graph, private x: Tensor, private alpha: Tensor) {
super(graph, 'PReLU', {x, alpha}, new Tensor(x.shape));
}

validate() {
util.assert(util.arraysEqual(this.x.shape, this.alpha.shape),
'Error adding pRelu op: the ' +
`shapes x: ${this.x.shape} and alpha: ${this.alpha.shape} must match.`);
}
}

export class EluNode extends Node {
static readonly X = 'x';
constructor(graph: Graph, x: Tensor) {
Expand Down
19 changes: 19 additions & 0 deletions src/graph/graph_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,25 @@ describe('leakyRelu validation', () => {
});
});

describe('pRelu validation', () => {
let g: Graph;

beforeEach(() => {
g = new Graph();
});

it('Different shapes throws', () => {
expect(() => g.prelu(new Tensor([5, 4]), new Tensor([1, 2, 3])))
.toThrowError();
});

it('Same size does not throw', () => {
expect(g.prelu(new Tensor([5, 4]), new Tensor([5, 4])).shape).toEqual([
5, 4
]);
});
});

describe('elu validation', () => {
let g: Graph;

Expand Down
7 changes: 5 additions & 2 deletions src/graph/operation_emitter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/

// tslint:disable-next-line:max-line-length
import {AddNode, ArgMaxEqualsNode, ArgMaxNode, Concat3DNode, Convolution2DNode, DivideNode, ExpNode, EluNode, FusedLinearCombinationNode, LogNode, MatMulNode, MaxPoolNode, MeanSquaredCostNode, MultiplyNode, Node, ReduceSumNode, ReLUNode, LeakyReLUNode, ReshapeNode, SigmoidNode, SoftmaxCrossEntropyCostNode, SoftmaxNode, SquareNode, SubtractNode, TanHNode} from './graph';
import {AddNode, ArgMaxEqualsNode, ArgMaxNode, Concat3DNode, Convolution2DNode, DivideNode, ExpNode, EluNode, FusedLinearCombinationNode, LogNode, MatMulNode, MaxPoolNode, MeanSquaredCostNode, MultiplyNode, Node, ReduceSumNode, ReLUNode, PReLUNode, LeakyReLUNode, ReshapeNode, SigmoidNode, SoftmaxCrossEntropyCostNode, SoftmaxNode, SquareNode, SubtractNode, TanHNode} from './graph';
import * as graph_util from './graph_util';
import {Add} from './ops/add';
import {ArgMax} from './ops/argmax';
Expand All @@ -25,7 +25,7 @@ import {Concat3D} from './ops/concat3d';
import {Convolution2D} from './ops/convolution';
import {Divide} from './ops/divide';
// tslint:disable-next-line:max-line-length
import {ReLU, Sigmoid, Square, TanH, LeakyReLU, Elu} from './ops/element_wise_activation';
import {ReLU, Sigmoid, Square, TanH, LeakyReLU, PReLU, Elu} from './ops/element_wise_activation';
import {MeanSquaredCost} from './ops/element_wise_cost';
import {Exp} from './ops/exp';
import {LinearCombination} from './ops/linear_combination';
Expand Down Expand Up @@ -72,6 +72,9 @@ function emitOpFromNode(node: Node): Operation[] {
} else if (node instanceof LeakyReLUNode) {
return [new LeakyReLU(node.inputs[LeakyReLUNode.X],
node.output, node.alpha)];
} else if (node instanceof PReLUNode) {
return [new PReLU(node.inputs[PReLUNode.X], node.inputs[PReLUNode.ALPHA],
node.output)];
} else if (node instanceof EluNode) {
return [new Elu(node.inputs[EluNode.X], node.output)];
} else if (node instanceof TanHNode) {
Expand Down
33 changes: 33 additions & 0 deletions src/graph/ops/element_wise_activation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,36 @@ export class Square extends ElementWiseActivation {
super(xTensor, yTensor, new EluFunc());
}
}

/**
* @hidden
*/
export class PReLU extends Operation {
constructor(
private xTensor: Tensor, private alphaTensor: Tensor,
private yTensor: Tensor) {
super();
}

feedForward(math: NDArrayMath, inferenceArrays: TensorArrayMap) {
const x = inferenceArrays.get(this.xTensor);
const alpha = inferenceArrays.get(this.alphaTensor);

math.scope((keep) => {
inferenceArrays.set(this.yTensor, keep(math.prelu(x, alpha)));
});
}

backProp(
math: NDArrayMath, inferenceArrays: TensorArrayMap,
gradientArrays: SummedTensorArrayMap) {
const x = inferenceArrays.get(this.xTensor);
const alpha = inferenceArrays.get(this.alphaTensor);
const dy = gradientArrays.get(this.yTensor);

math.scope(() => {
const dydx = math.preluDer(x, alpha);
gradientArrays.add(this.xTensor, math.elementWiseMul(dy, dydx));
});
}
}
38 changes: 35 additions & 3 deletions src/graph/ops/element_wise_activation_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ import {Tensor} from '../graph';
import {SummedTensorArrayMap, TensorArrayMap} from '../tensor_array_map';

// tslint:disable-next-line:max-line-length
import {Elu, LeakyReLU, ReLU, Sigmoid, Square, TanH} from './element_wise_activation';
import {Elu, LeakyReLU, PReLU, ReLU, Sigmoid, Square, TanH} from './element_wise_activation';
import {expectArraysClose} from "../../test_util";

describe('Element wise activation', () => {
let math: NDArrayMathCPU;
Expand Down Expand Up @@ -78,7 +79,8 @@ describe('Element wise activation', () => {
op.feedForward(math, activations);

const y = activations.get(yTensor);
expect(y.getValues()).toEqual(new Float32Array([3, 0, -0.2, 2, 9, -1.0]));
expectArraysClose(y.dataSync(),
new Float32Array([3, 0, -0.2, 2, 9, -1.0]));

// Backprop.
const dy = Array2D.new([2, 3], [1, 2, 3, 4, 5, 6]);
Expand All @@ -88,7 +90,37 @@ describe('Element wise activation', () => {

const dx = gradients.get(xTensor);

expect(dx.getValues()).toEqual(new Float32Array([1, 0, 0.6, 4, 5, 1.2]));
expectArraysClose(dx.dataSync(),
new Float32Array([1, 0, 0.6, 4, 5, 1.2]));
});

it('PReLU', () => {
const x = Array2D.new([2, 3], [3, 0, -1, 2, -9, -5]);
const alpha = Array2D.new([2, 3], [0.15, 0.15, 0.12, 0.3, 0.05, 0.01]);

const alphaTensor = new Tensor(x.shape);
xTensor = new Tensor(x.shape);
yTensor = new Tensor(x.shape);
activations.set(xTensor, x);
activations.set(alphaTensor, alpha);

const op = new PReLU(xTensor, alphaTensor, yTensor);
op.feedForward(math, activations);

const y = activations.get(yTensor);
expectArraysClose(y.dataSync(),
new Float32Array([3, 0, -0.12, 2, -0.45, -0.05]));

// Backprop.
const dy = Array2D.new([2, 3], [1, 2, 3, 4, 5, 6]);
gradients.add(yTensor, dy);

op.backProp(math, activations, gradients);

const dx = gradients.get(xTensor);

expectArraysClose(dx.dataSync(),
new Float32Array([1, 0, 0.36, 4, 0.25, 0.06]));
});

it('TanH', () => {
Expand Down
2 changes: 2 additions & 0 deletions src/math/backends/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ export interface MathBackend extends NDArrayStorage {
eluDer<T extends NDArray>(x: T): T;
selu<T extends NDArray>(x: T): T;
leakyRelu<T extends NDArray>(x: T, alpha: number): T;
prelu<T extends NDArray>(x: T, alpha: T): T;
preluDer<T extends NDArray>(x: T, alpha: T): T;

clip<T extends NDArray>(x: T, min: number, max: number): T;

Expand Down
32 changes: 32 additions & 0 deletions src/math/backends/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,38 @@ export class MathBackendCPU implements MathBackend {
return NDArray.make(x.shape, {values: resultValues}) as T;
}

prelu<T extends NDArray>(x: T, alpha: T) {
const resultValues = new Float32Array(x.size);
const values = x.dataSync();
const alphas = alpha.dataSync();
for (let i = 0; i < values.length; i++) {
const v = values[i];
if (v >= 0) {
resultValues[i] = v;
} else {
resultValues[i] = alphas[i] * v;
}
}
return NDArray.make(x.shape, {values: resultValues}) as T;
}

preluDer<T extends NDArray>(x: T, alpha: T) {
const resultValues = new Float32Array(x.size);
const values = x.dataSync();
const alphas = alpha.dataSync();
for (let i = 0; i < values.length; i++) {
const v = values[i];
if (v > 0) {
resultValues[i] = 1;
} else if (v < 0) {
resultValues[i] = alphas[i];
} else {
resultValues[i] = v;
}
}
return NDArray.make(x.shape, {values: resultValues}) as T;
}

clip<T extends NDArray>(x: T, min: number, max: number): T {
const resultValues = new Float32Array(x.size);
const values = x.getValues();
Expand Down
11 changes: 11 additions & 0 deletions src/math/backends/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,17 @@ export class MathBackendWebGL implements MathBackend {
return this.compileAndRun(program, [x]) as T;
}

prelu<T extends NDArray>(a: T, b: T): T {
const program = new BinaryOpProgram(binaryop_gpu.PRELU, a.shape, b.shape);
return this.compileAndRun(program, [a, b]) as T;
}

preluDer<T extends NDArray>(a: T, b: T): T {
const program = new BinaryOpProgram(binaryop_gpu.PRELU_DER,
a.shape, b.shape);
return this.compileAndRun(program, [a, b]) as T;
}

clip<T extends NDArray>(x: T, min: number, max: number): T {
const program = new ClipProgram(x.shape, min, max);
return this.compileAndRun(program, [x]) as T;
Expand Down
9 changes: 9 additions & 0 deletions src/math/backends/kernel_registry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import {SumInputConfig, SumNode} from './types/sum';
import {TopKIndicesInputConfig, TopKIndicesNode, TopKValuesInputConfig, TopKValuesNode} from './types/topk';
// tslint:disable-next-line:max-line-length
import {ClipInputConfig, ClipNode, LeakyReluInputConfig, LeakyReluNode, StepInputConfig, StepNode, TileInputConfig, TileNode, TransposeInputConfig, TransposeNode, UnaryInputConfig, UnaryNode} from './types/unary';
import {PReLUNode, PReLUInputConfig} from './types/prelu';

const KERNEL_METHODS: {
[kernel in keyof KernelConfigRegistry]: (
Expand Down Expand Up @@ -142,6 +143,12 @@ const KERNEL_METHODS: {
LeakyRelu: (backend: MathBackend, config: LeakyReluInputConfig<NDArray>) => {
return backend.leakyRelu(config.inputs.x, config.args.alpha);
},
PReLU: (backend: MathBackend, config: PReLUInputConfig<NDArray>) => {
return backend.prelu(config.inputs.x, config.inputs.alpha);
},
PReLUDer: (backend: MathBackend, config: PReLUInputConfig<NDArray>) => {
return backend.preluDer(config.inputs.x, config.inputs.alpha);
},
Elu: (backend: MathBackend, config: UnaryInputConfig<NDArray>) => {
return backend.elu(config.inputs.x);
},
Expand Down Expand Up @@ -296,6 +303,8 @@ export interface KernelConfigRegistry {
Square: UnaryNode<NDArray>;
Relu: UnaryNode<NDArray>;
LeakyRelu: LeakyReluNode<NDArray>;
PReLU: PReLUNode<NDArray>;
PReLUDer: PReLUNode<NDArray>;
Elu: UnaryNode<NDArray>;
EluDer: UnaryNode<NDArray>;
Selu: UnaryNode<NDArray>;
Expand Down
43 changes: 43 additions & 0 deletions src/math/backends/types/prelu.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/**
* @license
* Copyright 2017 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {NDArray} from '../../ndarray';
// tslint:disable-next-line:max-line-length
import {KernelInputConfig, KernelNode, TapeNodeInputArrays, TapeNodeInputGradientArrays} from '../tape_types';

// PReLU
export interface PReLUNode<T extends NDArray> extends KernelNode {
inputAndArgs: PReLUInputConfig<T>;
output: T;
gradient: (dy: T, y: T) => PReLUGradientInputArrays<T>;
}

export interface PReLUInputConfig<T extends NDArray> extends KernelInputConfig {
inputs: PReLUInputArrays<T>;
}

export interface PReLUInputArrays<T extends NDArray>
extends TapeNodeInputArrays {
x: T;
alpha: T;
}

export interface PReLUGradientInputArrays<T extends NDArray>
extends TapeNodeInputGradientArrays {
x: () => T;
alpha: () => T;
}
6 changes: 6 additions & 0 deletions src/math/backends/webgl/binaryop_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ export const EQUAL = `
if (isNaN(b)) return b;
return float(a == b);
`;
export const PRELU = `
return (a >= 0.0) ? a : b * a;
`;
export const PRELU_DER = `
return (a > 0.0) ? 1.0 : ((a < 0.0) ? b : a);
`;

export class BinaryOpProgram implements GPGPUProgram {
variableNames = ['A', 'B'];
Expand Down
Loading

0 comments on commit ba7d597

Please sign in to comment.