Skip to content

Commit

Permalink
pnnx convert nn.RMSNorm F.rms_norm (#5628)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Aug 14, 2024
1 parent c46278d commit eb6e084
Show file tree
Hide file tree
Showing 10 changed files with 331 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ set(pnnx_pass_level1_SRCS
pass_level1/nn_ReplicationPad1d.cpp
pass_level1/nn_ReplicationPad2d.cpp
pass_level1/nn_ReplicationPad3d.cpp
pass_level1/nn_RMSNorm.cpp
pass_level1/nn_RNN.cpp
pass_level1/nn_RReLU.cpp
pass_level1/nn_SELU.cpp
Expand Down Expand Up @@ -163,6 +164,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/F_prelu.cpp
pass_level2/F_relu.cpp
pass_level2/F_relu6.cpp
pass_level2/F_rms_norm.cpp
pass_level2/F_rrelu.cpp
pass_level2/F_scaled_dot_product_attention.cpp
pass_level2/F_selu.cpp
Expand Down Expand Up @@ -383,6 +385,7 @@ set(pnnx_pass_level5_SRCS
pass_level5/fuse_static_layernorm.cpp
pass_level5/fuse_static_linear.cpp
pass_level5/fuse_static_prelu.cpp
pass_level5/fuse_static_rmsnorm.cpp
pass_level5/normalize_einsum_equation.cpp
pass_level5/unroll_rnn_op.cpp
)
Expand Down
51 changes: 51 additions & 0 deletions tools/pnnx/src/pass_level1/nn_RMSNorm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level1.h"

#include "../utils.h"

namespace pnnx {

class RMSNorm : public FuseModulePass
{
public:
const char* match_type_str() const
{
return "__torch__.torch.nn.modules.normalization.RMSNorm";
}

const char* type_str() const
{
return "nn.RMSNorm";
}

void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
{
const torch::jit::Node* rmsn = find_node_by_kind(graph, "aten::rms_norm");

op->params["normalized_shape"] = rmsn->namedInput("normalized_shape");
op->params["eps"] = rmsn->namedInput("eps");
op->params["elementwise_affine"] = mod.hasattr("weight") && mod.hasattr("bias");

if (mod.hasattr("weight"))
{
op->attrs["weight"] = mod.attr("weight").toTensor();
}
}
};

REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(RMSNorm)

} // namespace pnnx
43 changes: 43 additions & 0 deletions tools/pnnx/src/pass_level2/F_rms_norm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class F_rms_norm : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
6 5
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 weight
pnnx.Input input_2 0 1 normalized_shape
prim::Constant op_0 0 1 eps value=%eps
aten::rms_norm op_1 4 1 input normalized_shape weight eps out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.rms_norm";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_rms_norm, 10)

} // namespace pnnx
2 changes: 2 additions & 0 deletions tools/pnnx/src/pass_level5.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
#include "pass_level5/fuse_static_layernorm.h"
#include "pass_level5/fuse_static_linear.h"
#include "pass_level5/fuse_static_prelu.h"
#include "pass_level5/fuse_static_rmsnorm.h"
#include "pass_level5/normalize_einsum_equation.h"
#include "pass_level4/dead_code_elimination.h"
#include "pass_level4/canonicalize.h"
Expand Down Expand Up @@ -102,6 +103,7 @@ void pass_level5(Graph& g, const std::set<std::string>& foldable_constants, cons
fuse_static_groupnorm(g);
fuse_static_instancenorm(g);
fuse_static_layernorm(g);
fuse_static_rmsnorm(g);

fuse_static_conv(g);
fuse_static_convtranspose(g);
Expand Down
57 changes: 57 additions & 0 deletions tools/pnnx/src/pass_level5/fuse_static_rmsnorm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_static_rmsnorm.h"

#include "pass_level2.h"

#include <math.h>
#include <string.h>

namespace pnnx {

class fuse_static_Frmsnorm_pass : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Attribute op_weight 0 1 weight @data
F.rms_norm op_0 2 1 input weight out normalized_shape=%normalized_shape eps=%eps
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* replace_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.RMSNorm rmsn 1 1 input out normalized_shape=%normalized_shape eps=%eps elementwise_affine=True @weight=%op_weight.data
pnnx.Output output 1 0 out
)PNNXIR";
}
};

