Skip to content

Commit

Permalink
[CPU] SoftMax cache (openvinotoolkit#9480)
Browse files Browse the repository at this point in the history
* [CPUCache]SoftMax cache

* [CpuCache]fix bf16 tests

* [CPUCache]apply review comments

* [CPUCache]fix compilation
  • Loading branch information
zhangYiIntel authored Jan 10, 2022
1 parent af105b8 commit c1206ef
Show file tree
Hide file tree
Showing 2 changed files with 278 additions and 241 deletions.
84 changes: 65 additions & 19 deletions src/plugins/intel_cpu/src/nodes/mkldnn_softmax_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,45 @@
#include <memory_desc/cpu_memory_desc_utils.h>
#include <ngraph/opsets/opset1.hpp>
#include "memory_desc/dnnl_blocked_memory_desc.h"
#include <common/primitive_hashing_utils.hpp>

using namespace mkldnn;
using namespace MKLDNNPlugin;
using namespace InferenceEngine;

namespace {
struct SoftmaxKey {
DnnlMemoryDescCPtr inp0;
impl_desc_type implType;
size_t axis;

size_t hash() const;
bool operator==(const SoftmaxKey& rhs) const;
};

size_t SoftmaxKey::hash() const {
using namespace dnnl::impl;
using namespace dnnl::impl::primitive_hashing;

size_t seed = 0;

seed = hash_combine(seed, get_md_hash(inp0->getDnnlDesc().data));
seed = hash_combine(seed, implType);
seed = hash_combine(seed, axis);
return seed;
}

bool SoftmaxKey::operator==(const SoftmaxKey& rhs) const {
bool retVal = true;
if (inp0 != rhs.inp0) {
retVal = retVal && inp0 && rhs.inp0 && inp0->getDnnlDesc() == rhs.inp0->getDnnlDesc();
}

retVal = retVal && implType == rhs.implType && axis == rhs.axis;
return retVal;
}
} // namespace

bool MKLDNNSoftMaxNode::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
try {
if (!std::dynamic_pointer_cast<const ngraph::opset1::Softmax>(op)) {
Expand Down Expand Up @@ -108,32 +142,44 @@ void MKLDNNSoftMaxNode::createDescriptor(const std::vector<MemoryDescPtr> &input

void MKLDNNSoftMaxNode::prepareParams() {
auto inpDesc = getParentEdgeAt(0)->getMemory().GetDescWithType<DnnlMemoryDesc>();
const auto& in_candidate = inpDesc->getDnnlDesc();
MKLDNNDescriptor desc(std::shared_ptr<softmax_forward::desc>(
new softmax_forward::desc(prop_kind::forward_scoring, in_candidate, axis)));
const NodeDesc* selected_pd = getSelectedPrimitiveDescriptor();

const NodeDesc *selected_pd = getSelectedPrimitiveDescriptor();
if (selected_pd == nullptr)
IE_THROW() << "Preferable primitive descriptor is not set for node " << getName() << ".";

softmax_forward::primitive_desc prim_desc;
primitive_desc_iterator itpd = desc.createPrimitiveDescriptorIterator(getEngine());

while (itpd) {
impl_desc_type impl_type = parse_impl_name(itpd.impl_info_str());
if (impl_type == selected_pd->getImplementationType() ||
// At least for oneDNN v2.4 the softmax primitive is optimized for the cases where the dimension of the softmax axis is physically dense.
// There could be situations where it is not possible to detect the optimized case in advance in case of dynamic shapes, but
// in runtime the shape could be suitable for the optimized implementation, so we have to select the optimized one.
(ref_any == selected_pd->getImplementationType() && (impl_type & jit))) {
prim_desc = itpd.get();
break;
SoftmaxKey key = {inpDesc, selected_pd->getImplementationType(), axis};
auto engine = getEngine();
auto builder = [&engine](const SoftmaxKey& key) -> std::shared_ptr<mkldnn::primitive> {
softmax_forward::primitive_desc prim_desc;
MKLDNNDescriptor desc(std::shared_ptr<softmax_forward::desc>(
new softmax_forward::desc(prop_kind::forward_scoring, key.inp0->getDnnlDesc(), key.axis)));
primitive_desc_iterator itpd = desc.createPrimitiveDescriptorIterator(engine);

while (itpd) {
impl_desc_type impl_type = parse_impl_name(itpd.impl_info_str());
if (impl_type == key.implType ||
// At least for oneDNN v2.4 the softmax primitive is optimized for the cases where the dimension of the
// softmax axis is physically dense. There could be situations where it is not possible to detect the
// optimized case in advance in case of dynamic shapes, but in runtime the shape could be suitable for
// the optimized implementation, so we have to select the optimized one.
(ref_any == key.implType && (impl_type & jit))) {
prim_desc = itpd.get();
break;
}
if (!itpd.next_impl())
return nullptr;
}
if (!itpd.next_impl())
IE_THROW() << "Primitive descriptor was not found for node " << getName() << ".";
return std::make_shared<softmax_forward>(prim_desc);
};

auto cache = getRuntimeCache();
auto result = cache->getOrCreate(key, builder);

if (!result.first) {
IE_THROW() << "Primitive descriptor was not found for node " << getName() << ".";
}

prim.reset(new softmax_forward(prim_desc));
prim = result.first;

auto src = getParentEdgesAtPort(0)[0]->getMemoryPtr()->GetPrimitive();
auto dst = getChildEdgesAtPort(0)[0]->getMemoryPtr()->GetPrimitive();
Expand Down
Loading

0 comments on commit c1206ef

Please sign in to comment.