Skip to content

Commit

Permalink
[CPU] Enabled MatMul+Transpose transformations and reduced MatMul inf…
Browse files Browse the repository at this point in the history
…erence overheads (openvinotoolkit#6570)
  • Loading branch information
a-sidorova authored Jul 28, 2021
1 parent 1471095 commit 1aa58b4
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 43 deletions.
2 changes: 2 additions & 0 deletions inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
#include <transformations/op_conversions/convert_previous_nms_to_nms_5.hpp>
#include <transformations/op_conversions/convert_nms_to_nms_ie_internal.hpp>
#include <transformations/op_conversions/convert_deformable_conv_v8_to_v1.hpp>
#include <transformations/smart_reshape/matmul_sr.hpp>
#include <transformations/convert_precision.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/rt_info/fused_names_attribute.hpp>
Expand Down Expand Up @@ -167,6 +168,7 @@ static void Transformation(CNNNetwork& clonedNetwork, const Config& conf) {
manager.register_pass<ngraph::pass::ConvertNMS3ToNMS5>();
manager.register_pass<ngraph::pass::ConvertNMS4ToNMS5>();
manager.register_pass<ngraph::pass::ConvertNMSToNMSIEInternal>();
manager.register_pass<ngraph::pass::TransposeMatMul>();
manager.register_pass<ngraph::pass::ConstantFolding>();

if (useLpt) {
Expand Down
96 changes: 57 additions & 39 deletions inference-engine/src/mkldnn_plugin/nodes/mkldnn_matmul_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ MKLDNNMatMulNode::MKLDNNMatMulNode(const std::shared_ptr<ngraph::Node>& op, cons
errorPrefix = "Gemm node with name '" + getName() + "'";

const auto matMul = std::dynamic_pointer_cast<const ngraph::opset1::MatMul>(op);
alpha = 1;
beta = 1;
alpha = 1.f;
beta = 0.f;
transposeA = matMul->get_transpose_a();
transposeB = matMul->get_transpose_b();
} else {
Expand Down Expand Up @@ -179,6 +179,34 @@ void MKLDNNMatMulNode::createPrimitive() {
IE_THROW() << errorPrefix << " did not allocate input memory";
if (getSelectedPrimitiveDescriptor() == nullptr)
IE_THROW() << errorPrefix << " did not set preferable primitive descriptor";

auto inDims0 = src0MemPtr->GetDims();
auto outDims = dstMemPtr->GetDims();

params.src0_mem_ptr = src0MemPtr;
params.src1_mem_ptr = src1MemPtr;
params.dst_mem_ptr = dstMemPtr;

params.ndims = outDims.size();

params.MB1 = 1;
params.MB2 = outDims.size() > 3 ? outDims[params.ndims - 3] : 1;

params.M = outDims[yAxis];
params.N = outDims[xAxis];
params.K = transposeA ? inDims0[yAxis] : inDims0[xAxis];

params.transa = transposeA ? 'T' : 'N';
params.transb = transposeB ? 'T' : 'N';

params.lda = transposeA ? params.M : params.K;
params.ldb = transposeB ? params.K : params.N;
params.ldc = params.N;

params.shift1 = params.M * params.N * params.MB2;
params.shift2 = params.M * params.N;

runtimePrecision = getParentEdgeAt(0)->getDesc().getPrecision();
}

inline void process_gemm(char transa, char transb, int M, int N, int K, float alpha, const float *A, int lda,
Expand Down Expand Up @@ -212,67 +240,57 @@ inline void process_gemm(char transa, char transb, int M, int N, int K, float al
}

template<typename T0, typename T1>
void MKLDNNMatMulNode::process_data() {
auto inDims0 = getParentEdgeAt(0)->getDims();
auto inDims1 = getParentEdgeAt(1)->getDims();
auto outDims = getChildEdgeAt(0)->getDims();

auto& srcMemory0 = getParentEdgeAt(0)->getMemory();
auto& srcMemory1 = getParentEdgeAt(1)->getMemory();
auto& dstMemory0 = getChildEdgeAt(0)->getMemory();

const T0 *src0_ptr = reinterpret_cast<const T0*>(srcMemory0.GetPtr());
const T1 *src1_ptr = reinterpret_cast<const T1*>(srcMemory1.GetData());
float *dst_ptr = reinterpret_cast<float*>(dstMemory0.GetData());

int MB1 = outDims.ndims() == 4 ? batchToProcess() : 1;
int MB2 = outDims.ndims() == 3 ? batchToProcess() : outDims.ndims() > 3 ? outDims[outDims.ndims() - 3] : 1;
int M = outDims[yAxis];
int N = outDims[xAxis];
int K = transposeA ? inDims0[yAxis] : inDims0[xAxis];

const char transa = transposeA ? 'T' : 'N';
const char transb = transposeB ? 'T' : 'N';

int lda = transposeA ? M : K;
int ldb = transposeB ? K : N;
int ldc = N;

beta = 0.f;
inline void MKLDNNMatMulNode::process_data() {
const T0* src0_ptr = reinterpret_cast<const T0*>(params.src0_mem_ptr->GetPtr());
const T1* src1_ptr = reinterpret_cast<const T1*>(params.src1_mem_ptr->GetPtr());
float* dst_ptr = reinterpret_cast<float*>(params.dst_mem_ptr->GetPtr());

const int MB = batchToProcess();
if (params.ndims == 4) {
params.MB1 = MB;
} else if (params.ndims == 3) {
params.shift1 = params.shift1 * MB / params.MB2;
params.MB2 = MB;
}

for (int b1 = 0; b1 < MB1; b1++) {
for (int b1 = 0; b1 < params.MB1; ++b1) {
const T0 *a_ptr = src0_ptr;
const T1 *b_ptr = src1_ptr;
float *d_ptr = dst_ptr;

for (int b2 = 0; b2 < MB2; b2++) {
process_gemm(transa, transb, M, N, K, alpha, a_ptr, lda, b_ptr, ldb, beta, d_ptr, ldc);
for (int b2 = 0; b2 < params.MB2; ++b2) {
process_gemm(params.transa, params.transb, params.M, params.N, params.K,
alpha, a_ptr, params.lda, b_ptr, params.ldb, beta, d_ptr, params.ldc);

a_ptr += aOffsets[0];
b_ptr += bOffsets[0];
d_ptr += M * N;
d_ptr += params.shift2;
}

src0_ptr += aOffsets[1];
src1_ptr += bOffsets[1];
dst_ptr += MB2 * M * N;
dst_ptr += params.shift1;
}
}

void MKLDNNMatMulNode::execute(mkldnn::stream strm) {
switch (getParentEdgeAt(0)->getDesc().getPrecision()) {
case Precision::FP32:
switch (runtimePrecision) {
case Precision::FP32: {
process_data<float, float>();
break;
case Precision::BF16:
}
case Precision::BF16: {
process_data<uint16_t, uint16_t>();
break;
case Precision::I8:
}
case Precision::I8: {
process_data<int8_t, int8_t>();
break;
case Precision::U8:
}
case Precision::U8: {
process_data<uint8_t, int8_t>();
break;
}
default:
IE_THROW() << errorPrefix << " has incorrect precision on first input";
}
Expand Down
33 changes: 30 additions & 3 deletions inference-engine/src/mkldnn_plugin/nodes/mkldnn_matmul_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class MKLDNNMatMulNode : public MKLDNNNode {
static bool isSupportedOperation(const std::shared_ptr<ngraph::Node>& op, std::string& errorMessage) noexcept;

private:
float alpha = 1.0f;
float beta = 1.0f;
float alpha = 1.f;
float beta = 0.f;
bool transposeA = false;
bool transposeB = false;

Expand All @@ -40,9 +40,36 @@ class MKLDNNMatMulNode : public MKLDNNNode {
std::vector<int> bOffsets;
std::vector<int> cOffsets;

template<typename T0, typename T1> void process_data();
InferenceEngine::Precision runtimePrecision;

template<typename T0, typename T1> inline void process_data();

std::string errorPrefix;

struct {
MKLDNNMemoryPtr src0_mem_ptr = nullptr;
MKLDNNMemoryPtr src1_mem_ptr = nullptr;
MKLDNNMemoryPtr dst_mem_ptr = nullptr;

char transa = 'N';
char transb = 'N';

int MB1 = 1;
int MB2 = 1;

int M = 0;
int N = 0;
int K = 0;

int lda = 0;
int ldb = 0;
int ldc = 0;

int shift1 = 0;
int shift2 = 0;

size_t ndims = 0;
} params;
};

} // namespace MKLDNNPlugin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ const std::vector<ShapeRelatedParams> shapeRelatedParams = {
{ { {1, 4, 5, 6}, false }, { {1, 4, 6, 4}, false } },
{ { {4, 5, 6}, false }, { {6, 3}, false } },
{ { {9, 9, 9}, false }, { {9, 9}, false } },
{ { {1, 2, 3}, false }, { {1, 10, 3}, true } },
{ { {1, 2, 3}, false }, { {1, 3, 10}, false } },
{ { {1, 2, 3}, false }, { {1, 1, 3, 2}, false } },
{ { {1, 3, 2, 4}, false }, { {2, 1, 4, 2}, false } },
{ { {2, 1, 2, 4}, false }, { {1, 3, 4, 2}, false } },
Expand All @@ -30,7 +32,7 @@ const std::vector<ShapeRelatedParams> shapeRelatedParams = {
{ { {2, 2, 1, 3}, false }, { {3}, false } },
{ { {1, 5}, false }, { {5, 1}, false } },
{ { {5, 1}, true }, { {5, 1}, false } },
{ { {1, 5}, false }, { {1, 5}, true } },
{ { {1, 5}, false }, { {10, 5}, true } },
{ { {1, 5}, false }, { {5}, false } },
{ { {5}, false }, { {5, 1}, false } },
{ { {5}, false }, { {5}, false } },
Expand Down

0 comments on commit 1aa58b4

Please sign in to comment.