void fuse_static_rmsnorm(Graph& graph)
{
fuse_static_Frmsnorm_pass a;
int opindex = 0;

pnnx_graph_rewrite(graph, &a, opindex);
}

} // namespace pnnx
21 changes: 21 additions & 0 deletions tools/pnnx/src/pass_level5/fuse_static_rmsnorm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

void fuse_static_rmsnorm(Graph& graph);

} // namespace pnnx
4 changes: 4 additions & 0 deletions tools/pnnx/src/pass_ncnn/solve_batch_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ static bool is_known_operator_with_batch_index_0(const Operator* op)
"F.group_norm",
"F.instance_norm",
"F.interpolate",
"F.layer_norm",
"F.linear",
"F.local_response_norm",
"F.lp_pool1d",
Expand All @@ -56,6 +57,7 @@ static bool is_known_operator_with_batch_index_0(const Operator* op)
"F.pixel_shuffle",
"F.pixel_unshuffle",
"F.prelu",
"F.rms_norm",
"F.scaled_dot_product_attention",
"F.unfold",
"F.upsample_bilinear",
Expand Down Expand Up @@ -91,6 +93,7 @@ static bool is_known_operator_with_batch_index_0(const Operator* op)
"nn.InstanceNorm2d",
"nn.InstanceNorm3d",
"nn.LocalResponseNorm",
"nn.LayerNorm",
"nn.LPPool1d",
"nn.LPPool2d",
"nn.MaxPool1d",
Expand All @@ -104,6 +107,7 @@ static bool is_known_operator_with_batch_index_0(const Operator* op)
"nn.ReplicationPad1d",
"nn.ReplicationPad2d",
"nn.ReplicationPad3d",
"nn.RMSNorm",
"nn.Softmax2d",
"nn.Unfold",
"nn.Upsample",
Expand Down
2 changes: 2 additions & 0 deletions tools/pnnx/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ pnnx_add_test(F_pixel_unshuffle)
pnnx_add_test(F_prelu)
pnnx_add_test(F_relu)
pnnx_add_test(F_relu6)
pnnx_add_test(F_rms_norm)
pnnx_add_test(F_rrelu)
pnnx_add_test(F_scaled_dot_product_attention)
pnnx_add_test(F_selu)
Expand Down Expand Up @@ -145,6 +146,7 @@ pnnx_add_test(nn_ReLU6)
pnnx_add_test(nn_ReplicationPad1d)
pnnx_add_test(nn_ReplicationPad2d)
pnnx_add_test(nn_ReplicationPad3d)
pnnx_add_test(nn_RMSNorm)
pnnx_add_test(nn_RNN)
pnnx_add_test(nn_RReLU)
pnnx_add_test(nn_SELU)
Expand Down
77 changes: 77 additions & 0 deletions tools/pnnx/tests/test_F_rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F
from packaging import version

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.w3 = nn.Parameter(torch.rand(24))
self.w4 = nn.Parameter(torch.rand(12, 16))
self.w5 = nn.Parameter(torch.rand(24))

def forward(self, x, y, z, w0, w1, w2):
x = F.rms_norm(x, (24,), w0)
x = F.rms_norm(x, (12,24), None)
x = F.rms_norm(x, (24,), self.w3)

y = F.rms_norm(y, (16,), None, eps=1e-3)
y = F.rms_norm(y, (12,16), w1)
y = F.rms_norm(y, (12,16), self.w4)

z = F.rms_norm(z, (24,), w2)
z = F.rms_norm(z, (12,16,24), None, eps=1e-2)
z = F.rms_norm(z, (24,), self.w5)
return x, y, z

def test():
if version.parse(torch.__version__) < version.parse('2.4'):
return True

net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(1, 12, 24)
y = torch.rand(2, 3, 12, 16)
z = torch.rand(1, 10, 12, 16, 24)
w0 = torch.rand(24)
w1 = torch.rand(12, 16)
w2 = torch.rand(24)

a0, a1, a2 = net(x, y, z, w0, w1, w2)

# export torchscript
mod = torch.jit.trace(net, (x, y, z, w0, w1, w2))
mod.save("test_F_rms_norm.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_F_rms_norm.pt inputshape=[1,12,24],[2,3,12,16],[1,10,12,16,24],[24],[12,16],[24]")

# pnnx inference
import test_F_rms_norm_pnnx
b0, b1, b2 = test_F_rms_norm_pnnx.test_inference()

return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2)

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)
Loading

0 comments on commit eb6e084

Please sign in to comment.