Skip to content

Commit

Permalink
Removed template from base (openvinotoolkit#4045)
Browse files Browse the repository at this point in the history
Co-authored-by: Ilya Lavrenov <[email protected]>
  • Loading branch information
apankratovantonp and ilya-lavrenov authored Feb 4, 2021
1 parent e80e5e7 commit 945da5f
Show file tree
Hide file tree
Showing 14 changed files with 46 additions and 60 deletions.
2 changes: 1 addition & 1 deletion docs/template_plugin/src/template_executable_network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ InferenceEngine::IInferRequest::Ptr TemplatePlugin::ExecutableNetwork::CreateInf
auto internalRequest = CreateInferRequestImpl(_networkInputs, _networkOutputs);
auto asyncThreadSafeImpl = std::make_shared<TemplateAsyncInferRequest>(std::static_pointer_cast<TemplateInferRequest>(internalRequest),
_taskExecutor, _plugin->_waitExecutor, _callbackExecutor);
asyncRequest.reset(new InferenceEngine::InferRequestBase<TemplateAsyncInferRequest>(asyncThreadSafeImpl),
asyncRequest.reset(new InferenceEngine::InferRequestBase(asyncThreadSafeImpl),
[](InferenceEngine::IInferRequest *p) { p->Release(); });
asyncThreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
return asyncRequest;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ IInferRequest::Ptr MultiDeviceExecutableNetwork::CreateInferRequest() {
_needPerfCounters,
std::static_pointer_cast<MultiDeviceExecutableNetwork>(shared_from_this()),
_callbackExecutor);
asyncRequest.reset(new InferRequestBase<MultiDeviceAsyncInferRequest>(asyncTreadSafeImpl), [](IInferRequest *p) { p->Release(); });
asyncRequest.reset(new InferRequestBase(asyncTreadSafeImpl), [](IInferRequest *p) { p->Release(); });
asyncTreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
return asyncRequest;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <cpp/ie_executable_network.hpp>
#include <cpp_interfaces/base/ie_variable_state_base.hpp>
#include <cpp_interfaces/interface/ie_ivariable_state_internal.hpp>
#include <cpp_interfaces/interface/ie_iexecutable_network_internal.hpp>
#include <map>
#include <memory>
#include <string>
Expand All @@ -24,18 +25,17 @@ namespace InferenceEngine {
/**
* @brief Executable network `noexcept` wrapper which accepts IExecutableNetworkInternal derived instance which can throw exceptions
* @ingroup ie_dev_api_exec_network_api
* @tparam T Minimal CPP implementation of IExecutableNetworkInternal (e.g. ExecutableNetworkInternal)
*/
template <class T>
*/
class ExecutableNetworkBase : public IExecutableNetwork {
std::shared_ptr<T> _impl;
protected:
std::shared_ptr<IExecutableNetworkInternal> _impl;

public:
/**
* @brief Constructor with actual underlying implementation.
* @param impl Underlying implementation of type IExecutableNetworkInternal
*/
explicit ExecutableNetworkBase(std::shared_ptr<T> impl) {
explicit ExecutableNetworkBase(std::shared_ptr<IExecutableNetworkInternal> impl) {
if (impl.get() == nullptr) {
THROW_IE_EXCEPTION << "implementation not defined";
}
Expand Down Expand Up @@ -77,7 +77,7 @@ class ExecutableNetworkBase : public IExecutableNetwork {
if (idx >= v.size()) {
return OUT_OF_BOUNDS;
}
pState = std::make_shared<VariableStateBase<IVariableStateInternal>>(v[idx]);
pState = std::make_shared<VariableStateBase>(v[idx]);
return OK;
} catch (const std::exception& ex) {
return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what();
Expand All @@ -91,11 +91,6 @@ class ExecutableNetworkBase : public IExecutableNetwork {
delete this;
}

/// @private Need for unit tests only - TODO: unit tests should test using public API, non having details
const std::shared_ptr<T> getImpl() const {
return _impl;
}

StatusCode SetConfig(const std::map<std::string, Parameter>& config, ResponseDesc* resp) noexcept override {
TO_STATUS(_impl->SetConfig(config));
}
Expand All @@ -112,8 +107,8 @@ class ExecutableNetworkBase : public IExecutableNetwork {
TO_STATUS(pContext = _impl->GetContext());
}

private:
~ExecutableNetworkBase() = default;
protected:
~ExecutableNetworkBase() override = default;
};

/**
Expand All @@ -127,7 +122,7 @@ template <class T>
inline typename InferenceEngine::ExecutableNetwork make_executable_network(std::shared_ptr<T> impl) {
// to suppress warning about deprecated QueryState
IE_SUPPRESS_DEPRECATED_START
typename ExecutableNetworkBase<T>::Ptr net(new ExecutableNetworkBase<T>(impl), [](IExecutableNetwork* p) {
typename ExecutableNetworkBase::Ptr net(new ExecutableNetworkBase(impl), [](IExecutableNetwork* p) {
p->Release();
});
IE_SUPPRESS_DEPRECATED_END
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "cpp_interfaces/exception2status.hpp"
#include "cpp_interfaces/plugin_itt.hpp"
#include <cpp_interfaces/base/ie_variable_state_base.hpp>
#include <cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp>
#include "ie_iinfer_request.hpp"
#include "ie_preprocess.hpp"

Expand All @@ -19,18 +20,16 @@ namespace InferenceEngine {
/**
* @brief Inference request `noexcept` wrapper which accepts IAsyncInferRequestInternal derived instance which can throw exceptions
* @ingroup ie_dev_api_async_infer_request_api
* @tparam T Minimal CPP implementation of IAsyncInferRequestInternal (e.g. AsyncInferRequestThreadSafeDefault)
*/
template <class T>
class InferRequestBase : public IInferRequest {
std::shared_ptr<T> _impl;
std::shared_ptr<IAsyncInferRequestInternal> _impl;

public:
/**
* @brief Constructor with actual underlying implementation.
* @param impl Underlying implementation of type IAsyncInferRequestInternal
*/
explicit InferRequestBase(std::shared_ptr<T> impl): _impl(impl) {}
explicit InferRequestBase(std::shared_ptr<IAsyncInferRequestInternal> impl): _impl(impl) {}

StatusCode Infer(ResponseDesc* resp) noexcept override {
OV_ITT_SCOPED_TASK(itt::domains::Plugin, "Infer");
Expand Down Expand Up @@ -100,7 +99,7 @@ class InferRequestBase : public IInferRequest {
if (idx >= v.size()) {
return OUT_OF_BOUNDS;
}
pState = std::make_shared<VariableStateBase<IVariableStateInternal>>(v[idx]);
pState = std::make_shared<VariableStateBase>(v[idx]);
return OK;
} catch (const std::exception& ex) {
return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,17 @@ IE_SUPPRESS_DEPRECATED_START

/**
* @brief Default implementation for IVariableState
* @tparam T Minimal CPP implementation of IVariableStateInternal (e.g. VariableStateInternal)
* @ingroup ie_dev_api_variable_state_api
* @ingroup ie_dev_api_variable_state_api
*/
template <class T>
class VariableStateBase : public IVariableState {
std::shared_ptr<T> impl;
std::shared_ptr<IVariableStateInternal> impl;

public:
/**
* @brief Constructor with actual underlying implementation.
* @param impl Underlying implementation of type IVariableStateInternal
*/
explicit VariableStateBase(std::shared_ptr<T> impl): impl(impl) {
explicit VariableStateBase(std::shared_ptr<IVariableStateInternal> impl): impl(impl) {
if (impl == nullptr) {
THROW_IE_EXCEPTION << "VariableStateBase implementation is not defined";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ExecutableNetworkThreadSafeAsyncOnly : public ExecutableNetworkInternal,
auto asyncRequestImpl = this->CreateAsyncInferRequestImpl(_networkInputs, _networkOutputs);
asyncRequestImpl->setPointerToExecutableNetworkInternal(shared_from_this());

asyncRequest.reset(new InferRequestBase<AsyncInferRequestInternal>(asyncRequestImpl), [](IInferRequest* p) {
asyncRequest.reset(new InferRequestBase(asyncRequestImpl), [](IInferRequest* p) {
p->Release();
});
asyncRequestImpl->SetPointerToPublicInterface(asyncRequest);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class ExecutableNetworkThreadSafeDefault : public ExecutableNetworkInternal,

auto asyncThreadSafeImpl = std::make_shared<AsyncInferRequestType>(
syncRequestImpl, _taskExecutor, _callbackExecutor);
asyncRequest.reset(new InferRequestBase<AsyncInferRequestType>(asyncThreadSafeImpl),
asyncRequest.reset(new InferRequestBase(asyncThreadSafeImpl),
[](IInferRequest *p) { p->Release(); });
asyncThreadSafeImpl->SetPointerToPublicInterface(asyncRequest);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ class ExecutableNetwork : public ie::ExecutableNetworkThreadSafeDefault {
auto taskExecutorGetResult = getNextTaskExecutor();
auto asyncThreadSafeImpl = std::make_shared<MyriadAsyncInferRequest>(
syncRequestImpl, _taskExecutor, _callbackExecutor, taskExecutorGetResult);
asyncRequest.reset(new ie::InferRequestBase<ie::AsyncInferRequestThreadSafeDefault>(
asyncThreadSafeImpl),
asyncRequest.reset(new ie::InferRequestBase(asyncThreadSafeImpl),
[](ie::IInferRequest *p) { p->Release(); });
asyncThreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
return asyncRequest;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class ExecutableNetworkThreadSafeAsyncOnlyTests : public ::testing::Test {
virtual void SetUp() {
mockExeNetwork = make_shared<MockExecutableNetworkThreadSafeAsyncOnly>();
exeNetwork = details::shared_from_irelease(
new ExecutableNetworkBase<MockExecutableNetworkThreadSafeAsyncOnly>(mockExeNetwork));
new ExecutableNetworkBase(mockExeNetwork));
InputsDataMap networkInputs;
OutputsDataMap networkOutputs;
mockAsyncInferRequestInternal = make_shared<MockAsyncInferRequestInternal>(networkInputs, networkOutputs);
Expand All @@ -46,7 +46,7 @@ TEST_F(ExecutableNetworkThreadSafeAsyncOnlyTests, createAsyncInferRequestCallsTh
EXPECT_CALL(*mockExeNetwork.get(), CreateAsyncInferRequestImpl(_, _)).WillOnce(
Return(mockAsyncInferRequestInternal));
EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
auto threadSafeReq = dynamic_pointer_cast<InferRequestBase<AsyncInferRequestInternal>>(req);
auto threadSafeReq = dynamic_pointer_cast<InferRequestBase>(req);
ASSERT_NE(threadSafeReq, nullptr);
}

Expand Down Expand Up @@ -109,7 +109,7 @@ class ExecutableNetworkThreadSafeTests : public ::testing::Test {
virtual void SetUp() {
mockExeNetwork = make_shared<MockExecutableNetworkThreadSafe>();
exeNetwork = details::shared_from_irelease(
new ExecutableNetworkBase<MockExecutableNetworkThreadSafe>(mockExeNetwork));
new ExecutableNetworkBase(mockExeNetwork));
InputsDataMap networkInputs;
OutputsDataMap networkOutputs;
mockInferRequestInternal = make_shared<MockInferRequestInternal>(networkInputs, networkOutputs);
Expand All @@ -120,7 +120,7 @@ TEST_F(ExecutableNetworkThreadSafeTests, createInferRequestCallsThreadSafeImplAn
IInferRequest::Ptr req;
EXPECT_CALL(*mockExeNetwork.get(), CreateInferRequestImpl(_, _)).WillOnce(Return(mockInferRequestInternal));
EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
auto threadSafeReq = dynamic_pointer_cast<InferRequestBase<AsyncInferRequestThreadSafeDefault>>(req);
auto threadSafeReq = dynamic_pointer_cast<InferRequestBase>(req);
ASSERT_NE(threadSafeReq, nullptr);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class InferRequestBaseTests : public ::testing::Test {

virtual void SetUp() {
mock_impl.reset(new MockIAsyncInferRequestInternal());
request = details::shared_from_irelease(new InferRequestBase<MockIAsyncInferRequestInternal>(mock_impl));
request = details::shared_from_irelease(new InferRequestBase(mock_impl));
}
};

Expand Down Expand Up @@ -243,7 +243,7 @@ class InferRequestTests : public ::testing::Test {
mockNotEmptyNet.getOutputsInfo(outputsInfo);
mockInferRequestInternal = make_shared<MockAsyncInferRequestInternal>(inputsInfo, outputsInfo);
inferRequest = shared_from_irelease(
new InferRequestBase<MockAsyncInferRequestInternal>(mockInferRequestInternal));
new InferRequestBase(mockInferRequestInternal));
return make_shared<InferRequest>(inferRequest);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,7 @@ TEST_F(InferRequestThreadSafeDefaultTests, callbackTakesOKIfAsyncRequestWasOK) {
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);

IInferRequest::Ptr asyncRequest;
asyncRequest.reset(new InferRequestBase<AsyncInferRequestThreadSafeDefault>(testRequest),
[](IInferRequest *p) { p->Release(); });
asyncRequest.reset(new InferRequestBase(testRequest), [](IInferRequest *p) { p->Release(); });
testRequest->SetPointerToPublicInterface(asyncRequest);

testRequest->SetCompletionCallback([](InferenceEngine::IInferRequest::Ptr request, StatusCode status) {
Expand All @@ -215,8 +214,7 @@ TEST_F(InferRequestThreadSafeDefaultTests, callbackIsCalledIfAsyncRequestFailed)
auto taskExecutor = std::make_shared<CPUStreamsExecutor>();
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
IInferRequest::Ptr asyncRequest;
asyncRequest.reset(new InferRequestBase<AsyncInferRequestThreadSafeDefault>(testRequest),
[](IInferRequest *p) { p->Release(); });
asyncRequest.reset(new InferRequestBase(testRequest), [](IInferRequest *p) { p->Release(); });
testRequest->SetPointerToPublicInterface(asyncRequest);

bool wasCalled = false;
Expand All @@ -238,8 +236,7 @@ TEST_F(InferRequestThreadSafeDefaultTests, canCatchExceptionIfAsyncRequestFailed
auto taskExecutor = std::make_shared<CPUStreamsExecutor>();
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
IInferRequest::Ptr asyncRequest;
asyncRequest.reset(new InferRequestBase<AsyncInferRequestThreadSafeDefault>(testRequest),
[](IInferRequest *p) { p->Release(); });
asyncRequest.reset(new InferRequestBase(testRequest), [](IInferRequest *p) { p->Release(); });
testRequest->SetPointerToPublicInterface(asyncRequest);

EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).WillOnce(Throw(std::exception()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using namespace InferenceEngine::details;

template <class T>
inline typename InferenceEngine::InferRequest make_infer_request(std::shared_ptr<T> impl) {
typename InferRequestBase<T>::Ptr req(new InferRequestBase<T>(impl), [](IInferRequest* p) {
typename InferRequestBase::Ptr req(new InferRequestBase(impl), [](IInferRequest* p) {
p->Release();
});
return InferenceEngine::InferRequest(req);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ class ExecutableNetworkBaseTests : public ::testing::Test {

virtual void SetUp() {
mock_impl.reset(new MockIExecutableNetworkInternal());
exeNetwork = shared_from_irelease(new ExecutableNetworkBase<MockIExecutableNetworkInternal>(mock_impl));
exeNetwork = shared_from_irelease(new ExecutableNetworkBase(mock_impl));
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,17 @@ class MKLDNNTestExecNetwork: public MKLDNNPlugin::MKLDNNExecNetwork {
}
};

class MKLDNNTestEngine: public MKLDNNPlugin::Engine {
public:
MKLDNNPlugin::MKLDNNGraph& getGraph(InferenceEngine::IExecutableNetwork::Ptr execNetwork) {
auto * execNetworkInt =
dynamic_cast<InferenceEngine::ExecutableNetworkBase<InferenceEngine::ExecutableNetworkInternal> *>(execNetwork.get());
if (!execNetworkInt)
THROW_IE_EXCEPTION << "Cannot find loaded network!";

auto * network = reinterpret_cast<MKLDNNTestExecNetwork *>(execNetworkInt->getImpl().get());
if (!network)
THROW_IE_EXCEPTION << "Cannot get mkldnn graph!";
return network->getGraph();
}
struct TestExecutableNetworkBase : public InferenceEngine::ExecutableNetworkBase {
using InferenceEngine::ExecutableNetworkBase::_impl;
~TestExecutableNetworkBase() override = default;
};

static MKLDNNPlugin::MKLDNNGraph& getGraph(InferenceEngine::IExecutableNetwork::Ptr execNetwork) {
return reinterpret_cast<MKLDNNTestExecNetwork*>(
reinterpret_cast<TestExecutableNetworkBase*>(
execNetwork.get())->_impl.get())->getGraph();
}

class MKLDNNGraphLeaksTests: public ::testing::Test {
protected:
void addOutputToEachNode(InferenceEngine::CNNNetwork& network, std::vector<std::string>& new_outputs,
Expand Down Expand Up @@ -257,11 +253,11 @@ TEST_F(MKLDNNGraphLeaksTests, MKLDNN_not_release_outputs_fp32) {

ASSERT_NE(1, network.getOutputsInfo().size());

std::shared_ptr<MKLDNNTestEngine> score_engine(new MKLDNNTestEngine());
std::shared_ptr<MKLDNNPlugin::Engine> score_engine(new MKLDNNPlugin::Engine());
InferenceEngine::ExecutableNetwork exeNetwork1;
ASSERT_NO_THROW(exeNetwork1 = score_engine->LoadNetwork(network, {}));

size_t modified_outputs_size = score_engine->getGraph(exeNetwork1).GetOutputNodes().size();
size_t modified_outputs_size = getGraph(exeNetwork1).GetOutputNodes().size();

InferenceEngine::CNNNetwork network2;
ASSERT_NO_THROW(network2 = core.ReadNetwork(model, weights_ptr));
Expand All @@ -270,10 +266,12 @@ TEST_F(MKLDNNGraphLeaksTests, MKLDNN_not_release_outputs_fp32) {
InferenceEngine::ExecutableNetwork exeNetwork2;
ASSERT_NO_THROW(exeNetwork2 = score_engine->LoadNetwork(network2, {}));

size_t original_outputs_size = score_engine->getGraph(exeNetwork2).GetOutputNodes().size();
size_t original_outputs_size = getGraph(exeNetwork2).GetOutputNodes().size();

ASSERT_NE(modified_outputs_size, original_outputs_size);
ASSERT_EQ(1, original_outputs_size);
} catch (std::exception& e) {
FAIL() << e.what();
} catch (...) {
FAIL();
}
Expand Down

0 comments on commit 945da5f

Please sign in to comment.