forked from tensorflow/tfjs
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FEATURE Add `Slice` and `Square` kernels. We do the slice in JS instead of C++ since we are copying memory around, which should be fast in JS. Also I added fast rank-specific implementations for tensors up to 4D, and a generic slow implementation for rank >= 5D.
- Loading branch information
Showing
9 changed files
with
224 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
/* Copyright 2019 Google Inc. All Rights Reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* ===========================================================================*/ | ||
|
||
#ifdef __EMSCRIPTEN__ | ||
#include <emscripten.h> | ||
#endif | ||
|
||
#include "src/cc/backend.h" | ||
#include "src/cc/unary.h" | ||
|
||
namespace { | ||
inline float square(float val) { return val * val; } | ||
} // namespace | ||
|
||
namespace tfjs { | ||
namespace wasm { | ||
// We use C-style API to interface with Javascript. | ||
extern "C" { | ||
|
||
#ifdef __EMSCRIPTEN__ | ||
EMSCRIPTEN_KEEPALIVE | ||
#endif | ||
void Square(int x_id, int out_id) { unary(x_id, out_id, square); } | ||
|
||
} // extern "C" | ||
} // namespace wasm | ||
} // namespace tfjs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
/** | ||
* @license | ||
* Copyright 2019 Google Inc. All Rights Reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* ============================================================================= | ||
*/ | ||
|
||
import {backend_util, buffer, NamedAttrMap, NamedTensorInfoMap, registerKernel, slice_util, util} from '@tensorflow/tfjs-core'; | ||
import {TensorInfo} from '@tensorflow/tfjs-core'; | ||
|
||
import {BackendWasm} from '../backend_wasm'; | ||
|
||
interface SliceInputs extends NamedTensorInfoMap { | ||
x: TensorInfo; | ||
} | ||
|
||
interface SliceAttrs extends NamedAttrMap { | ||
begin: number[]; | ||
size: number[]; | ||
} | ||
|
||
function slice( | ||
args: {inputs: SliceInputs, attrs: SliceAttrs, backend: BackendWasm}) { | ||
const {inputs: {x}, attrs: {begin, size}, backend} = args; | ||
const isContinous = slice_util.isSliceContinous(x.shape, begin, size); | ||
const {memoryOffset: xOffset} = backend.dataIdMap.get(x.dataId); | ||
const xVals = | ||
backend.typedArrayFromHeap(xOffset, x.dtype, util.sizeFromShape(x.shape)); | ||
const out = backend.makeOutput(size, x.dtype); | ||
const {memoryOffset: outOffset} = backend.dataIdMap.get(out.dataId); | ||
const outVals = backend.typedArrayFromHeap( | ||
outOffset, out.dtype, util.sizeFromShape(out.shape)); | ||
const xStrides = util.computeStrides(x.shape); | ||
if (isContinous) { | ||
const flatOffset = slice_util.computeFlatOffset(begin, xStrides); | ||
outVals.set( | ||
xVals.subarray(flatOffset, flatOffset + util.sizeFromShape(size))); | ||
return out; | ||
} | ||
const rank = x.shape.length; | ||
if (rank === 2) { | ||
slice2d( | ||
xVals, xStrides[0], outVals, begin as [number, number], | ||
size as [number, number]); | ||
} else if (rank === 3) { | ||
slice3d( | ||
xVals, xStrides[0], xStrides[1], outVals, | ||
begin as [number, number, number], size as [number, number, number]); | ||
} else if (rank === 4) { | ||
slice4d( | ||
xVals, xStrides[0], xStrides[1], xStrides[2], outVals, | ||
begin as [number, number, number, number], | ||
size as [number, number, number, number]); | ||
} else { | ||
genericSliceSlow(xVals, x, outVals, begin, size); | ||
} | ||
return out; | ||
} | ||
|
||
function slice2d( | ||
xVals: backend_util.TypedArray, xStride: number, | ||
outVals: backend_util.TypedArray, begin: [number, number], | ||
size: [number, number]): void { | ||
let outOffset = 0; | ||
const beginI = begin[0]; | ||
const beginJ = begin[1]; | ||
const endI = beginI + size[0]; | ||
for (let i = beginI; i < endI; i++) { | ||
const xOffset = i * xStride + beginJ; | ||
outVals.set(xVals.subarray(xOffset, xOffset + size[1]), outOffset); | ||
outOffset += size[1]; | ||
} | ||
} | ||
|
||
function slice3d( | ||
xVals: backend_util.TypedArray, xStride1: number, xStride2: number, | ||
outVals: backend_util.TypedArray, begin: [number, number, number], | ||
size: [number, number, number]): void { | ||
let outOffset = 0; | ||
const beginI = begin[0]; | ||
const beginJ = begin[1]; | ||
const beginK = begin[2]; | ||
const endI = beginI + size[0]; | ||
const endJ = beginJ + size[1]; | ||
for (let i = beginI; i < endI; i++) { | ||
for (let j = beginJ; j < endJ; j++) { | ||
const xOffset = i * xStride1 + j * xStride2 + beginK; | ||
outVals.set(xVals.subarray(xOffset, xOffset + size[2]), outOffset); | ||
outOffset += size[2]; | ||
} | ||
} | ||
} | ||
|
||
function slice4d( | ||
xVals: backend_util.TypedArray, xStride1: number, xStride2: number, | ||
xStride3: number, outVals: backend_util.TypedArray, | ||
begin: [number, number, number, number], | ||
size: [number, number, number, number]): void { | ||
let outOffset = 0; | ||
const beginI = begin[0]; | ||
const beginJ = begin[1]; | ||
const beginK = begin[2]; | ||
const endI = beginI + size[0]; | ||
const endJ = beginJ + size[1]; | ||
const endK = beginK + size[2]; | ||
const beginL = begin[3]; | ||
|
||
for (let i = beginI; i < endI; i++) { | ||
for (let j = beginJ; j < endJ; j++) { | ||
for (let k = beginK; k < endK; k++) { | ||
const xOffset = i * xStride1 + j * xStride2 + k * xStride3 + beginL; | ||
outVals.set(xVals.subarray(xOffset, xOffset + size[3]), outOffset); | ||
outOffset += size[3]; | ||
} | ||
} | ||
} | ||
} | ||
|
||
function genericSliceSlow( | ||
xVals: backend_util.TypedArray, xInfo: TensorInfo, | ||
outVals: backend_util.TypedArray, begin: number[], size: number[]): void { | ||
const outBuf = buffer(size, xInfo.dtype, outVals); | ||
const xBuf = buffer(xInfo.shape, xInfo.dtype, xVals); | ||
for (let i = 0; i < outBuf.size; ++i) { | ||
const loc = outBuf.indexToLoc(i); | ||
const xLoc = loc.map((idx, j) => idx + begin[j]); | ||
outVals[i] = xBuf.get(...xLoc) as number; | ||
} | ||
} | ||
|
||
registerKernel({ | ||
kernelName: 'Slice', | ||
backendName: 'wasm', | ||
kernelFunc: slice, | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
/** | ||
* @license | ||
* Copyright 2019 Google Inc. All Rights Reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* ============================================================================= | ||
*/ | ||
|
||
import {registerUnaryKernel} from './unary_kernel'; | ||
registerUnaryKernel('Square'); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,4 +27,5 @@ import './Mul'; | |
import './Prelu'; | ||
import './Reshape'; | ||
import './Sigmoid'; | ||
import './Slice'; | ||
import './Sub'; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters