Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

Commit

Permalink
Add inTopK op (#1734)
Browse files Browse the repository at this point in the history
FEATURE

This PR add inTopK op, which behaves the same way as [tf.math.in_top_k](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/math/in_top_k#aliases) in TensorFlow. This op help develop further metrics which depend on inTopK operation, such as, topKCategoricalAccuracy (feature requested in [tensorflow/tfjs#27](tensorflow/tfjs#27) ), sparseTopKCategoricalAccuracy (feature requested in [tensorflow/tfjs#26](tensorflow/tfjs#26)). Relative PR [tensorflow/tfjs-layers#537](tensorflow/tfjs-layers#537)

This PR:
* Add new inTopK op to [src/ops](https://github.com/tensorflow/tfjs-core/tree/master/src/ops)
* Register inTopK in [src/ops/ops.ts](https://github.com/tensorflow/tfjs-core/blob/master/src/ops/ops.ts)
* Add inTopK kernel to backend
* Add shared inTopK implementation between webgl and cpu
* Add relative tests for inTopK

Reference:
* [TensorFlow in_top_k doc](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/math/in_top_k#aliases)
* [TensorFlow in_top_k implementation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/in_topk_op.cc)
* [TensorFlow in_top_k test cases](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/nn_ops_test.cc#L442)
  • Loading branch information
syt123450 authored and Nikhil Thorat committed Aug 7, 2019
1 parent 8a6d4d5 commit e4d7607
Show file tree
Hide file tree
Showing 8 changed files with 268 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/backends/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,11 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
throw new Error('Not yet implemented');
}

inTopK<T extends Tensor, U extends Tensor>(
predictions: T, targets: U, k: number): U {
throw new Error('Not yet implemented');
}

min(x: Tensor, axes: number[]): Tensor {
throw new Error('Not yet implemented');
}
Expand Down
13 changes: 13 additions & 0 deletions src/backends/cpu/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import {getArrayFromDType, inferDtype, now, sizeFromShape} from '../../util';
import {BackendTimingInfo, DataStorage, EPSILON_FLOAT32, KernelBackend} from '../backend';
import * as backend_util from '../backend_util';
import * as complex_util from '../complex_util';
import {inTopKImpl} from '../inTopK_impl';
import {nonMaxSuppressionImpl} from '../non_max_suppression_impl';
import {split} from '../split_shared';
import {tile} from '../tile_impl';
Expand Down Expand Up @@ -858,6 +859,18 @@ export class MathBackendCPU implements KernelBackend {
return topkImpl(xVals, x.shape, x.dtype as NumericDataType, k, sorted);
}

inTopK<T extends Tensor, U extends Tensor>(
predictions: T, targets: U, k: number): U {
this.assertNotComplex([predictions, targets], 'inTopK');

const predictionsVals = this.readSync(predictions.dataId) as TypedArray;
const targetsVals = this.readSync(targets.dataId) as TypedArray;

return inTopKImpl(
predictionsVals, predictions.shape, targetsVals, targets.shape,
k) as U;
}

min(x: Tensor, axes: number[]): Tensor {
this.assertNotComplex(x, 'min');

Expand Down
55 changes: 55 additions & 0 deletions src/backends/inTopK_impl.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

/** An implementation of the inTopK kernel shared between webgl and cpu. */

import {tensor} from '../ops/tensor_ops';
import {Tensor} from '../tensor';
import {TypedArray} from '../types';
import {getTypedArrayFromDType} from '../util';

export function inTopKImpl<T extends Tensor>(
predictionsVals: TypedArray, predictionsShape: number[],
targetsVals: TypedArray, targetsShape: number[], k: number
): T {
// Reshape predictionsVals into a 2d tensor [batch, lastDim]
// and look up topK along lastDim.
const lastDim = predictionsShape[predictionsShape.length - 1];
const [batch, size] = [predictionsVals.length / lastDim, lastDim];
const precision = getTypedArrayFromDType('bool', batch);

for (let b = 0; b < batch; b++) {
const offset = b * size;
const vals = predictionsVals.subarray(offset, offset + size);
const valAndInd: Array<{ value: number, index: number }> = [];
for (let i = 0; i < vals.length; i++) {
valAndInd.push({value: vals[i], index: i});
}
valAndInd.sort((a, b) => b.value - a.value);

precision[b] = 0;
for (let i = 0; i < k; i++) {
if (valAndInd[i].index === targetsVals[b]) {
precision[b] = 1;
break;
}
}
}

// Output precision has the same shape as targets.
return tensor(precision, targetsShape, 'bool') as T;
}
10 changes: 10 additions & 0 deletions src/backends/webgl/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import {getArrayFromDType, getTypedArrayFromDType, inferDtype, sizeFromShape} fr
import {DataStorage, EPSILON_FLOAT16, EPSILON_FLOAT32, KernelBackend} from '../backend';
import * as backend_util from '../backend_util';
import {mergeRealAndImagArrays} from '../complex_util';
import {inTopKImpl} from '../inTopK_impl';
import {nonMaxSuppressionImpl} from '../non_max_suppression_impl';
import {split} from '../split_shared';
import {tile} from '../tile_impl';
Expand Down Expand Up @@ -1359,6 +1360,15 @@ export class MathBackendWebGL implements KernelBackend {
return topkImpl(xVals, x.shape, x.dtype as NumericDataType, k, sorted);
}

inTopK<T extends Tensor, U extends Tensor>(
predictions: T, targets: U, k: number): U {
const predictionsVals = predictions.dataSync();
const targetsVals = targets.dataSync();
return inTopKImpl(
predictionsVals, predictions.shape, targetsVals, targets.shape,
k) as U;
}

min(x: Tensor, axes: number[]): Tensor {
axis_util.assertAxesAreInnerMostDims('min', axes, x.rank);
const [outShape, reduceShape] =
Expand Down
75 changes: 75 additions & 0 deletions src/ops/inTopK.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {ENGINE} from '../engine';
import {NumericTensor, Tensor} from '../tensor';
import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import {assert, assertShapesMatch} from '../util';

import {op} from './operation';

/**
* Says whether the targets are in the top K predictions.
*
* ```js
* const predictions = tf.tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]);
* const targets = tf.tensor1d([2, 0]);
* const precision = tf.inTopK(predictions, targets);
* precision.print();
* ```
* @param predictions 2-D or higher `tf.Tensor` with last dimension being
* at least `k`.
* @param targets 1-D or higher `tf.Tensor`.
* @param k Optional Number of top elements to look at for computing precision,
* default to 1.
*/
/** @doc {heading: 'Operations', subheading: 'Evaluation'} */
function inTopK_<T extends Tensor, U extends Tensor>(
predictions: T|TensorLike, targets: U|TensorLike, k = 1): U {
const $predictions = convertToTensor(predictions, 'predictions', 'inTopK');
const $targets = convertToTensor(targets, 'targets', 'inTopK');

assert(
$predictions.rank > 1,
() => 'inTopK() expects the predictions to be of rank 2 or higher, ' +
`but got ${$predictions.rank}`);
assert(
$predictions.rank - 1 === $targets.rank,
() => `predictions' rank should be 1 larger than ` +
`targets' rank, but got predictions' rank ` +
`${$predictions.rank} and targets' rank ${$targets.rank}`);
assertShapesMatch(
$predictions.shape.slice(0, $predictions.shape.length - 1),
$targets.shape,
`predictions's shape should be align with the targets' shape, ` +
'except the last dimension.');
const lastDim = $predictions.shape[$predictions.shape.length - 1];
assert(
k > 0 && k <= lastDim,
() => `'k' passed to inTopK() must be > 0 && <= the predictions' last ` +
`dimension (${lastDim}), but got ${k}`);

const precision = ENGINE.runKernel(
b =>
b.inTopK($predictions as NumericTensor, $targets as NumericTensor, k),
{$predictions, $targets});

return precision as U;
}

export const inTopK = op({inTopK_});
108 changes: 108 additions & 0 deletions src/ops/inTopK_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as tf from '../index';
import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
import {expectArraysClose} from '../test_util';

import {tensor1d, tensor2d, tensor3d} from './tensor_ops';

describeWithFlags('inTopK', ALL_ENVS, async () => {
it('predictions 2d array, targets 1d array, with default k', async () => {
const predictions = tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]);
const targets = tensor1d([2, 0]);
const precision = tf.inTopK(predictions, targets);
expect(precision.shape).toEqual([2]);
expect(precision.dtype).toBe('bool');
expectArraysClose(await precision.data(), [1, 0]);
});

it('predictions 2d array, targets 1d array, with k=2', async () => {
const predictions = tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]);
const targets = tensor1d([2, 0]);
const k = 2;
const precision = tf.inTopK(predictions, targets, k);
expect(precision.shape).toEqual([2]);
expect(precision.dtype).toBe('bool');
expectArraysClose(await precision.data(), [1, 1]);
});

it('predictions 3d array, targets 2d array, with default k', async () => {
const predictions =
tensor3d([[[1, 5, 2], [4, 3, 6]], [[3, 2, 1], [1, 2, 3]]]);
const targets = tensor2d([[1, 2], [0, 1]]);
const precision = tf.inTopK(predictions, targets);
expect(precision.shape).toEqual([2, 2]);
expect(precision.dtype).toBe('bool');
expectArraysClose(await precision.data(), [1, 1, 1, 0]);
});

it('predictions 3d array, targets 2d array, with k=2', async () => {
const predictions =
tensor3d([[[1, 5, 2], [4, 3, 6]], [[3, 2, 1], [1, 2, 3]]]);
const targets = tensor2d([[1, 2], [0, 1]]);
const k = 2;
const precision = tf.inTopK(predictions, targets, k);
expect(precision.shape).toEqual([2, 2]);
expect(precision.dtype).toBe('bool');
expectArraysClose(await precision.data(), [1, 1, 1, 1]);
});

it('lower-index element count first, with default k', async () => {
const predictions = tensor2d([[1, 2, 2, 1]]);

const targets1 = tensor1d([1]);
const precision1 = tf.inTopK(predictions, targets1);
expect(precision1.shape).toEqual([1]);
expect(precision1.dtype).toBe('bool');
expectArraysClose(await precision1.data(), [1]);

const targets2 = tensor1d([2]);
const precision2 = tf.inTopK(predictions, targets2);
expect(precision2.shape).toEqual([1]);
expect(precision2.dtype).toBe('bool');
expectArraysClose(await precision2.data(), [0]);
});

it('accept tensor-like object, with default k', async () => {
const predictions = [[20, 10, 40, 30], [30, 50, -20, 10]];
const targets = [2, 0];
const precision = tf.inTopK(predictions, targets);
expect(precision.shape).toEqual([2]);
expect(precision.dtype).toBe('bool');
expectArraysClose(await precision.data(), [1, 0]);
});

it('throws when predictions_rank <2', () => {
const predictions = tensor1d([20, 10, 40, 30]);
const targets = [2];
expect(() => tf.inTopK(predictions, targets)).toThrowError();
});

it('throws when prediction_rank != targets_rank + 1', () => {
const predictions = tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]);
const targets = tensor2d([[0], [0]]);
expect(() => tf.inTopK(predictions, targets)).toThrowError();
});

it('throws when k > size of last dimension of predictions', () => {
const predictions = tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]);
const targets = tensor1d([2, 0]);
const k = 5;
expect(() => tf.inTopK(predictions, targets, k)).toThrowError();
});
});
1 change: 1 addition & 0 deletions src/ops/ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ export * from './gather_nd';
export * from './diag';
export * from './dropout';
export * from './signal_ops';
export * from './inTopK';

export {op} from './operation';

Expand Down
1 change: 1 addition & 0 deletions src/tests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ import './ops/dropout_test';
import './ops/fused_test';
import './ops/gather_nd_test';
import './ops/image_ops_test';
import './ops/inTopK_test';
import './ops/linalg_ops_test';
import './ops/logical_ops_test';
import './ops/loss_ops_test';
Expand Down

0 comments on commit e4d7607

Please sign in to comment.