Skip to content

Commit

Permalink
[WASM] Add Slice and Square kernels (tensorflow#2269)
Browse files Browse the repository at this point in the history
FEATURE

Add `Slice` and `Square` kernels. 

We do the slice in JS instead of C++ since we are copying memory around, which should be fast in JS. Also I added fast rank-specific implementations for tensors up to 4D, and a generic slow implementation for rank >= 5D.
  • Loading branch information
dsmilkov authored Oct 26, 2019
1 parent 5c117cf commit 6794d7c
Show file tree
Hide file tree
Showing 9 changed files with 224 additions and 4 deletions.
4 changes: 4 additions & 0 deletions tfjs-backend-wasm/src/backend_wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ export class BackendWasm extends KernelBackend {
this.wasm = null;
}

memory() {
return {unreliable: false};
}

makeOutput(shape: number[], dtype: DataType): TensorInfo {
const dataId = this.write(null /* values */, shape, dtype);
return {dataId, shape, dtype};
Expand Down
9 changes: 9 additions & 0 deletions tfjs-backend-wasm/src/cc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,15 @@ tfjs_cc_library(
],
)

tfjs_cc_library(
name = "Square",
srcs = ["kernels/Square.cc"],
deps = [
":backend",
":unary",
],
)

tfjs_cc_library(
name = "Sub",
srcs = ["kernels/Sub.cc"],
Expand Down
38 changes: 38 additions & 0 deletions tfjs-backend-wasm/src/cc/kernels/Square.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ===========================================================================*/

#ifdef __EMSCRIPTEN__
#include <emscripten.h>
#endif

#include "src/cc/backend.h"
#include "src/cc/unary.h"

namespace {
inline float square(float val) { return val * val; }
} // namespace

namespace tfjs {
namespace wasm {
// We use C-style API to interface with Javascript.
extern "C" {

#ifdef __EMSCRIPTEN__
EMSCRIPTEN_KEEPALIVE
#endif
void Square(int x_id, int out_id) { unary(x_id, out_id, square); }

} // extern "C"
} // namespace wasm
} // namespace tfjs
145 changes: 145 additions & 0 deletions tfjs-backend-wasm/src/kernels/Slice.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/**
* @license
* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {backend_util, buffer, NamedAttrMap, NamedTensorInfoMap, registerKernel, slice_util, util} from '@tensorflow/tfjs-core';
import {TensorInfo} from '@tensorflow/tfjs-core';

import {BackendWasm} from '../backend_wasm';

interface SliceInputs extends NamedTensorInfoMap {
x: TensorInfo;
}

interface SliceAttrs extends NamedAttrMap {
begin: number[];
size: number[];
}

function slice(
args: {inputs: SliceInputs, attrs: SliceAttrs, backend: BackendWasm}) {
const {inputs: {x}, attrs: {begin, size}, backend} = args;
const isContinous = slice_util.isSliceContinous(x.shape, begin, size);
const {memoryOffset: xOffset} = backend.dataIdMap.get(x.dataId);
const xVals =
backend.typedArrayFromHeap(xOffset, x.dtype, util.sizeFromShape(x.shape));
const out = backend.makeOutput(size, x.dtype);
const {memoryOffset: outOffset} = backend.dataIdMap.get(out.dataId);
const outVals = backend.typedArrayFromHeap(
outOffset, out.dtype, util.sizeFromShape(out.shape));
const xStrides = util.computeStrides(x.shape);
if (isContinous) {
const flatOffset = slice_util.computeFlatOffset(begin, xStrides);
outVals.set(
xVals.subarray(flatOffset, flatOffset + util.sizeFromShape(size)));
return out;
}
const rank = x.shape.length;
if (rank === 2) {
slice2d(
xVals, xStrides[0], outVals, begin as [number, number],
size as [number, number]);
} else if (rank === 3) {
slice3d(
xVals, xStrides[0], xStrides[1], outVals,
begin as [number, number, number], size as [number, number, number]);
} else if (rank === 4) {
slice4d(
xVals, xStrides[0], xStrides[1], xStrides[2], outVals,
begin as [number, number, number, number],
size as [number, number, number, number]);
} else {
genericSliceSlow(xVals, x, outVals, begin, size);
}
return out;
}

function slice2d(
xVals: backend_util.TypedArray, xStride: number,
outVals: backend_util.TypedArray, begin: [number, number],
size: [number, number]): void {
let outOffset = 0;
const beginI = begin[0];
const beginJ = begin[1];
const endI = beginI + size[0];
for (let i = beginI; i < endI; i++) {
const xOffset = i * xStride + beginJ;
outVals.set(xVals.subarray(xOffset, xOffset + size[1]), outOffset);
outOffset += size[1];
}
}

function slice3d(
xVals: backend_util.TypedArray, xStride1: number, xStride2: number,
outVals: backend_util.TypedArray, begin: [number, number, number],
size: [number, number, number]): void {
let outOffset = 0;
const beginI = begin[0];
const beginJ = begin[1];
const beginK = begin[2];
const endI = beginI + size[0];
const endJ = beginJ + size[1];
for (let i = beginI; i < endI; i++) {
for (let j = beginJ; j < endJ; j++) {
const xOffset = i * xStride1 + j * xStride2 + beginK;
outVals.set(xVals.subarray(xOffset, xOffset + size[2]), outOffset);
outOffset += size[2];
}
}
}

function slice4d(
xVals: backend_util.TypedArray, xStride1: number, xStride2: number,
xStride3: number, outVals: backend_util.TypedArray,
begin: [number, number, number, number],
size: [number, number, number, number]): void {
let outOffset = 0;
const beginI = begin[0];
const beginJ = begin[1];
const beginK = begin[2];
const endI = beginI + size[0];
const endJ = beginJ + size[1];
const endK = beginK + size[2];
const beginL = begin[3];

for (let i = beginI; i < endI; i++) {
for (let j = beginJ; j < endJ; j++) {
for (let k = beginK; k < endK; k++) {
const xOffset = i * xStride1 + j * xStride2 + k * xStride3 + beginL;
outVals.set(xVals.subarray(xOffset, xOffset + size[3]), outOffset);
outOffset += size[3];
}
}
}
}

function genericSliceSlow(
xVals: backend_util.TypedArray, xInfo: TensorInfo,
outVals: backend_util.TypedArray, begin: number[], size: number[]): void {
const outBuf = buffer(size, xInfo.dtype, outVals);
const xBuf = buffer(xInfo.shape, xInfo.dtype, xVals);
for (let i = 0; i < outBuf.size; ++i) {
const loc = outBuf.indexToLoc(i);
const xLoc = loc.map((idx, j) => idx + begin[j]);
outVals[i] = xBuf.get(...xLoc) as number;
}
}

registerKernel({
kernelName: 'Slice',
backendName: 'wasm',
kernelFunc: slice,
});
19 changes: 19 additions & 0 deletions tfjs-backend-wasm/src/kernels/Square.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/**
* @license
* Copyright 2019 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {registerUnaryKernel} from './unary_kernel';
registerUnaryKernel('Square');
1 change: 1 addition & 0 deletions tfjs-backend-wasm/src/kernels/all_kernels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ import './Mul';
import './Prelu';
import './Reshape';
import './Sigmoid';
import './Slice';
import './Sub';
2 changes: 1 addition & 1 deletion tfjs-backend-wasm/src/setup_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const grepFilter = env.specFilter;
/** Tests that have these substrings in their name will be included. */
const INCLUDE_LIST: string[] = [
'add ', 'matmul ', 'prelu ', ' cast', 'sigmoid', 'abs ', 'sub ', 'mul ',
'div '
'div ', 'slice ', 'square '
];
/** Tests that have these substrings in their name will be excluded. */
const EXCLUDE_LIST: string[] = [
Expand Down
4 changes: 3 additions & 1 deletion tfjs-core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import * as backend_util from './backends/backend_util';
import * as io from './io/io';
import * as math from './math';
import * as browser from './ops/browser';
import * as slice_util from './ops/slice_util';
import * as serialization from './serialization';
import {setOpHandler} from './tensor';
import * as tensor_util from './tensor_util';
Expand Down Expand Up @@ -97,7 +98,8 @@ export {
util,
backend_util,
webgl,
tensor_util
tensor_util,
slice_util
};

// Backend specific.
Expand Down
6 changes: 4 additions & 2 deletions tfjs-core/src/ops/slice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,12 @@ function slice_<R extends Rank, T extends Tensor<R>>(
for (let i = 0; i < dy.rank; i++) {
paddings.push([begin_[i], inputShape[i] - begin_[i] - size_[i]]);
}
return {$x: () => dy.pad(paddings)};
return {x: () => dy.pad(paddings)};
};
const attrs = {begin: begin_, size: size_};
return ENGINE.runKernelFunc(
backend => backend.slice($x, begin_, size_), {$x}, grad);
backend => backend.slice($x, begin_, size_), {x: $x}, grad, 'Slice',
attrs);
}

export const slice = op({slice_});
Expand Down

0 comments on commit 6794d7c

Please sign in to comment.