diff --git a/demos/benchmarks/pool_benchmarks.ts b/demos/benchmarks/pool_benchmarks.ts index d5808edf4e..844a225cf1 100644 --- a/demos/benchmarks/pool_benchmarks.ts +++ b/demos/benchmarks/pool_benchmarks.ts @@ -33,22 +33,19 @@ function getPoolingOp(option: string, math: NDArrayMath): ( strides: [number, number]|number, pad: 'valid'|'same'|number) => Array3D { switch (option) { case 'max': - return (x: Array3D, filterSize: [number, number] | number, - strides: [number, number] | number, - pad: 'valid' | 'same' | number) => { + return (x: Array3D, filterSize: [number, number]|number, + strides: [number, number]|number, pad: 'valid'|'same'|number) => { return math.maxPool(x, filterSize, strides, pad); }; case 'min': - return (x: Array3D, filterSize: [number, number] | number, - strides: [number, number] | number, - pad: 'valid' | 'same' | number) => { + return (x: Array3D, filterSize: [number, number]|number, + strides: [number, number]|number, pad: 'valid'|'same'|number) => { return math.minPool(x, filterSize, strides, pad); }; case 'avg': - return (x: Array3D, filterSize: [number, number] | number, - strides: [number, number] | number, - pad: 'valid' | 'same' | number) => { - return math.avgPool(x, filterSize, strides, pad); + return (x: Array3D, filterSize: [number, number]|number, + strides: [number, number]|number, pad: 'valid'|'same'|number) => { + return math.avgPool(x.asType('float32'), filterSize, strides, pad); }; default: throw new Error(`Not found such ops: ${option}`); diff --git a/demos/benchmarks/reduction_ops_benchmark.ts b/demos/benchmarks/reduction_ops_benchmark.ts index 291fb0ff1f..be479f5079 100644 --- a/demos/benchmarks/reduction_ops_benchmark.ts +++ b/demos/benchmarks/reduction_ops_benchmark.ts @@ -34,7 +34,7 @@ function getReductionOp(option: string, math: NDArrayMath): (input: NDArray) => case 'sum': return input => math.sum(input) as Scalar; case 'logSumExp': - return input => math.logSumExp(input); + return input => math.logSumExp(input) as Scalar; default: throw new Error(`Not found such ops: ${option}`); } diff --git a/package.json b/package.json index 5d1cb86d52..30f271e9a0 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "deeplearn", - "version": "0.4.1", + "version": "0.4.2", "description": "Hardware-accelerated JavaScript library for machine intelligence", "private": false, "main": "dist/index.js", diff --git a/src/math/math.ts b/src/math/math.ts index 5f98e3f229..fc84b7b20a 100644 --- a/src/math/math.ts +++ b/src/math/math.ts @@ -2426,11 +2426,10 @@ export class NDArrayMath implements NDArrayManager { * number. If none is provided, it will not round and error if the output * is of fractional size. */ - avgPool, - T2 extends NDArray<'float32', R>>( - x: T1, filterSize: [number, number]|number, + avgPool( + x: NDArray<'int32'|'float32', R>, filterSize: [number, number]|number, strides: [number, number]|number, pad: 'valid'|'same'|number, - dimRoundingMode?: 'floor'|'round'|'ceil'): T2 { + dimRoundingMode?: 'floor'|'round'|'ceil'): RankMap<'float32'>[R] { let x4D = x as NDArray as Array4D; let reshapedTo4D = false; if (x.rank === 3) { @@ -2458,10 +2457,10 @@ export class NDArrayMath implements NDArrayManager { const res = this.backendEngine.executeKernel( 'AvgPool', {inputs: {x: x4D}, args: {convInfo}}, gradients); if (reshapedTo4D) { - return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as NDArray as - T2; + return res.as3D(res.shape[1], res.shape[2], res.shape[3]) as + RankMap<'float32'>[R]; } - return res as NDArray as T2; + return res as RankMap<'float32'>[R]; }); } diff --git a/src/version.ts b/src/version.ts index 9615abfa31..7accc4d724 100644 --- a/src/version.ts +++ b/src/version.ts @@ -1,5 +1,5 @@ /** @license See the LICENSE file. */ // This code is auto-generated, do not modify this file! -const version = '0.4.1'; +const version = '0.4.2'; export {version};