Skip to content

Commit

Permalink
Move Variable to ndarray.ts to avoid circular dep problem. (tenso…
Browse files Browse the repository at this point in the history
  • Loading branch information
dsmilkov authored Jan 4, 2018
1 parent db4d80e commit 1ff2ef3
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 85 deletions.
3 changes: 1 addition & 2 deletions src/math/math.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@ import * as broadcast_util from './broadcast_util';
import * as concat_util from './concat_util';
import * as conv_util from './conv_util';
// tslint:disable-next-line:max-line-length
import {Array1D, Array2D, Array3D, Array4D, DataType, DataTypeMap, NDArray, Rank, RankMap, Scalar} from './ndarray';
import {Array1D, Array2D, Array3D, Array4D, DataType, DataTypeMap, NDArray, Rank, RankMap, Scalar, Variable} from './ndarray';
import * as slice_util from './slice_util';
import {SumTypes} from './types';
import {Variable} from './variable';

export interface LSTMCell {
(data: Array2D, c: Array2D, h: Array2D): [Array2D, Array2D];
Expand Down
61 changes: 61 additions & 0 deletions src/math/ndarray.ts
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,67 @@ export class Array4D<D extends DataType = DataType> extends NDArray<D, '4'> {
}
}

export class Variable<D extends DataType = DataType, R extends Rank = Rank>
extends NDArray<D, R> {
private static nextVarId = 0;
name: string;

/**
* Private constructor since we can not add logic before calling super().
* Instead, we expose static `Variable.variable` method below, which will be
* added to global namespace.
*/
private constructor(
initialValue: NDArray<D, R>, public trainable = true, name?: string) {
super(
initialValue.shape, initialValue.dtype, null /* values */,
initialValue.dataId);
this.name = name;
if (this.name == null) {
this.name = Variable.nextVarId.toString();
Variable.nextVarId++;
}
ENV.math.registerVariable(this);
}

/**
* Creates a new variable with the provided initial value.
*
* @param initialValue An ndarray.
* @param trainable If true, optimizers are allowed to update it.
* @param name Name of the variable. Defaults to a unique id.
* @param dtype If set, initialValue will be converted to the given type.
*/
static variable<D extends DataType, R extends Rank>(
initialValue: NDArray<D, R>, trainable = true, name?: string,
dtype?: D): Variable<D, R> {
if (dtype != null && dtype !== initialValue.dtype) {
initialValue = initialValue.asType(dtype);
}
return new Variable(initialValue, trainable, name);
}

/** Assign a new array to this variable. The old array will be disposed. */
assign(newValue: NDArray<D, R>): void {
if (newValue.dtype !== this.dtype) {
throw new Error(
`dtype of the new value (${newValue.dtype}) and ` +
`previous value (${this.dtype}) must match`);
}
if (!util.arraysEqual(newValue.shape, this.shape)) {
throw new Error(
`shape of the new value (${newValue.shape}) and ` +
`previous value (${this.shape}) must match`);
}
this.math.disposeData(this.dataId);
this.dataId = newValue.dataId;
ENV.math.register(this);
}
}

const variable = Variable.variable;
export {variable};

function copyTypedArray<D extends DataType>(
array: DataTypeMap[D]|number[]|boolean[], dtype: D): DataTypeMap[D] {
if (dtype == null || dtype === 'float32') {
Expand Down
81 changes: 0 additions & 81 deletions src/math/variable.ts

This file was deleted.

4 changes: 2 additions & 2 deletions src/math/variable_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

import * as test_util from '../test_util';
import {MathTests} from '../test_util';
import {Array1D, Array2D, Array3D, Array4D, NDArray, Scalar} from './ndarray';
import {variable, Variable} from './variable';
// tslint:disable-next-line:max-line-length
import {Array1D, Array2D, Array3D, Array4D, NDArray, Scalar, variable, Variable} from './ndarray';

const tests: MathTests = it => {
it('simple assign', math => {
Expand Down

0 comments on commit 1ff2ef3

Please sign in to comment.