From e4d7607de9350510df27c7faa890fdbc6a010ae2 Mon Sep 17 00:00:00 2001 From: syt123450 Date: Wed, 7 Aug 2019 13:23:43 -0700 Subject: [PATCH] Add inTopK op (#1734) FEATURE This PR add inTopK op, which behaves the same way as [tf.math.in_top_k](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/math/in_top_k#aliases) in TensorFlow. This op help develop further metrics which depend on inTopK operation, such as, topKCategoricalAccuracy (feature requested in [tensorflow/tfjs#27](https://github.com/tensorflow/tfjs/issues/27) ), sparseTopKCategoricalAccuracy (feature requested in [tensorflow/tfjs#26](https://github.com/tensorflow/tfjs/issues/26)). Relative PR [tensorflow/tfjs-layers#537](https://github.com/tensorflow/tfjs-layers/pull/537) This PR: * Add new inTopK op to [src/ops](https://github.com/tensorflow/tfjs-core/tree/master/src/ops) * Register inTopK in [src/ops/ops.ts](https://github.com/tensorflow/tfjs-core/blob/master/src/ops/ops.ts) * Add inTopK kernel to backend * Add shared inTopK implementation between webgl and cpu * Add relative tests for inTopK Reference: * [TensorFlow in_top_k doc](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/math/in_top_k#aliases) * [TensorFlow in_top_k implementation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/in_topk_op.cc) * [TensorFlow in_top_k test cases](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/nn_ops_test.cc#L442) --- src/backends/backend.ts | 5 ++ src/backends/cpu/backend_cpu.ts | 13 ++++ src/backends/inTopK_impl.ts | 55 ++++++++++++++ src/backends/webgl/backend_webgl.ts | 10 +++ src/ops/inTopK.ts | 75 +++++++++++++++++++ src/ops/inTopK_test.ts | 108 ++++++++++++++++++++++++++++ src/ops/ops.ts | 1 + src/tests.ts | 1 + 8 files changed, 268 insertions(+) create mode 100644 src/backends/inTopK_impl.ts create mode 100644 src/ops/inTopK.ts create mode 100644 src/ops/inTopK_test.ts diff --git a/src/backends/backend.ts b/src/backends/backend.ts index 0a6cc981d1..f1a64e4d13 100644 --- a/src/backends/backend.ts +++ b/src/backends/backend.ts @@ -241,6 +241,11 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer { throw new Error('Not yet implemented'); } + inTopK( + predictions: T, targets: U, k: number): U { + throw new Error('Not yet implemented'); + } + min(x: Tensor, axes: number[]): Tensor { throw new Error('Not yet implemented'); } diff --git a/src/backends/cpu/backend_cpu.ts b/src/backends/cpu/backend_cpu.ts index f217ff3143..132d00a502 100644 --- a/src/backends/cpu/backend_cpu.ts +++ b/src/backends/cpu/backend_cpu.ts @@ -41,6 +41,7 @@ import {getArrayFromDType, inferDtype, now, sizeFromShape} from '../../util'; import {BackendTimingInfo, DataStorage, EPSILON_FLOAT32, KernelBackend} from '../backend'; import * as backend_util from '../backend_util'; import * as complex_util from '../complex_util'; +import {inTopKImpl} from '../inTopK_impl'; import {nonMaxSuppressionImpl} from '../non_max_suppression_impl'; import {split} from '../split_shared'; import {tile} from '../tile_impl'; @@ -858,6 +859,18 @@ export class MathBackendCPU implements KernelBackend { return topkImpl(xVals, x.shape, x.dtype as NumericDataType, k, sorted); } + inTopK( + predictions: T, targets: U, k: number): U { + this.assertNotComplex([predictions, targets], 'inTopK'); + + const predictionsVals = this.readSync(predictions.dataId) as TypedArray; + const targetsVals = this.readSync(targets.dataId) as TypedArray; + + return inTopKImpl( + predictionsVals, predictions.shape, targetsVals, targets.shape, + k) as U; + } + min(x: Tensor, axes: number[]): Tensor { this.assertNotComplex(x, 'min'); diff --git a/src/backends/inTopK_impl.ts b/src/backends/inTopK_impl.ts new file mode 100644 index 0000000000..c662d62b3b --- /dev/null +++ b/src/backends/inTopK_impl.ts @@ -0,0 +1,55 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** An implementation of the inTopK kernel shared between webgl and cpu. */ + +import {tensor} from '../ops/tensor_ops'; +import {Tensor} from '../tensor'; +import {TypedArray} from '../types'; +import {getTypedArrayFromDType} from '../util'; + +export function inTopKImpl( + predictionsVals: TypedArray, predictionsShape: number[], + targetsVals: TypedArray, targetsShape: number[], k: number +): T { + // Reshape predictionsVals into a 2d tensor [batch, lastDim] + // and look up topK along lastDim. + const lastDim = predictionsShape[predictionsShape.length - 1]; + const [batch, size] = [predictionsVals.length / lastDim, lastDim]; + const precision = getTypedArrayFromDType('bool', batch); + + for (let b = 0; b < batch; b++) { + const offset = b * size; + const vals = predictionsVals.subarray(offset, offset + size); + const valAndInd: Array<{ value: number, index: number }> = []; + for (let i = 0; i < vals.length; i++) { + valAndInd.push({value: vals[i], index: i}); + } + valAndInd.sort((a, b) => b.value - a.value); + + precision[b] = 0; + for (let i = 0; i < k; i++) { + if (valAndInd[i].index === targetsVals[b]) { + precision[b] = 1; + break; + } + } + } + + // Output precision has the same shape as targets. + return tensor(precision, targetsShape, 'bool') as T; +} \ No newline at end of file diff --git a/src/backends/webgl/backend_webgl.ts b/src/backends/webgl/backend_webgl.ts index 801a76bb39..4699df63d1 100644 --- a/src/backends/webgl/backend_webgl.ts +++ b/src/backends/webgl/backend_webgl.ts @@ -44,6 +44,7 @@ import {getArrayFromDType, getTypedArrayFromDType, inferDtype, sizeFromShape} fr import {DataStorage, EPSILON_FLOAT16, EPSILON_FLOAT32, KernelBackend} from '../backend'; import * as backend_util from '../backend_util'; import {mergeRealAndImagArrays} from '../complex_util'; +import {inTopKImpl} from '../inTopK_impl'; import {nonMaxSuppressionImpl} from '../non_max_suppression_impl'; import {split} from '../split_shared'; import {tile} from '../tile_impl'; @@ -1359,6 +1360,15 @@ export class MathBackendWebGL implements KernelBackend { return topkImpl(xVals, x.shape, x.dtype as NumericDataType, k, sorted); } + inTopK( + predictions: T, targets: U, k: number): U { + const predictionsVals = predictions.dataSync(); + const targetsVals = targets.dataSync(); + return inTopKImpl( + predictionsVals, predictions.shape, targetsVals, targets.shape, + k) as U; + } + min(x: Tensor, axes: number[]): Tensor { axis_util.assertAxesAreInnerMostDims('min', axes, x.rank); const [outShape, reduceShape] = diff --git a/src/ops/inTopK.ts b/src/ops/inTopK.ts new file mode 100644 index 0000000000..2c97f469b2 --- /dev/null +++ b/src/ops/inTopK.ts @@ -0,0 +1,75 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {NumericTensor, Tensor} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import {assert, assertShapesMatch} from '../util'; + +import {op} from './operation'; + +/** + * Says whether the targets are in the top K predictions. + * + * ```js + * const predictions = tf.tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]); + * const targets = tf.tensor1d([2, 0]); + * const precision = tf.inTopK(predictions, targets); + * precision.print(); + * ``` + * @param predictions 2-D or higher `tf.Tensor` with last dimension being + * at least `k`. + * @param targets 1-D or higher `tf.Tensor`. + * @param k Optional Number of top elements to look at for computing precision, + * default to 1. + */ +/** @doc {heading: 'Operations', subheading: 'Evaluation'} */ +function inTopK_( + predictions: T|TensorLike, targets: U|TensorLike, k = 1): U { + const $predictions = convertToTensor(predictions, 'predictions', 'inTopK'); + const $targets = convertToTensor(targets, 'targets', 'inTopK'); + + assert( + $predictions.rank > 1, + () => 'inTopK() expects the predictions to be of rank 2 or higher, ' + + `but got ${$predictions.rank}`); + assert( + $predictions.rank - 1 === $targets.rank, + () => `predictions' rank should be 1 larger than ` + + `targets' rank, but got predictions' rank ` + + `${$predictions.rank} and targets' rank ${$targets.rank}`); + assertShapesMatch( + $predictions.shape.slice(0, $predictions.shape.length - 1), + $targets.shape, + `predictions's shape should be align with the targets' shape, ` + + 'except the last dimension.'); + const lastDim = $predictions.shape[$predictions.shape.length - 1]; + assert( + k > 0 && k <= lastDim, + () => `'k' passed to inTopK() must be > 0 && <= the predictions' last ` + + `dimension (${lastDim}), but got ${k}`); + + const precision = ENGINE.runKernel( + b => + b.inTopK($predictions as NumericTensor, $targets as NumericTensor, k), + {$predictions, $targets}); + + return precision as U; +} + +export const inTopK = op({inTopK_}); diff --git a/src/ops/inTopK_test.ts b/src/ops/inTopK_test.ts new file mode 100644 index 0000000000..a13416fbed --- /dev/null +++ b/src/ops/inTopK_test.ts @@ -0,0 +1,108 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +import {tensor1d, tensor2d, tensor3d} from './tensor_ops'; + +describeWithFlags('inTopK', ALL_ENVS, async () => { + it('predictions 2d array, targets 1d array, with default k', async () => { + const predictions = tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]); + const targets = tensor1d([2, 0]); + const precision = tf.inTopK(predictions, targets); + expect(precision.shape).toEqual([2]); + expect(precision.dtype).toBe('bool'); + expectArraysClose(await precision.data(), [1, 0]); + }); + + it('predictions 2d array, targets 1d array, with k=2', async () => { + const predictions = tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]); + const targets = tensor1d([2, 0]); + const k = 2; + const precision = tf.inTopK(predictions, targets, k); + expect(precision.shape).toEqual([2]); + expect(precision.dtype).toBe('bool'); + expectArraysClose(await precision.data(), [1, 1]); + }); + + it('predictions 3d array, targets 2d array, with default k', async () => { + const predictions = + tensor3d([[[1, 5, 2], [4, 3, 6]], [[3, 2, 1], [1, 2, 3]]]); + const targets = tensor2d([[1, 2], [0, 1]]); + const precision = tf.inTopK(predictions, targets); + expect(precision.shape).toEqual([2, 2]); + expect(precision.dtype).toBe('bool'); + expectArraysClose(await precision.data(), [1, 1, 1, 0]); + }); + + it('predictions 3d array, targets 2d array, with k=2', async () => { + const predictions = + tensor3d([[[1, 5, 2], [4, 3, 6]], [[3, 2, 1], [1, 2, 3]]]); + const targets = tensor2d([[1, 2], [0, 1]]); + const k = 2; + const precision = tf.inTopK(predictions, targets, k); + expect(precision.shape).toEqual([2, 2]); + expect(precision.dtype).toBe('bool'); + expectArraysClose(await precision.data(), [1, 1, 1, 1]); + }); + + it('lower-index element count first, with default k', async () => { + const predictions = tensor2d([[1, 2, 2, 1]]); + + const targets1 = tensor1d([1]); + const precision1 = tf.inTopK(predictions, targets1); + expect(precision1.shape).toEqual([1]); + expect(precision1.dtype).toBe('bool'); + expectArraysClose(await precision1.data(), [1]); + + const targets2 = tensor1d([2]); + const precision2 = tf.inTopK(predictions, targets2); + expect(precision2.shape).toEqual([1]); + expect(precision2.dtype).toBe('bool'); + expectArraysClose(await precision2.data(), [0]); + }); + + it('accept tensor-like object, with default k', async () => { + const predictions = [[20, 10, 40, 30], [30, 50, -20, 10]]; + const targets = [2, 0]; + const precision = tf.inTopK(predictions, targets); + expect(precision.shape).toEqual([2]); + expect(precision.dtype).toBe('bool'); + expectArraysClose(await precision.data(), [1, 0]); + }); + + it('throws when predictions_rank <2', () => { + const predictions = tensor1d([20, 10, 40, 30]); + const targets = [2]; + expect(() => tf.inTopK(predictions, targets)).toThrowError(); + }); + + it('throws when prediction_rank != targets_rank + 1', () => { + const predictions = tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]); + const targets = tensor2d([[0], [0]]); + expect(() => tf.inTopK(predictions, targets)).toThrowError(); + }); + + it('throws when k > size of last dimension of predictions', () => { + const predictions = tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]); + const targets = tensor1d([2, 0]); + const k = 5; + expect(() => tf.inTopK(predictions, targets, k)).toThrowError(); + }); +}); diff --git a/src/ops/ops.ts b/src/ops/ops.ts index 0c7ba4f8c9..790b9a6bac 100644 --- a/src/ops/ops.ts +++ b/src/ops/ops.ts @@ -48,6 +48,7 @@ export * from './gather_nd'; export * from './diag'; export * from './dropout'; export * from './signal_ops'; +export * from './inTopK'; export {op} from './operation'; diff --git a/src/tests.ts b/src/tests.ts index fe5a93961c..27326c21c5 100644 --- a/src/tests.ts +++ b/src/tests.ts @@ -61,6 +61,7 @@ import './ops/dropout_test'; import './ops/fused_test'; import './ops/gather_nd_test'; import './ops/image_ops_test'; +import './ops/inTopK_test'; import './ops/linalg_ops_test'; import './ops/logical_ops_test'; import './ops/loss_ops_test';