Skip to content

Commit

Permalink
Add tf.addN(), allow bool weights and improve tslint (tensorflow#1190)
Browse files Browse the repository at this point in the history
- Add tf.addN() following [tf.add_n](https://www.tensorflow.org/api_docs/python/tf/add_n) in Python. This is needed when converting the `ssd_mobilenetv2_coco` model.
- Add ability to load weights of dtype `bool`
- Improve tslint rule so `imports {}` and `exports {}` are not subject to 80 width line length
- Also fix circular dep which was disabled on some files due to `//tslint:disable`

FEATURE
  • Loading branch information
dsmilkov authored Jul 26, 2018
1 parent d2530ad commit 8049594
Show file tree
Hide file tree
Showing 99 changed files with 390 additions and 311 deletions.
14 changes: 4 additions & 10 deletions integration_tests/benchmarks/math-benchmark-run-groups.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,12 @@
* =============================================================================
*/

// tslint:disable-next-line:max-line-length
import {BatchNormalization3DCPUBenchmark, BatchNormalization3DGPUBenchmark} from './batchnormalization3d_benchmark';
import {BenchmarkRun, BenchmarkRunGroup} from './benchmark';
// tslint:disable-next-line:max-line-length
import {ConvGPUBenchmark, ConvParams, DepthwiseConvParams, RegularConvParams} from './conv_benchmarks';
// tslint:disable-next-line:max-line-length
import {MatmulCPUBenchmark, MatmulGPUBenchmark} from './matmul_benchmarks';
// tslint:disable-next-line:max-line-length
import {PoolBenchmarkParams, PoolCPUBenchmark, PoolGPUBenchmark} from './pool_benchmarks';
// tslint:disable-next-line:max-line-length
import {ReductionOpsCPUBenchmark, ReductionOpsGPUBenchmark} from './reduction_ops_benchmark';
// tslint:disable-next-line:max-line-length
import {UnaryOpsCPUBenchmark, UnaryOpsGPUBenchmark} from './unary_ops_benchmark';

export function getRunGroups(): BenchmarkRunGroup[] {
Expand Down Expand Up @@ -105,10 +99,10 @@ export function getRunGroups(): BenchmarkRunGroup[] {
max: 1024,
stepToSizeTransformation: (step: number) => Math.max(1, step),
options: [
'log', 'exp', 'neg', 'ceil', 'floor', 'log1p', 'sqrt', 'square',
'abs', 'relu', 'elu', 'selu', 'leakyRelu', 'prelu', 'sigmoid',
'sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'sinh', 'cosh',
'tanh', 'step'
'log', 'exp', 'neg', 'ceil', 'floor', 'log1p', 'sqrt',
'square', 'abs', 'relu', 'elu', 'selu', 'leakyRelu', 'prelu',
'sigmoid', 'sin', 'cos', 'tan', 'asin', 'acos', 'atan',
'sinh', 'cosh', 'tanh', 'step'
],
selectedOption: 'log',
stepSize: 64,
Expand Down
2 changes: 1 addition & 1 deletion models/knn_image_classifier/knn_image_classifier.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
* =============================================================================
*/
// tslint:disable-next-line:max-line-length

import * as dl from 'deeplearn';
import {Tensor1D, Tensor2D, Tensor3D} from 'deeplearn';
import {SqueezeNet} from 'deeplearn-squeezenet';
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"rollup-plugin-uglify": "~3.0.0",
"shelljs": "~0.7.8",
"ts-node": "~7.0.0",
"tslint": "~5.8.0",
"tslint": "~5.11.0",
"tslint-no-circular-imports": "~0.5.0",
"typescript": "2.9.2",
"yalc": "~1.0.0-pre.21"
Expand Down
2 changes: 1 addition & 1 deletion src/debug_mode_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import * as tf from './index';
import {describeWithFlags} from './jasmine_util';
import {convertToTensor} from './tensor_util';
import {convertToTensor} from './tensor_util_env';
import {ALL_ENVS, expectArraysClose} from './test_util';

describeWithFlags('debug on', ALL_ENVS, () => {
Expand Down
3 changes: 0 additions & 3 deletions src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,8 @@

import {BackendTimingInfo, KernelBackend} from './kernels/backend';
import {Profiler} from './profiler';
// tslint:disable-next-line:max-line-length
import {backpropagateGradients, getFilteredNodesXToY, NamedGradientMap, TapeNode} from './tape';
// tslint:disable-next-line:max-line-length
import {DataId, Tensor, Tensor3D, Variable} from './tensor';
// tslint:disable-next-line:max-line-length
import {NamedTensorMap, NamedVariableMap, TensorContainer} from './tensor_types';
import {getTensorsInContainer, isTensorInList} from './tensor_util';
import {TypedArray} from './types';
Expand Down
1 change: 0 additions & 1 deletion src/engine_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import * as tf from './index';
import {describeWithFlags} from './jasmine_util';
// tslint:disable-next-line:max-line-length
import {ALL_ENVS, expectArraysClose, expectArraysEqual, expectNumbersClose, WEBGL_ENVS} from './test_util';

describeWithFlags('fromPixels + regular math op', WEBGL_ENVS, () => {
Expand Down
2 changes: 0 additions & 2 deletions src/environment.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
*/

import * as device_util from './device_util';

import {Engine, MemoryInfo, ScopeFn, TimingInfo} from './engine';
// tslint:disable-next-line:max-line-length
import {Features, getFeaturesFromURL, getWebGLDisjointQueryTimerVersion, isChrome, isDownloadFloatTextureEnabled, isRenderToFloatTextureEnabled, isWebGLFenceEnabled, isWebGLVersionEnabled} from './environment_util';
import {KernelBackend} from './kernels/backend';
import {setTensorTracker, Tensor, TensorTracker} from './tensor';
Expand Down
9 changes: 8 additions & 1 deletion src/environment_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import * as device_util from './device_util';
import {ENV, Environment} from './environment';
import {Features} from './environment_util';
import {Features, getQueryParams} from './environment_util';
import {describeWithFlags} from './jasmine_util';
import {KernelBackend} from './kernels/backend';
import {MathBackendCPU} from './kernels/backend_cpu';
Expand Down Expand Up @@ -275,3 +275,10 @@ describe('Backend', () => {
ENV.removeBackend('custom');
});
});

describe('environment_util.getQueryParams', () => {
it('basic', () => {
expect(getQueryParams('?a=1&b=hi&f=animal'))
.toEqual({'a': '1', 'b': 'hi', 'f': 'animal'});
});
});
16 changes: 14 additions & 2 deletions src/environment_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
* =============================================================================
*/

import {getQueryParams} from './util';

export interface Features {
// Whether to enable debug mode.
'DEBUG'?: boolean;
Expand Down Expand Up @@ -278,3 +276,17 @@ function createFloatTextureAndBindToFramebuffer(
gl.framebufferTexture2D(
gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
}

export function getQueryParams(queryString: string): {[key: string]: string} {
const params = {};
queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, (s, ...t) => {
decodeParam(params, t[0], t[1]);
return t.join('=');
});
return params;
}

function decodeParam(
params: {[key: string]: string}, name: string, value?: string) {
params[decodeURIComponent(name)] = decodeURIComponent(value || '');
}
1 change: 0 additions & 1 deletion src/globals.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
*/

import {Environment} from './environment';
// tslint:disable-next-line:max-line-length
export {customGrad, grad, grads, valueAndGrad, valueAndGrads, variableGrads} from './gradients';

export const tidy = Environment.tidy;
Expand Down
1 change: 0 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ export {MomentumOptimizer} from './optimizers/momentum_optimizer';
export {Optimizer} from './optimizers/optimizer';
export {RMSPropOptimizer} from './optimizers/rmsprop_optimizer';
export {SGDOptimizer} from './optimizers/sgd_optimizer';
// tslint:disable-next-line:max-line-length
export {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, TensorBuffer, variable, Variable} from './tensor';
export {NamedTensorMap} from './tensor_types';
export {DataType, Rank, ShapeMap} from './types';
Expand Down
4 changes: 0 additions & 4 deletions src/io/browser_files.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,11 @@
* user-selected files in browser.
*/

// tslint:disable:max-line-length
import {ENV} from '../environment';

import {basename, concatenateArrayBuffers, getModelArtifactsInfoForJSON} from './io_utils';
import {IORouter, IORouterRegistry} from './router_registry';
import {IOHandler, ModelArtifacts, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';

// tslint:enable:max-line-length

const DEFAULT_FILE_NAME_PREFIX = 'model';
const DEFAULT_JSON_EXTENSION_NAME = '.json';
const DEFAULT_WEIGHT_DATA_EXTENSION_NAME = '.weights.bin';
Expand Down
1 change: 0 additions & 1 deletion src/io/browser_files_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import * as tf from '../index';
import {describeWithFlags} from '../jasmine_util';
import {BROWSER_ENVS} from '../test_util';
// tslint:disable-next-line:max-line-length
import {browserDownloads, BrowserDownloads, browserDownloadsRouter} from './browser_files';
import {WeightsManifestConfig, WeightsManifestEntry} from './types';

Expand Down
7 changes: 1 addition & 6 deletions src/io/browser_http.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,10 @@
*/

import {assert} from '../util';

// tslint:disable:max-line-length
import {concatenateArrayBuffers, getModelArtifactsInfoForJSON} from './io_utils';
import {IORouter, IORouterRegistry} from './router_registry';
import {IOHandler, ModelArtifacts, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
import {loadWeightsAsArrayBuffer} from './weights_loader';
// tslint:enable:max-line-length

export class BrowserHTTPRequest implements IOHandler {
protected readonly path: string;
Expand Down Expand Up @@ -173,7 +170,6 @@ export const httpRequestRouter: IORouter = (url: string) => {
IORouterRegistry.registerSaveRouter(httpRequestRouter);
IORouterRegistry.registerLoadRouter(httpRequestRouter);

// tslint:disable:max-line-length
/**
* Creates an IOHandler subtype that sends model artifacts to HTTP server.
*
Expand All @@ -184,7 +180,7 @@ IORouterRegistry.registerLoadRouter(httpRequestRouter);
* - A JSON file consisting of `modelTopology` and `weightsManifest`.
* - A binary weights file consisting of the concatenated weight values.
* These files are in the same format as the one generated by
* [tensorflowjs_converter](https://js.tensorflow.org/tutorials/import-keras.html).
* [tfjs_converter](https://js.tensorflow.org/tutorials/import-keras.html).
*
* The following code snippet exemplifies the client-side code that uses this
* function:
Expand Down Expand Up @@ -309,7 +305,6 @@ IORouterRegistry.registerLoadRouter(httpRequestRouter);
* If `requestInit` has a `body`, an Error will be thrown.
* @returns An instance of `IOHandler`.
*/
// tslint:enable:max-line-length
export function browserHTTPRequest(
path: string, requestInit?: RequestInit): IOHandler {
return new BrowserHTTPRequest(path, requestInit);
Expand Down
2 changes: 0 additions & 2 deletions src/io/indexed_db.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@
* =============================================================================
*/

// tslint:disable:max-line-length
import {ENV} from '../environment';
import {getModelArtifactsInfoForJSON} from './io_utils';
import {ModelStoreManagerRegistry} from './model_management';
import {IORouter, IORouterRegistry} from './router_registry';
import {IOHandler, ModelArtifacts, ModelArtifactsInfo, ModelStoreManager, SaveResult} from './types';
// tslint:enable:max-line-length

const DATABASE_NAME = 'tensorflowjs';
const DATABASE_VERSION = 1;
Expand Down
1 change: 0 additions & 1 deletion src/io/indexed_db_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import * as tf from '../index';
import {describeWithFlags} from '../jasmine_util';
import {BROWSER_ENVS, expectArrayBuffersEqual} from '../test_util';
// tslint:disable-next-line:max-line-length
import {browserIndexedDB, BrowserIndexedDB, BrowserIndexedDBManager, deleteDatabase, indexedDBRouter} from './indexed_db';

describeWithFlags('IndexedDB', BROWSER_ENVS, () => {
Expand Down
3 changes: 0 additions & 3 deletions src/io/io.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
* =============================================================================
*/

// tslint:disable:max-line-length
// Importing local_storage and indexed_db is necessary for the routers to be
// registered.
import './indexed_db';
Expand All @@ -36,8 +35,6 @@ const getLoadHandlers = IORouterRegistry.getLoadHandlers;

export {copyModel, listModels, moveModel, removeModel} from './model_management';

// tslint:enable:max-line-length

export {
browserFiles,
browserHTTPRequest,
Expand Down
2 changes: 0 additions & 2 deletions src/io/io_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ import {Tensor} from '../tensor';
import {NamedTensorMap} from '../tensor_types';
import {TypedArray} from '../types';
import {sizeFromShape} from '../util';

// tslint:disable-next-line:max-line-length
import {DTYPE_VALUE_SIZE_MAP, ModelArtifacts, ModelArtifactsInfo, WeightsManifestEntry} from './types';

/**
Expand Down
2 changes: 0 additions & 2 deletions src/io/io_utils_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ import * as tf from '../index';
import {scalar, tensor1d, tensor2d} from '../ops/ops';
import {NamedTensorMap} from '../tensor_types';
import {expectArraysEqual} from '../test_util';

// tslint:disable-next-line:max-line-length
import {arrayBufferToBase64String, base64StringToArrayBuffer, basename, concatenateArrayBuffers, concatenateTypedArrays, stringByteLength} from './io_utils';

describe('concatenateTypedArrays', () => {
Expand Down
4 changes: 0 additions & 4 deletions src/io/local_storage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,13 @@
* =============================================================================
*/

// tslint:disable:max-line-length
import {ENV} from '../environment';
import {assert} from '../util';

import {arrayBufferToBase64String, base64StringToArrayBuffer, getModelArtifactsInfoForJSON} from './io_utils';
import {ModelStoreManagerRegistry} from './model_management';
import {IORouter, IORouterRegistry} from './router_registry';
import {IOHandler, ModelArtifacts, ModelArtifactsInfo, ModelStoreManager, SaveResult} from './types';

// tslint:enable:max-line-length

const PATH_SEPARATOR = '/';
const PATH_PREFIX = 'tensorflowjs_models';
const INFO_SUFFIX = 'info';
Expand Down
1 change: 0 additions & 1 deletion src/io/local_storage_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import * as tf from '../index';
import {describeWithFlags} from '../jasmine_util';
import {BROWSER_ENVS} from '../test_util';
import {arrayBufferToBase64String, base64StringToArrayBuffer} from './io_utils';
// tslint:disable-next-line:max-line-length
import {browserLocalStorage, BrowserLocalStorage, BrowserLocalStorageManager, localStorageRouter, purgeLocalStorageArtifacts} from './local_storage';

describeWithFlags('LocalStorage', BROWSER_ENVS, () => {
Expand Down
2 changes: 0 additions & 2 deletions src/io/passthrough.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
* IOHandlers that pass through the in-memory ModelArtifacts format.
*/

// tslint:disable:max-line-length
import {IOHandler, ModelArtifacts, SaveResult, WeightsManifestEntry} from './types';
// tslint:enable:max-line-length

class PassthroughLoader implements IOHandler {
constructor(
Expand Down
7 changes: 4 additions & 3 deletions src/io/weights_loader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
* =============================================================================
*/

// tslint:disable:max-line-length
import {tensor} from '../ops/ops';
import {NamedTensorMap} from '../tensor_types';
import {TypedArray} from '../types';
import * as util from '../util';
import {DTYPE_VALUE_SIZE_MAP, WeightsManifestConfig, WeightsManifestEntry} from './types';
// tslint:enable:max-line-length

/**
* Reads binary weights data from a number of URLs.
Expand Down Expand Up @@ -165,7 +164,7 @@ export async function loadWeights(
weightsEntry.groupOffset,
weightsEntry.groupOffset + weightsEntry.sizeBytes);

let typedArray: Float32Array|Int32Array;
let typedArray: TypedArray;

const dtype = weightsEntry.manifestEntry.dtype;

Expand Down Expand Up @@ -196,6 +195,8 @@ export async function loadWeights(
typedArray = new Float32Array(byteBuffer);
} else if (dtype === 'int32') {
typedArray = new Int32Array(byteBuffer);
} else if (dtype === 'bool') {
typedArray = new Uint8Array(byteBuffer);
} else {
throw new Error(
`Weight ${weightsEntry.manifestEntry.name} has unknown dtype ` +
Expand Down
Loading

0 comments on commit 8049594

Please sign in to comment.