forked from tensorflow/tfjs-core
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FEATURE This PR add inTopK op, which behaves the same way as [tf.math.in_top_k](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/math/in_top_k#aliases) in TensorFlow. This op help develop further metrics which depend on inTopK operation, such as, topKCategoricalAccuracy (feature requested in [tensorflow/tfjs#27](tensorflow/tfjs#27) ), sparseTopKCategoricalAccuracy (feature requested in [tensorflow/tfjs#26](tensorflow/tfjs#26)). Relative PR [tensorflow/tfjs-layers#537](tensorflow/tfjs-layers#537) This PR: * Add new inTopK op to [src/ops](https://github.com/tensorflow/tfjs-core/tree/master/src/ops) * Register inTopK in [src/ops/ops.ts](https://github.com/tensorflow/tfjs-core/blob/master/src/ops/ops.ts) * Add inTopK kernel to backend * Add shared inTopK implementation between webgl and cpu * Add relative tests for inTopK Reference: * [TensorFlow in_top_k doc](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/math/in_top_k#aliases) * [TensorFlow in_top_k implementation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/in_topk_op.cc) * [TensorFlow in_top_k test cases](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/nn_ops_test.cc#L442)
- Loading branch information
Showing
8 changed files
with
268 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/** | ||
* @license | ||
* Copyright 2018 Google LLC. All Rights Reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* ============================================================================= | ||
*/ | ||
|
||
/** An implementation of the inTopK kernel shared between webgl and cpu. */ | ||
|
||
import {tensor} from '../ops/tensor_ops'; | ||
import {Tensor} from '../tensor'; | ||
import {TypedArray} from '../types'; | ||
import {getTypedArrayFromDType} from '../util'; | ||
|
||
export function inTopKImpl<T extends Tensor>( | ||
predictionsVals: TypedArray, predictionsShape: number[], | ||
targetsVals: TypedArray, targetsShape: number[], k: number | ||
): T { | ||
// Reshape predictionsVals into a 2d tensor [batch, lastDim] | ||
// and look up topK along lastDim. | ||
const lastDim = predictionsShape[predictionsShape.length - 1]; | ||
const [batch, size] = [predictionsVals.length / lastDim, lastDim]; | ||
const precision = getTypedArrayFromDType('bool', batch); | ||
|
||
for (let b = 0; b < batch; b++) { | ||
const offset = b * size; | ||
const vals = predictionsVals.subarray(offset, offset + size); | ||
const valAndInd: Array<{ value: number, index: number }> = []; | ||
for (let i = 0; i < vals.length; i++) { | ||
valAndInd.push({value: vals[i], index: i}); | ||
} | ||
valAndInd.sort((a, b) => b.value - a.value); | ||
|
||
precision[b] = 0; | ||
for (let i = 0; i < k; i++) { | ||
if (valAndInd[i].index === targetsVals[b]) { | ||
precision[b] = 1; | ||
break; | ||
} | ||
} | ||
} | ||
|
||
// Output precision has the same shape as targets. | ||
return tensor(precision, targetsShape, 'bool') as T; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
/** | ||
* @license | ||
* Copyright 2018 Google LLC. All Rights Reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* ============================================================================= | ||
*/ | ||
|
||
import {ENGINE} from '../engine'; | ||
import {NumericTensor, Tensor} from '../tensor'; | ||
import {convertToTensor} from '../tensor_util_env'; | ||
import {TensorLike} from '../types'; | ||
import {assert, assertShapesMatch} from '../util'; | ||
|
||
import {op} from './operation'; | ||
|
||
/** | ||
* Says whether the targets are in the top K predictions. | ||
* | ||
* ```js | ||
* const predictions = tf.tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]); | ||
* const targets = tf.tensor1d([2, 0]); | ||
* const precision = tf.inTopK(predictions, targets); | ||
* precision.print(); | ||
* ``` | ||
* @param predictions 2-D or higher `tf.Tensor` with last dimension being | ||
* at least `k`. | ||
* @param targets 1-D or higher `tf.Tensor`. | ||
* @param k Optional Number of top elements to look at for computing precision, | ||
* default to 1. | ||
*/ | ||
/** @doc {heading: 'Operations', subheading: 'Evaluation'} */ | ||
function inTopK_<T extends Tensor, U extends Tensor>( | ||
predictions: T|TensorLike, targets: U|TensorLike, k = 1): U { | ||
const $predictions = convertToTensor(predictions, 'predictions', 'inTopK'); | ||
const $targets = convertToTensor(targets, 'targets', 'inTopK'); | ||
|
||
assert( | ||
$predictions.rank > 1, | ||
() => 'inTopK() expects the predictions to be of rank 2 or higher, ' + | ||
`but got ${$predictions.rank}`); | ||
assert( | ||
$predictions.rank - 1 === $targets.rank, | ||
() => `predictions' rank should be 1 larger than ` + | ||
`targets' rank, but got predictions' rank ` + | ||
`${$predictions.rank} and targets' rank ${$targets.rank}`); | ||
assertShapesMatch( | ||
$predictions.shape.slice(0, $predictions.shape.length - 1), | ||
$targets.shape, | ||
`predictions's shape should be align with the targets' shape, ` + | ||
'except the last dimension.'); | ||
const lastDim = $predictions.shape[$predictions.shape.length - 1]; | ||
assert( | ||
k > 0 && k <= lastDim, | ||
() => `'k' passed to inTopK() must be > 0 && <= the predictions' last ` + | ||
`dimension (${lastDim}), but got ${k}`); | ||
|
||
const precision = ENGINE.runKernel( | ||
b => | ||
b.inTopK($predictions as NumericTensor, $targets as NumericTensor, k), | ||
{$predictions, $targets}); | ||
|
||
return precision as U; | ||
} | ||
|
||
export const inTopK = op({inTopK_}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
/** | ||
* @license | ||
* Copyright 2018 Google LLC. All Rights Reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* ============================================================================= | ||
*/ | ||
|
||
import * as tf from '../index'; | ||
import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; | ||
import {expectArraysClose} from '../test_util'; | ||
|
||
import {tensor1d, tensor2d, tensor3d} from './tensor_ops'; | ||
|
||
describeWithFlags('inTopK', ALL_ENVS, async () => { | ||
it('predictions 2d array, targets 1d array, with default k', async () => { | ||
const predictions = tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]); | ||
const targets = tensor1d([2, 0]); | ||
const precision = tf.inTopK(predictions, targets); | ||
expect(precision.shape).toEqual([2]); | ||
expect(precision.dtype).toBe('bool'); | ||
expectArraysClose(await precision.data(), [1, 0]); | ||
}); | ||
|
||
it('predictions 2d array, targets 1d array, with k=2', async () => { | ||
const predictions = tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]); | ||
const targets = tensor1d([2, 0]); | ||
const k = 2; | ||
const precision = tf.inTopK(predictions, targets, k); | ||
expect(precision.shape).toEqual([2]); | ||
expect(precision.dtype).toBe('bool'); | ||
expectArraysClose(await precision.data(), [1, 1]); | ||
}); | ||
|
||
it('predictions 3d array, targets 2d array, with default k', async () => { | ||
const predictions = | ||
tensor3d([[[1, 5, 2], [4, 3, 6]], [[3, 2, 1], [1, 2, 3]]]); | ||
const targets = tensor2d([[1, 2], [0, 1]]); | ||
const precision = tf.inTopK(predictions, targets); | ||
expect(precision.shape).toEqual([2, 2]); | ||
expect(precision.dtype).toBe('bool'); | ||
expectArraysClose(await precision.data(), [1, 1, 1, 0]); | ||
}); | ||
|
||
it('predictions 3d array, targets 2d array, with k=2', async () => { | ||
const predictions = | ||
tensor3d([[[1, 5, 2], [4, 3, 6]], [[3, 2, 1], [1, 2, 3]]]); | ||
const targets = tensor2d([[1, 2], [0, 1]]); | ||
const k = 2; | ||
const precision = tf.inTopK(predictions, targets, k); | ||
expect(precision.shape).toEqual([2, 2]); | ||
expect(precision.dtype).toBe('bool'); | ||
expectArraysClose(await precision.data(), [1, 1, 1, 1]); | ||
}); | ||
|
||
it('lower-index element count first, with default k', async () => { | ||
const predictions = tensor2d([[1, 2, 2, 1]]); | ||
|
||
const targets1 = tensor1d([1]); | ||
const precision1 = tf.inTopK(predictions, targets1); | ||
expect(precision1.shape).toEqual([1]); | ||
expect(precision1.dtype).toBe('bool'); | ||
expectArraysClose(await precision1.data(), [1]); | ||
|
||
const targets2 = tensor1d([2]); | ||
const precision2 = tf.inTopK(predictions, targets2); | ||
expect(precision2.shape).toEqual([1]); | ||
expect(precision2.dtype).toBe('bool'); | ||
expectArraysClose(await precision2.data(), [0]); | ||
}); | ||
|
||
it('accept tensor-like object, with default k', async () => { | ||
const predictions = [[20, 10, 40, 30], [30, 50, -20, 10]]; | ||
const targets = [2, 0]; | ||
const precision = tf.inTopK(predictions, targets); | ||
expect(precision.shape).toEqual([2]); | ||
expect(precision.dtype).toBe('bool'); | ||
expectArraysClose(await precision.data(), [1, 0]); | ||
}); | ||
|
||
it('throws when predictions_rank <2', () => { | ||
const predictions = tensor1d([20, 10, 40, 30]); | ||
const targets = [2]; | ||
expect(() => tf.inTopK(predictions, targets)).toThrowError(); | ||
}); | ||
|
||
it('throws when prediction_rank != targets_rank + 1', () => { | ||
const predictions = tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]); | ||
const targets = tensor2d([[0], [0]]); | ||
expect(() => tf.inTopK(predictions, targets)).toThrowError(); | ||
}); | ||
|
||
it('throws when k > size of last dimension of predictions', () => { | ||
const predictions = tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]); | ||
const targets = tensor1d([2, 0]); | ||
const k = 5; | ||
expect(() => tf.inTopK(predictions, targets, k)).toThrowError(); | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters