diff --git a/src/io/weights_loader.ts b/src/io/weights_loader.ts index 46beae003a..ea70de03d6 100644 --- a/src/io/weights_loader.ts +++ b/src/io/weights_loader.ts @@ -15,10 +15,10 @@ * ============================================================================= */ -import {tensor} from '../ops/ops'; import {NamedTensorMap} from '../tensor_types'; -import {TypedArray} from '../types'; import * as util from '../util'; + +import {decodeWeights} from './io_utils'; import {DTYPE_VALUE_SIZE_MAP, WeightsManifestConfig, WeightsManifestEntry} from './types'; /** @@ -158,61 +158,15 @@ export async function loadWeights( } const weightsEntries = groupWeightsToFetch[i]; - weightsEntries.forEach(weightsEntry => { const byteBuffer = groupBuffer.slice( weightsEntry.groupOffset, weightsEntry.groupOffset + weightsEntry.sizeBytes); - - let typedArray: TypedArray; - - const dtype = weightsEntry.manifestEntry.dtype; - - if ('quantization' in weightsEntry.manifestEntry) { - const quantization = weightsEntry.manifestEntry.quantization; - if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') { - throw new Error( - `Weight ${weightsEntry.manifestEntry.name} has unknown ` + - `quantization dtype ${quantization.dtype}.`); - } - const quantizedArray = (quantization.dtype === 'uint8') ? - new Uint8Array(byteBuffer) : - new Uint16Array(byteBuffer); - if (dtype === 'float32') { - typedArray = Float32Array.from( - quantizedArray, v => v * quantization.scale + quantization.min); - } else if (dtype === 'int32') { - typedArray = Int32Array.from( - quantizedArray, - v => Math.round(v * quantization.scale + quantization.min)); - } else { - throw new Error( - `Weight ${weightsEntry.manifestEntry.name} has a dtype not ` + - `supported by quantization: ${dtype}`); - } - } else { - if (dtype === 'float32') { - typedArray = new Float32Array(byteBuffer); - } else if (dtype === 'int32') { - typedArray = new Int32Array(byteBuffer); - } else if (dtype === 'bool') { - typedArray = new Uint8Array(byteBuffer); - } else { - throw new Error( - `Weight ${weightsEntry.manifestEntry.name} has unknown dtype ` + - `${dtype}.`); - } - } - - const weightName = weightsEntry.manifestEntry.name; - if (weightsTensorMap[weightName] != null) { - throw new Error( - `Duplicate weight with name ${weightName}. ` + - `Please make sure weights names are unique in the manifest JSON.`); + const nameToTensorMap = + decodeWeights(byteBuffer, [weightsEntry.manifestEntry]); + for (const name in nameToTensorMap) { + weightsTensorMap[name] = nameToTensorMap[name]; } - weightsTensorMap[weightName] = tensor( - typedArray, weightsEntry.manifestEntry.shape, - weightsEntry.manifestEntry.dtype); }); bufferIndexOffset += numBuffers;