Skip to content

Commit

Permalink
implement Pad on webgl
Browse files Browse the repository at this point in the history
  • Loading branch information
milhidaka committed Feb 15, 2022
1 parent 4db7739 commit 4dc051e
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 19 deletions.
1 change: 1 addition & 0 deletions src/descriptor_runner/core/operatorTable.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export function instantiateOperator(
// 特殊なオペレータ
switch (opType) {
case "Flatten":
case "Pad":
case "Reshape":
case "Squeeze":
case "Transpose":
Expand Down
45 changes: 45 additions & 0 deletions src/descriptor_runner/operators/base/pad11.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import { OperatorImpl } from "../operatorImpl";
import { Tensor } from "../../interface/core/tensor";
import { onnx } from "onnx-proto";
import { getAttrString } from "../operatorUtil";
import { Backend } from "../../interface/core/constants";
import { CPUTensor } from "../../interface/backend/cpu/cpuTensor";

type PadMode = "constant" | "reflect" | "edge";

/*
* Opset 11
* opset 2は互換性なし
*/
export abstract class Pad11 extends OperatorImpl {
mode!: PadMode;

initialize(attribute: onnx.IAttributeProto[]): void {
super.initialize(attribute);
this.mode = getAttrString(attribute, "mode", "constant") as PadMode;
}

getTensorBackendRequirement(
nInputs: number,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
nOutputs: number
): (Backend | null)[] {
if (nInputs === 2) {
return [this.backend, "cpu"];
} else {
return [this.backend, "cpu", "cpu"];
}
}

protected calcShape(
input: Tensor,
padTensor: CPUTensor
): { outputShape: number[]; pads: number[] } {
const outputShape: number[] = [];
const pads: number[] = Array.from(padTensor.data);
for (let i = 0; i < input.ndim; i++) {
outputShape.push(input.dims[i] + pads[i] + pads[i + input.ndim]);
}
return { outputShape, pads };
}
}
24 changes: 5 additions & 19 deletions src/descriptor_runner/operators/cpu/operators/standard/pad11.ts
Original file line number Diff line number Diff line change
@@ -1,42 +1,28 @@
import { DataArrayTypes } from "../../../../interface/core/constants";
import { OperatorImpl } from "../../../operatorImpl";
import { WebDNNCPUContext } from "../../../../interface/backend/cpu/cpuContext";
import { Tensor } from "../../../../interface/core/tensor";
import { OperatorEntry } from "../../../../interface/core/operator";
import { onnx } from "onnx-proto";
import { getAttrString } from "../../../operatorUtil";

type PadMode = "constant" | "reflect" | "edge";
import { Pad11 } from "../../../base/pad11";

/*
* Opset 11
* opset 2は互換性なし
*/
class Pad11 extends OperatorImpl {
mode!: PadMode;

class CPUPad11 extends Pad11 {
constructor() {
super("cpu");
}

initialize(attribute: onnx.IAttributeProto[]): void {
super.initialize(attribute);
this.mode = getAttrString(attribute, "mode", "constant") as PadMode;
}

async run(context: WebDNNCPUContext, inputs: Tensor[]): Promise<Tensor[]> {
context.assertsCPUTensorArray(inputs);
const input = inputs[0],
pads = Array.from(inputs[1].data),
constantValueTensor = inputs[2];
const [input, shapeTensor, constantValueTensor] = inputs;
const { outputShape, pads } = this.calcShape(input, shapeTensor);
let constantValue = 0;
if (constantValueTensor) {
constantValue = constantValueTensor.data[0];
}
const outputShape: number[] = [];
for (let i = 0; i < input.ndim; i++) {
outputShape.push(input.dims[i] + pads[i] + pads[i + input.ndim]);
}

// edge:
// [0,1,2,3] -> pad (3,3) -> [0,0,0,*0,1,2,3*,3,3,3]
Expand Down Expand Up @@ -1271,7 +1257,7 @@ export function getOpEntries(): OperatorEntry[] {
opType: "Pad",
backend: "cpu",
opsetMin: 11,
factory: () => new Pad11(),
factory: () => new CPUPad11(),
},
];
}
2 changes: 2 additions & 0 deletions src/descriptor_runner/operators/webgl/opEntriesStandard.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { getOpEntries as getOpEntriesoperatorsstandardgemm } from "./operators/s
import { getOpEntries as getOpEntriesoperatorsstandardglobalaveragepool } from "./operators/standard/globalaveragepool";
import { getOpEntries as getOpEntriesoperatorsstandardmatmul } from "./operators/standard/matmul";
import { getOpEntries as getOpEntriesoperatorsstandardmaxpool } from "./operators/standard/maxpool";
import { getOpEntries as getOpEntriesoperatorsstandardpad11 } from "./operators/standard/pad11";
import { getOpEntries as getOpEntriesoperatorsstandardreduce } from "./operators/standard/reduce";
import { getOpEntries as getOpEntriesoperatorsstandardreshape5 } from "./operators/standard/reshape5";
import { getOpEntries as getOpEntriesoperatorsstandardsoftmax } from "./operators/standard/softmax";
Expand All @@ -34,6 +35,7 @@ export function getOpEntries(): OperatorEntry[] {
entries.push(...getOpEntriesoperatorsstandardglobalaveragepool());
entries.push(...getOpEntriesoperatorsstandardmatmul());
entries.push(...getOpEntriesoperatorsstandardmaxpool());
entries.push(...getOpEntriesoperatorsstandardpad11());
entries.push(...getOpEntriesoperatorsstandardreduce());
entries.push(...getOpEntriesoperatorsstandardreshape5());
entries.push(...getOpEntriesoperatorsstandardsoftmax());
Expand Down
150 changes: 150 additions & 0 deletions src/descriptor_runner/operators/webgl/operators/standard/pad11.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import { Tensor } from "../../../../interface/core/tensor";
import { OperatorEntry } from "../../../../interface/core/operator";
import { Pad11 } from "../../../base/pad11";
import {
WebDNNWebGLContext,
WebGLUniformItem,
} from "../../../../interface/backend/webgl/webglContext";
import {
shaderGenHeader,
shaderGenOutput,
shaderGenTensorNDGet,
shaderGenTensorNDGetUniformItem,
shaderGenTensorOutputCoordsWithReturn,
shaderGenTensorOutputUniform,
shaderGenTensorOutputUniformItem,
} from "../../shaderHelper";
import { arange } from "../../../../util";

/*
* Opset 11
* opset 2は互換性なし
*/
class WebGLPad11 extends Pad11 {
constructor() {
super("webgl");
}

async run(context: WebDNNWebGLContext, inputs: Tensor[]): Promise<Tensor[]> {
const [input, shapeTensor, constantValueTensor] = inputs;
context.assertsWebGLTensor(input);
context.cpuContext.assertsCPUTensor(shapeTensor);
const { outputShape: outShape, pads } = this.calcShape(input, shapeTensor);
let constantValue = 0;
if (constantValueTensor) {
context.cpuContext.assertsCPUTensor(constantValueTensor);
constantValue = constantValueTensor.data[0];
}
const output = context.emptyTensor(outShape, "float32");
const kernelName = `pad_${outShape.length}_${this.mode}`;
const padUniforms = arange(outShape.length)
.map((dim) => `uniform int pad${dim};`)
.join("");
const inShapeUniforms = arange(outShape.length)
.map((dim) => `uniform int inShape${dim};`)
.join("");
const constantUniform =
this.mode === "constant" ? "uniform float padConstant;" : "";
const tex_input_idxs = arange(outShape.length)
.map((dim) => `ti${dim}`)
.join(",");
const minusPad = arange(outShape.length)
.map((dim) => `int ti${dim} = tex_output_${dim} - pad${dim};`)
.join("");
const outOfBoundCond = arange(outShape.length)
.map((dim) => `ti${dim} < 0 || ti${dim} >= inShape${dim}`)
.join("||");
let indexAdjuster: string;
let valueGetter: string;
switch (this.mode) {
case "constant":
indexAdjuster = "";
valueGetter = `if (${outOfBoundCond}) {s = padConstant;} else {s = get_tex_input(${tex_input_idxs});}`;
break;
case "edge":
indexAdjuster = arange(outShape.length)
.map(
(dim) =>
`if (ti${dim} < 0) {ti${dim} = 0;} else if (ti${dim} >= inShape${dim}) {ti${dim} = inShape${dim} - 1;}`
)
.join("");
valueGetter = `s = get_tex_input(${tex_input_idxs});`;
break;
case "reflect":
indexAdjuster = arange(outShape.length)
.map(
(dim) =>
`if (ti${dim} < 0) {ti${dim} = pad_mod(-ti${dim}, inShape${dim} * 2 - 2); if (ti${dim} >= inShape${dim}) {ti${dim} = inShape${dim} * 2 - ti${dim} - 2;}} else if (ti${dim} >= inShape${dim}) {ti${dim} = pad_mod(ti${dim}, inShape${dim} * 2 - 2); if (ti${dim} >= inShape${dim}) {ti${dim} = inShape${dim} * 2 - ti${dim} - 2;}}`
)
.join("");
valueGetter = `s = get_tex_input(${tex_input_idxs});`;
break;
}
const kernelSource = `${shaderGenHeader(context.webgl2)}
int pad_mod(int x, int y) {
int z = x / y;
return x - z * y;
}
${padUniforms}
${constantUniform}
${inShapeUniforms}
${shaderGenTensorOutputUniform(outShape.length)}
${shaderGenTensorNDGet("tex_input", input.ndim, context.webgl2)}
void main() {
${shaderGenTensorOutputCoordsWithReturn(outShape.length)}
${minusPad}
${indexAdjuster}
float s;
${valueGetter}
${shaderGenOutput("s", context.webgl2)}
return;
}
`;
context.addKernel(kernelName, kernelSource);

const uniforms: WebGLUniformItem[] = [
...shaderGenTensorNDGetUniformItem(
"tex_input",
input.strides,
input,
context.webgl2
),
...shaderGenTensorOutputUniformItem(outShape, output, context.webgl2),
];
for (let dim = 0; dim < outShape.length; dim++) {
uniforms.push({ name: `pad${dim}`, value: pads[dim], type: "int" });
uniforms.push({
name: `inShape${dim}`,
value: input.dims[dim],
type: "int",
});
}
if (this.mode === "constant") {
uniforms.push({
name: "padConstant",
value: constantValue,
type: "float",
});
}
await context.runKernel(
kernelName,
[{ tensor: input, name: "tex_input" }],
output,
uniforms
);
return [output];
}
}

export function getOpEntries(): OperatorEntry[] {
return [
{
opType: "Pad",
backend: "webgl",
opsetMin: 11,
factory: () => new WebGLPad11(),
},
];
}
2 changes: 2 additions & 0 deletions test/model_test/make_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,10 @@ def __init__(self, pad, mode="constant", constant=0.0):
self.constant = constant

def forward(self, x):
x = F.relu(x) # Padは入力テンソルがCPUならGPUに移動せず実行する仕様としているため、まずGPUに移動させる
return F.pad(x, pad=self.pad, mode=self.mode, value=self.constant)


class ReduceMax(nn.Module):
def forward(self, x):
return torch.max(x, -1, keepdim=True)[0]
Expand Down

0 comments on commit 4dc051e

Please sign in to comment.