Skip to content

Commit

Permalink
Finish gradient for tf.pow (tensorflow#954)
Browse files Browse the repository at this point in the history
  • Loading branch information
jgartman authored and dsmilkov committed Apr 17, 2018
1 parent 184d339 commit 3656e27
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 20 deletions.
95 changes: 89 additions & 6 deletions src/ops/arithmetic_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -566,34 +566,42 @@ describeWithFlags('pow', ALL_ENVS, () => {
const b = tf.scalar(2, 'int32');
const dy = tf.scalar(3);

const grad = tf.grad(a => tf.pow(a, b));
const da = grad(a, dy);
const grads = tf.grads((a, b) => tf.pow(a, b));
const [da, db] = grads([a, b], dy);

expect(da.shape).toEqual(a.shape);
expect(da.dtype).toEqual('float32');
expectArraysClose(da, [2 * 5 * 3]);

expect(db.shape).toEqual(b.shape);
expect(db.dtype).toEqual('float32');
expectArraysClose(db, [3 * Math.pow(5, 2) * Math.log(5)]);
});

it('gradients: Scalar ^ Scalar fractional exponent', () => {
const a = tf.scalar(4.0);
const b = tf.scalar(1.5);
const dy = tf.scalar(3.0);

const grad = tf.grad(a => tf.pow(a, b));
const da = grad(a, dy);
const grads = tf.grads((a, b) => tf.pow(a, b));
const [da, db] = grads([a, b], dy);

expect(da.shape).toEqual(a.shape);
expect(da.dtype).toEqual('float32');
expectArraysClose(da, [1.5 * Math.pow(4, 0.5) * 3]);

expect(db.shape).toEqual(b.shape);
expect(db.dtype).toEqual('float32');
expectArraysClose(db, [3.0 * Math.pow(4, 1.5) * Math.log(4.0)]);
});

it('gradients: Tensor ^ Tensor', () => {
const a = tf.tensor1d([-1, .5, 2]);
const b = tf.tensor1d([3, 2, -1], 'int32');
const dy = tf.tensor1d([1, 5, 10]);

const grad = tf.grad(a => tf.pow(a, b));
const da = grad(a, dy);
const grads = tf.grads((a, b) => tf.pow(a, b));
const [da, db] = grads([a, b], dy);

expect(da.shape).toEqual(a.shape);
expect(da.dtype).toEqual('float32');
Expand All @@ -604,6 +612,81 @@ describeWithFlags('pow', ALL_ENVS, () => {
-1 * Math.pow(2, -2) * 10
],
1e-1);

expect(db.shape).toEqual(b.shape);
expect(db.dtype).toEqual('float32');
expectArraysClose(db, [
NaN, 5 * Math.pow(.5, 2) * Math.log(.5),
10 * Math.pow(2, -1) * Math.log(2)
]);
});

it('gradient: scalar / Tensor1D', () => {
const a = tf.scalar(2);
const b = tf.tensor1d([3, 4, 5]);
const dy = tf.tensor1d([6, 7, 8]);

const grads = tf.grads((a, b) => tf.pow(a, b));
const [da, db] = grads([a, b], dy);

expect(da.shape).toEqual(a.shape);
expect(da.dtype).toEqual('float32');
expectArraysClose(da, [
6 * 3 * Math.pow(2, 2) + 7 * 4 * Math.pow(2, 3) + 8 * 5 * Math.pow(2, 4)
]);

expect(db.shape).toEqual(b.shape);
expect(db.dtype).toEqual('float32');
expectArraysClose(db, [
6 * Math.pow(2, 3) * Math.log(2), 7 * Math.pow(2, 4) * Math.log(2),
8 * Math.pow(2, 5) * Math.log(2)
]);
});

it('gradient: Tensor2D / scalar', () => {
const a = tf.tensor2d([[2, 3], [4, 5]], [2, 2]);
const b = tf.scalar(2);
const dy = tf.tensor2d([[6, 7], [8, 9]], [2, 2]);

const grads = tf.grads((a, b) => tf.pow(a, b));
const [da, db] = grads([a, b], dy);

expect(da.shape).toEqual(a.shape);
expect(da.dtype).toEqual('float32');
expectArraysClose(da, [
6 * 2 * Math.pow(2, 1), 7 * 2 * Math.pow(3, 1), 8 * 2 * Math.pow(4, 1),
9 * 2 * Math.pow(5, 1)
]);

expect(db.shape).toEqual(b.shape);
expect(db.dtype).toEqual('float32');
expectArraysClose(
db,
[6 * Math.pow(2, 2) * Math.log(2) + 7 * Math.pow(3, 2) * Math.log(3) +
8 * Math.pow(4, 2) * Math.log(4) + 9 * Math.pow(5, 2) * Math.log(5)]);
});

it('gradient: Tensor2D / Tensor2D w/ broadcast', () => {
const a = tf.tensor2d([3, 4], [2, 1]);
const b = tf.tensor2d([[2, 3], [4, 5]], [2, 2]);
const dy = tf.tensor2d([[6, 7], [8, 9]], [2, 2]);

const grads = tf.grads((a, b) => tf.pow(a, b));
const [da, db] = grads([a, b], dy);

expect(da.shape).toEqual(a.shape);
expect(da.dtype).toEqual('float32');
expectArraysClose(da, [
6 * 2 * Math.pow(3, 1) + 7 * 3 * Math.pow(3, 2),
8 * 4 * Math.pow(4, 3) + 9 * 5 * Math.pow(4, 4)
]);

expect(db.shape).toEqual(b.shape);
expect(db.dtype).toEqual('float32');
expectArraysClose(db, [
6 * Math.pow(3, 2) * Math.log(3), 7 * Math.pow(3, 3) * Math.log(3),
8 * Math.pow(4, 4) * Math.log(4), 9 * Math.pow(4, 5) * Math.log(4)
]);
});
});

