Skip to content

Commit

Permalink
Add support for loading quantized weights (tensorflow#965)
Browse files Browse the repository at this point in the history
FEATURE
PERF Quantizing weights reduces model size and improves model download time
  • Loading branch information
adarob authored and dsmilkov committed Apr 25, 2018
1 parent ce93412 commit 242f627
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 18 deletions.
59 changes: 49 additions & 10 deletions src/weights_loader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,20 @@ export interface WeightsManifestGroupConfig {
export interface WeightsManifestEntry {
name: string;
shape: number[];
dtype: 'float32'|'int32';
dtype: 'float32'|'int32'; // Dtype of the (unquantized) weights.
quantization?: {
// Information to dequantize the weights.
scale: number, // The scaling constant to multiply by.
min: number, // The (possibly nudged) minimum weight to add.
dtype: 'uint16'|'uint8' // The dtype of the quantized weights.
};
}

const DTYPE_VALUE_SIZE_MAP: {[dtype: string]: number} = {
'float32': 4,
'int32': 4
'int32': 4,
'uint16': 2,
'uint8': 1
};

/**
Expand Down Expand Up @@ -68,7 +76,11 @@ export async function loadWeights(
manifest.forEach((manifestGroupConfig, groupIndex) => {
let groupOffset = 0;
manifestGroupConfig.weights.forEach(weightsEntry => {
const weightsBytes = DTYPE_VALUE_SIZE_MAP[weightsEntry.dtype] *
const rawDtype = ('quantization' in weightsEntry) ?
weightsEntry.quantization.dtype :
weightsEntry.dtype;

const weightsBytes = DTYPE_VALUE_SIZE_MAP[rawDtype] *
util.sizeFromShape(weightsEntry.shape);

const enqueueWeightsForFetchingFn = () => {
Expand Down Expand Up @@ -161,14 +173,41 @@ export async function loadWeights(
weightsEntry.groupOffset + weightsEntry.sizeBytes);

let typedArray: Float32Array|Int32Array;
if (weightsEntry.manifestEntry.dtype === 'float32') {
typedArray = new Float32Array(byteBuffer);
} else if (weightsEntry.manifestEntry.dtype === 'int32') {
typedArray = new Int32Array(byteBuffer);

const dtype = weightsEntry.manifestEntry.dtype;

if ('quantization' in weightsEntry.manifestEntry) {
const quantization = weightsEntry.manifestEntry.quantization;
if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') {
throw new Error(
`Weight ${weightsEntry.manifestEntry.name} has unknown ` +
`quantization dtype ${quantization.dtype}.`);
}
const quantizedArray = (quantization.dtype === 'uint8') ?
new Uint8Array(byteBuffer) :
new Uint16Array(byteBuffer);
if (dtype === 'float32') {
typedArray = Float32Array.from(
quantizedArray, v => v * quantization.scale + quantization.min);
} else if (dtype === 'int32') {
typedArray = Int32Array.from(
quantizedArray,
v => Math.round(v * quantization.scale + quantization.min));
} else {
throw new Error(
`Weight ${weightsEntry.manifestEntry.name} has a dtype not ` +
`supported by quantization: ${dtype}`);
}
} else {
throw new Error(
`Weight ${weightsEntry.manifestEntry.name} has unknown dtype ` +
`${weightsEntry.manifestEntry.dtype}.`);
if (dtype === 'float32') {
typedArray = new Float32Array(byteBuffer);
} else if (dtype === 'int32') {
typedArray = new Int32Array(byteBuffer);
} else {
throw new Error(
`Weight ${weightsEntry.manifestEntry.name} has unknown dtype ` +
`${dtype}.`);
}
}

const weightName = weightsEntry.manifestEntry.name;
Expand Down
129 changes: 121 additions & 8 deletions src/weights_loader_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@
* =============================================================================
*/
import * as tf from './index';
import {CPU_ENVS, expectArraysClose} from './test_util';
import {describeWithFlags} from './jasmine_util';
import {CPU_ENVS, expectArraysClose, expectArraysEqual} from './test_util';
import {WeightsManifestConfig} from './weights_loader';

describeWithFlags('loadWeights', CPU_ENVS, () => {
const setupFakeWeightFiles =
(fileBufferMap:
{[filename: string]: Float32Array|Int32Array|ArrayBuffer}) => {
spyOn(window, 'fetch').and.callFake((path: string) => {
return new Response(fileBufferMap[path]);
});
};
const setupFakeWeightFiles = (fileBufferMap: {
[filename: string]: Float32Array|Int32Array|ArrayBuffer|Uint8Array|
Uint16Array
}) => {
spyOn(window, 'fetch').and.callFake((path: string) => {
return new Response(fileBufferMap[path]);
});
};

it('1 group, 1 weight, 1 requested weight', done => {
setupFakeWeightFiles({'./weightfile0': new Float32Array([1, 2, 3])});
Expand Down Expand Up @@ -465,4 +466,116 @@ describeWithFlags('loadWeights', CPU_ENVS, () => {
.then(done)
.catch(done.fail);
});

const quantizationTest =
(quantizationDtype: 'uint8'|'uint16', done: DoneFn) => {
const arrayType =
quantizationDtype === 'uint8' ? Uint8Array : Uint16Array;
setupFakeWeightFiles(
{'./weightfile0': new arrayType([0, 48, 255, 0, 48, 255])});

const manifest: WeightsManifestConfig = [{
'paths': ['weightfile0'],
'weights': [
{
'name': 'weight0',
'dtype': 'float32',
'shape': [3],
'quantization':
{'min': -1, 'scale': 0.1, 'dtype': quantizationDtype}
},
{
'name': 'weight1',
'dtype': 'int32',
'shape': [3],
'quantization':
{'min': -1, 'scale': 0.1, 'dtype': quantizationDtype}
}
]
}];

const weightsNamesToFetch = ['weight0', 'weight1'];
tf.loadWeights(manifest, './', weightsNamesToFetch)
.then(weights => {
expect((window.fetch as jasmine.Spy).calls.count()).toBe(1);

const weightNames = Object.keys(weights);
expect(weightNames.length).toEqual(weightsNamesToFetch.length);

const weight0 = weights['weight0'];
expectArraysClose(weight0, [-1, 3.8, 24.5]);
expect(weight0.shape).toEqual([3]);
expect(weight0.dtype).toEqual('float32');

const weight1 = weights['weight1'];
expectArraysEqual(weight1, [-1, 4, 25]);
expect(weight1.shape).toEqual([3]);
expect(weight1.dtype).toEqual('int32');
})
.then(done)
.catch(done.fail);
};

it('quantized weights (uint8)', done => {
quantizationTest('uint8', done);
});

it('quantized weights (uint16)', done => {
quantizationTest('uint16', done);
});

it('2 groups, 1 quantized, 1 unquantized', done => {
setupFakeWeightFiles({
'./weightfile0': new Uint8Array([0, 48, 255, 0, 48, 255]),
'./weightfile1': new Float32Array([6, 7, 8, 9])
});

const manifest: WeightsManifestConfig = [
{
'paths': ['weightfile0'],
'weights': [
{
'name': 'weight0',
'dtype': 'float32',
'shape': [3],
'quantization': {'min': -1, 'scale': 0.1, 'dtype': 'uint8'}
},
{
'name': 'weight1',
'dtype': 'int32',
'shape': [3],
'quantization': {'min': -1, 'scale': 0.1, 'dtype': 'uint8'}
}
]
},
{
'paths': ['weightfile1'],
'weights': [
{'name': 'weight2', 'dtype': 'float32', 'shape': [3, 1]},
{'name': 'weight3', 'dtype': 'float32', 'shape': []}
]
}
];

tf.loadWeights(manifest, './', ['weight0', 'weight2'])
.then(weights => {
// Both groups need to be fetched.
expect((window.fetch as jasmine.Spy).calls.count()).toBe(2);

const weightNames = Object.keys(weights);
expect(weightNames.length).toEqual(2);

const weight0 = weights['weight0'];
expectArraysClose(weight0, [-1, 3.8, 24.5]);
expect(weight0.shape).toEqual([3]);
expect(weight0.dtype).toEqual('float32');

const weight2 = weights['weight2'];
expectArraysClose(weight2, [6, 7, 8]);
expect(weight2.shape).toEqual([3, 1]);
expect(weight2.dtype).toEqual('float32');
})
.then(done)
.catch(done.fail);
});
});

0 comments on commit 242f627

Please sign in to comment.