Skip to content

Commit

Permalink
Refactor tf.io.loadWeights() to use tf.io.decodeWeights() (tensorflow…
Browse files Browse the repository at this point in the history
…#1236)

Reduces code duplication

DEV
  • Loading branch information
caisq authored Aug 20, 2018
1 parent 11e6c8a commit 64d458c
Showing 1 changed file with 6 additions and 52 deletions.
58 changes: 6 additions & 52 deletions src/io/weights_loader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
* =============================================================================
*/

import {tensor} from '../ops/ops';
import {NamedTensorMap} from '../tensor_types';
import {TypedArray} from '../types';
import * as util from '../util';

import {decodeWeights} from './io_utils';
import {DTYPE_VALUE_SIZE_MAP, WeightsManifestConfig, WeightsManifestEntry} from './types';

/**
Expand Down Expand Up @@ -158,61 +158,15 @@ export async function loadWeights(
}

const weightsEntries = groupWeightsToFetch[i];

weightsEntries.forEach(weightsEntry => {
const byteBuffer = groupBuffer.slice(
weightsEntry.groupOffset,
weightsEntry.groupOffset + weightsEntry.sizeBytes);

let typedArray: TypedArray;

const dtype = weightsEntry.manifestEntry.dtype;

if ('quantization' in weightsEntry.manifestEntry) {
const quantization = weightsEntry.manifestEntry.quantization;
if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') {
throw new Error(
`Weight ${weightsEntry.manifestEntry.name} has unknown ` +
`quantization dtype ${quantization.dtype}.`);
}
const quantizedArray = (quantization.dtype === 'uint8') ?
new Uint8Array(byteBuffer) :
new Uint16Array(byteBuffer);
if (dtype === 'float32') {
typedArray = Float32Array.from(
quantizedArray, v => v * quantization.scale + quantization.min);
} else if (dtype === 'int32') {
typedArray = Int32Array.from(
quantizedArray,
v => Math.round(v * quantization.scale + quantization.min));
} else {
throw new Error(
`Weight ${weightsEntry.manifestEntry.name} has a dtype not ` +
`supported by quantization: ${dtype}`);
}
} else {
if (dtype === 'float32') {
typedArray = new Float32Array(byteBuffer);
} else if (dtype === 'int32') {
typedArray = new Int32Array(byteBuffer);
} else if (dtype === 'bool') {
typedArray = new Uint8Array(byteBuffer);
} else {
throw new Error(
`Weight ${weightsEntry.manifestEntry.name} has unknown dtype ` +
`${dtype}.`);
}
}

const weightName = weightsEntry.manifestEntry.name;
if (weightsTensorMap[weightName] != null) {
throw new Error(
`Duplicate weight with name ${weightName}. ` +
`Please make sure weights names are unique in the manifest JSON.`);
const nameToTensorMap =
decodeWeights(byteBuffer, [weightsEntry.manifestEntry]);
for (const name in nameToTensorMap) {
weightsTensorMap[name] = nameToTensorMap[name];
}
weightsTensorMap[weightName] = tensor(
typedArray, weightsEntry.manifestEntry.shape,
weightsEntry.manifestEntry.dtype);
});

bufferIndexOffset += numBuffers;
Expand Down

0 comments on commit 64d458c

Please sign in to comment.