Skip to content

Commit

Permalink
Merge pull request #954 from mil-tokyo/instancenorm
Browse files Browse the repository at this point in the history
implement InstanceNormalization
  • Loading branch information
milhidaka authored Feb 16, 2022
2 parents a1f331c + 6e2fdc8 commit 64666e5
Show file tree
Hide file tree
Showing 5 changed files with 308 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/descriptor_runner/operators/cpu/opEntriesStandard.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { getOpEntries as getOpEntriesoperatorsstandardflatten } from "./operator
import { getOpEntries as getOpEntriesoperatorsstandardgather } from "./operators/standard/gather";
import { getOpEntries as getOpEntriesoperatorsstandardgemm } from "./operators/standard/gemm";
import { getOpEntries as getOpEntriesoperatorsstandardglobalaveragepool } from "./operators/standard/globalaveragepool";
import { getOpEntries as getOpEntriesoperatorsstandardinstancenormalization } from "./operators/standard/instancenormalization";
import { getOpEntries as getOpEntriesoperatorsstandardmatmul } from "./operators/standard/matmul";
import { getOpEntries as getOpEntriesoperatorsstandardmaxpool } from "./operators/standard/maxpool";
import { getOpEntries as getOpEntriesoperatorsstandardpad11 } from "./operators/standard/pad11";
Expand Down Expand Up @@ -48,6 +49,7 @@ export function getOpEntries(): OperatorEntry[] {
entries.push(...getOpEntriesoperatorsstandardgather());
entries.push(...getOpEntriesoperatorsstandardgemm());
entries.push(...getOpEntriesoperatorsstandardglobalaveragepool());
entries.push(...getOpEntriesoperatorsstandardinstancenormalization());
entries.push(...getOpEntriesoperatorsstandardmatmul());
entries.push(...getOpEntriesoperatorsstandardmaxpool());
entries.push(...getOpEntriesoperatorsstandardpad11());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import { onnx } from "onnx-proto";
import { OperatorImpl } from "../../../operatorImpl";
import { getAttrInt } from "../../../operatorUtil";
import { WebDNNCPUContext } from "../../../../interface/backend/cpu/cpuContext";
import { Tensor } from "../../../../interface/core/tensor";
import { OperatorEntry } from "../../../../interface/core/operator";
import { arrayProd } from "../../../../util";

class InstanceNormalization extends OperatorImpl {
epsilon!: number;

constructor() {
super("cpu");
}

initialize(attribute: onnx.IAttributeProto[]): void {
super.initialize(attribute);
this.epsilon = getAttrInt(attribute, "epsilon", 1e-5);
}

async run(context: WebDNNCPUContext, inputs: Tensor[]): Promise<Tensor[]> {
context.assertsCPUTensorArray(inputs);
const [input, scale, bias] = inputs;
const reductionLength = arrayProd(input.dims.slice(2)),
output = context.emptyTensor(input.dims, input.dataType),
dI = input.data,
dO = output.data,
dS = scale.data,
dB = bias.data;
const [dimBatch, dimCh] = input.dims;
const [strideBatch, strideCh] = input.strides;
for (let batch = 0; batch < dimBatch; batch++) {
for (let ch = 0; ch < dimCh; ch++) {
const ofs = batch * strideBatch + ch * strideCh;
let sum = 0.0;
let sqsum = 0.0;
for (let r = 0; r < reductionLength; r++) {
const v = dI[ofs + r];
sum += v;
sqsum += v * v;
}
const mean = sum / reductionLength;
const variance = sqsum / reductionLength - mean * mean;
const invstd = 1 / Math.sqrt(variance + this.epsilon);
const chscale = dS[ch] * invstd;
const chbias = -mean * chscale + dB[ch];
for (let r = 0; r < reductionLength; r++) {
dO[ofs + r] = dI[ofs + r] * chscale + chbias;
}
}
}
return [output];
}
}

export function getOpEntries(): OperatorEntry[] {
return [
{
opType: "InstanceNormalization",
backend: "cpu",
opsetMin: 1,
factory: () => new InstanceNormalization(),
},
];
}
2 changes: 2 additions & 0 deletions src/descriptor_runner/operators/webgl/opEntriesStandard.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { getOpEntries as getOpEntriesoperatorsstandardconvtranspose } from "./op
import { getOpEntries as getOpEntriesoperatorsstandardflatten } from "./operators/standard/flatten";
import { getOpEntries as getOpEntriesoperatorsstandardgemm } from "./operators/standard/gemm";
import { getOpEntries as getOpEntriesoperatorsstandardglobalaveragepool } from "./operators/standard/globalaveragepool";
import { getOpEntries as getOpEntriesoperatorsstandardinstancenormalization } from "./operators/standard/instancenormalization";
import { getOpEntries as getOpEntriesoperatorsstandardmatmul } from "./operators/standard/matmul";
import { getOpEntries as getOpEntriesoperatorsstandardmaxpool } from "./operators/standard/maxpool";
import { getOpEntries as getOpEntriesoperatorsstandardpad11 } from "./operators/standard/pad11";
Expand All @@ -33,6 +34,7 @@ export function getOpEntries(): OperatorEntry[] {
entries.push(...getOpEntriesoperatorsstandardflatten());
entries.push(...getOpEntriesoperatorsstandardgemm());
entries.push(...getOpEntriesoperatorsstandardglobalaveragepool());
entries.push(...getOpEntriesoperatorsstandardinstancenormalization());
entries.push(...getOpEntriesoperatorsstandardmatmul());
entries.push(...getOpEntriesoperatorsstandardmaxpool());
entries.push(...getOpEntriesoperatorsstandardpad11());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import { onnx } from "onnx-proto";
import { OperatorImpl } from "../../../operatorImpl";
import { arrayProd, getAttrInt } from "../../../operatorUtil";
import {
shaderGenHeader,
shaderGenOutput,
shaderGenOutputVec4,
shaderGenTensorNDGet,
shaderGenTensorNDGetUniformItem,
shaderGenTensorNDGetVec4,
shaderGenTensorOutputCoordsWithReturn,
shaderGenTensorOutputUniform,
shaderGenTensorOutputUniformItem,
} from "../../shaderHelper";
import {
WebDNNWebGLContext,
WebGLUniformItem,
} from "../../../../interface/backend/webgl/webglContext";
import { Tensor } from "../../../../interface/core/tensor";
import { WebGLTensor } from "../../../../interface/backend/webgl/webglTensor";
import { OperatorEntry } from "../../../../interface/core/operator";

// Opset 1
export class InstanceNormalization extends OperatorImpl {
epsilon!: number;

constructor() {
super("webgl");
}

initialize(attribute: onnx.IAttributeProto[]): void {
super.initialize(attribute);
this.epsilon = getAttrInt(attribute, "epsilon", 1e-5);
}

async run(context: WebDNNWebGLContext, inputs: Tensor[]): Promise<Tensor[]> {
context.assertsWebGLTensorArray(inputs);
const [input, scale, bias] = inputs;
if (!context.webgl2) {
// mean, stdの2要素を出力することが難しいため
throw new Error("InstanceNormalization: WebGL1 is not supported");
}

const reductionLength = arrayProd(input.dims.slice(2));
const [dimBatch, dimCh] = input.dims;

// 統計量計算
const maxSumExpTensor = context.emptyTensor(
[dimBatch * dimCh * 4],
"float32",
{ dimPerPixel: 4 }
);
await this.calcStat(
context,
dimBatch,
dimCh,
reductionLength,
this.epsilon,
input,
scale,
bias,
maxSumExpTensor
);

// 結果計算
const output = context.emptyTensor(input.dims, input.dataType);
await this.calcOutput2(
context,
dimBatch,
dimCh,
reductionLength,
input,
maxSumExpTensor,
output
);
maxSumExpTensor.dispose();
return [output];
}

private async calcStat(
context: WebDNNWebGLContext,
batchLength: number,
chLength: number,
reductionLength: number,
epsilon: number,
input: WebGLTensor,
scale: WebGLTensor,
bias: WebGLTensor,
maxSumExpTensor: WebGLTensor
) {
const kernelName = `instancenormalization_stats_${reductionLength}`,
kernelSource = `${shaderGenHeader(context.webgl2)}
#define reductionLength ${reductionLength}
uniform float epsilon;
${shaderGenTensorOutputUniform(2)}
${shaderGenTensorNDGet("tex_input", 3, context.webgl2)}
${shaderGenTensorNDGet("tex_scale", 1, context.webgl2)}
${shaderGenTensorNDGet("tex_bias", 1, context.webgl2)}
void main() {
${shaderGenTensorOutputCoordsWithReturn(2)}
float s_sum = 0.0;
float s_sqsum = 0.0;
for (int i = 0; i < reductionLength; i++) {
float v = get_tex_input(tex_output_0, tex_output_1, i);
s_sum += v;
s_sqsum += v * v;
}
float s_mean = s_sum / float(reductionLength);
float s_var = s_sqsum / float(reductionLength) - s_mean * s_mean + epsilon;
float s_invstd = inversesqrt(s_var);
float s_scale = get_tex_scale(tex_output_1) * s_invstd;
float s_bias = -s_mean * s_scale + get_tex_bias(tex_output_1);
vec4 s = vec4(s_scale, s_bias, 0.0, 0.0);
${shaderGenOutputVec4("s", context.webgl2)}
return;
}
`;
context.addKernel(kernelName, kernelSource);
const uniforms: WebGLUniformItem[] = [
...shaderGenTensorNDGetUniformItem(
"tex_input",
[chLength * reductionLength, reductionLength, 1],
input,
context.webgl2
),
...shaderGenTensorNDGetUniformItem(
"tex_scale",
scale.strides,
scale,
context.webgl2
),
...shaderGenTensorNDGetUniformItem(
"tex_bias",
bias.strides,
bias,
context.webgl2
),
...shaderGenTensorOutputUniformItem(
[batchLength, chLength],
maxSumExpTensor,
context.webgl2
),
{ name: "epsilon", value: epsilon, type: "float" },
];
await context.runKernel(
kernelName,
[
{ tensor: input, name: "tex_input" },
{ tensor: scale, name: "tex_scale" },
{ tensor: bias, name: "tex_bias" },
],
maxSumExpTensor,
uniforms
);
}

private async calcOutput2(
context: WebDNNWebGLContext,
batchLength: number,
chLength: number,
reductionLength: number,
input: WebGLTensor,
maxSumExpTensor: WebGLTensor,
output: WebGLTensor
) {
const kernelName = `instancenormalization_output`,
kernelSource = `${shaderGenHeader(context.webgl2)}
${shaderGenTensorOutputUniform(3)}
${shaderGenTensorNDGet("tex_input", 3, context.webgl2)}
${shaderGenTensorNDGetVec4("tex_stats", 2, context.webgl2)}
void main() {
${shaderGenTensorOutputCoordsWithReturn(3)}
vec4 m = get_vec4_tex_stats(tex_output_0, tex_output_1);
float v = get_tex_input(tex_output_0, tex_output_1, tex_output_2);
float s = v * m.r + m.g;
${shaderGenOutput("s", context.webgl2)}
return;
}
`;
context.addKernel(kernelName, kernelSource);
const uniforms: WebGLUniformItem[] = [
...shaderGenTensorNDGetUniformItem(
"tex_input",
[chLength * reductionLength, reductionLength, 1],
input,
context.webgl2
),
...shaderGenTensorNDGetUniformItem(
"tex_stats",
[chLength, 1],
maxSumExpTensor,
context.webgl2
),
...shaderGenTensorOutputUniformItem(
[batchLength, chLength, reductionLength],
output,
context.webgl2
),
];
await context.runKernel(
kernelName,
[
{ tensor: input, name: "tex_input" },
{ tensor: maxSumExpTensor, name: "tex_stats" },
],
output,
uniforms
);
}
}

export function getOpEntries(): OperatorEntry[] {
return [
{
opType: "InstanceNormalization",
backend: "webgl",
opsetMin: 1,
factory: () => new InstanceNormalization(),
},
];
}
12 changes: 12 additions & 0 deletions test/model_test/make_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,17 @@ def forward(self, x):
return torch.split(x, self.split_size_or_sections, dim=self.dim)


class InstanceNorm(nn.Module):
def __init__(self, num_features, eps) -> None:
super().__init__()
self.instance_norm = torch.nn.InstanceNorm2d(num_features, eps, affine=True)
self.instance_norm.bias.data = torch.rand(*self.instance_norm.bias.data.shape)
self.instance_norm.weight.data = torch.rand(*self.instance_norm.weight.data.shape)

def forward(self, x):
return self.instance_norm(x)


def dump_expected(directory, arrays_dict):
casted_arrays_dict = {}
for k, array in arrays_dict.items():
Expand Down Expand Up @@ -717,6 +728,7 @@ def main():
dump("split1", Split([2, 3, 5, 7, 60-2-3-5-7], -1), [(3, 4, 5, 60)])
dump("split2", Split([2, 3, 5, 7, 40-2-3-5-7], 1), [(3, 40, 5, 6)])
dump("split3", Split(4, 0), [(30, 4, 5, 6)])
dump("instancenorm1", InstanceNorm(4, 0.01), [(3, 4, 5, 6)])
output_list()


Expand Down

0 comments on commit 64666e5

Please sign in to comment.