Expand Down
36 changes: 22 additions & 14 deletions src/ops/binary_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -190,26 +190,34 @@ export class BinaryOps {
@doc({heading: 'Operations', subheading: 'Arithmetic'})
@operation
static pow<T extends Tensor>(base: T, exp: Tensor): T {
broadcast_util.assertAndGetBroadcastShape(base.shape, exp.shape);
const outShape =
broadcast_util.assertAndGetBroadcastShape(base.shape, exp.shape);
base = base.cast(upcastType(base.dtype, exp.dtype));
exp = exp.cast(upcastType(base.dtype, exp.dtype));

const grad = (dy: Tensor) => {
if (!util.arraysEqual(base.shape, exp.shape) &&
!util.isScalarShape(exp.shape)) {
throw new Error(
`Gradient of pow not yet supported for broadcasted shapes.`);
}
const grad = (dy: Tensor, saved: Tensor[]) => {
const [y] = saved;
const derBase = () => {
const expFloat = exp.toFloat();
const dx =
expFloat.mul(base.toFloat().pow(expFloat.sub(scalar(1)))) as T;
return dy.mulStrict(dx) as T;
let res = dy.mul(exp.toFloat().mul(y.div(base)));
const reduceAxes =
broadcast_util.getReductionAxes(base.shape, outShape);
if (reduceAxes.length > 0) {
res = res.sum(reduceAxes);
}
return res.reshape(base.shape) as T;
};
const derExp = () => {
let res = dy.mul(y.mul(base.log()).toFloat());
const reduceAxes = broadcast_util.getReductionAxes(exp.shape, outShape);
if (reduceAxes.length > 0) {
res = res.sum(reduceAxes);
}
return res.reshape(exp.shape);
};
return {base: derBase};
return {base: derBase, exp: derExp};
};
return ENV.engine.runKernel(
backend => backend.pow(base, exp), {base}, grad) as T;
(backend, save) => save(backend.pow(base, exp)), {base, exp},
grad) as T;
}

/**
Expand Down

0 comments on commit 3656e27

Please sign in to comment.