Skip to content

Commit

Permalink
Add variable() to chain API. (tensorflow#730)
Browse files Browse the repository at this point in the history
  • Loading branch information
Nikhil Thorat authored Feb 15, 2018
1 parent 8ea6dd8 commit 611838b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
5 changes: 5 additions & 0 deletions src/tensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,11 @@ export class Tensor<R extends Rank = Rank> {
return ops.localResponseNormalization(
this, radius, bias, alpha, beta, normRegion);
}

variable(trainable = true, name?: string, dtype?: DataType): Variable<R> {
this.throwIfDisposed();
return Variable.variable(this, trainable, name, dtype);
}
}

/** @doclink Tensor */
Expand Down
16 changes: 15 additions & 1 deletion src/variable_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ describeWithFlags('variable', ALL_ENVS, () => {
expectArraysClose(v, [4, 5, 6]);
});

it('simple chain assign', () => {
const v = dl.tensor1d([1, 2, 3]).variable();
expectArraysClose(v, [1, 2, 3]);

v.assign(dl.tensor1d([4, 5, 6]));
expectArraysClose(v, [4, 5, 6]);
});

it('default names are unique', () => {
const v = variable(dl.tensor1d([1, 2, 3]));
expect(v.name).not.toBeNull();
Expand All @@ -50,13 +58,19 @@ describeWithFlags('variable', ALL_ENVS, () => {
.toThrowError();
});

it('math ops can take variables', () => {
it('ops can take variables', () => {
const value = dl.tensor1d([1, 2, 3]);
const v = variable(value);
const res = dl.sum(v);
expectArraysClose(res, [6]);
});

it('chained variables works', () => {
const v = dl.tensor1d([1, 2, 3]).variable();
const res = dl.sum(v);
expectArraysClose(res, [6]);
});

it('variables are not affected by tidy', () => {
let v: Variable<Rank.R1>;
expect(dl.memory().numTensors).toBe(0);
Expand Down

0 comments on commit 611838b

Please sign in to comment.