Skip to content

Commit

Permalink
Move/rename the plugin factory test file; delete duplicate test file;…
Browse files Browse the repository at this point in the history
… fix minor formatting issues.
  • Loading branch information
aaroey committed May 3, 2018
1 parent a2d35bd commit 03de4a4
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 105 deletions.
10 changes: 7 additions & 3 deletions tensorflow/contrib/tensorrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ tf_cc_test(
],
)

# Library for the plugin factory
# Library for the plugin factory
tf_cuda_library(
name = "trt_plugins",
srcs = [
Expand All @@ -304,9 +304,13 @@ tf_cuda_library(
)

tf_cuda_cc_test(
name = "trt_plugins_test",
name = "trt_plugin_factory_test",
size = "small",
srcs = ["plugin/trt_plugins_test.cc"],
srcs = ["plugin/trt_plugin_factory_test.cc"],
tags = [
"manual",
"notap",
],
deps = [
":trt_plugins",
"//tensorflow/core:test",
Expand Down
6 changes: 5 additions & 1 deletion tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ tf_kernel_library(
],
gpu_srcs = [
"inc_op_kernel.h",
"inc_op_kernel.cu.cc"
"inc_op_kernel.cu.cc",
],
deps = [
"//tensorflow/contrib/tensorrt:trt_plugins",
Expand Down Expand Up @@ -120,4 +120,8 @@ tf_py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:tf_optimizer",
],
tags = [
"manual",
"notap",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@
from __future__ import division
from __future__ import print_function

# normally we should do import tensorflow as tf and then
# tf.placeholder, tf.constant, tf.nn.conv2d etc but
# it looks like internal builds don't like it so
# importing every module individually

from tensorflow.contrib import tensorrt
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
Expand Down
6 changes: 5 additions & 1 deletion tensorflow/contrib/tensorrt/plugin/trt_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <iostream>
#include <unordered_map>
#include <vector>

#include "tensorflow/core/platform/types.h"

#if GOOGLE_CUDA
Expand All @@ -35,9 +36,11 @@ namespace tensorrt {
// PluginDeserializeFunc & PluginConstructFunc through PluginFactoryTensorRT
class PluginTensorRT : public nvinfer1::IPlugin {
public:
PluginTensorRT() {};
PluginTensorRT() {}
PluginTensorRT(const void* serialized_data, size_t length);

virtual const string& GetPluginName() const = 0;

virtual bool Finalize() = 0;

virtual bool SetAttribute(const string& key, const void* ptr,
Expand All @@ -53,6 +56,7 @@ class PluginTensorRT : public nvinfer1::IPlugin {
const size_t size);

virtual size_t getSerializationSize() override;

virtual void serialize(void* buffer) override;

protected:
Expand Down
9 changes: 5 additions & 4 deletions tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ limitations under the License.
#include <memory>
#include <mutex>
#include <unordered_map>
#include "trt_plugin.h"
#include "trt_plugin_utils.h"

#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h"
#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h"

#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
Expand Down Expand Up @@ -54,12 +55,12 @@ class PluginFactoryTensorRT : public nvinfer1::IPluginFactory {

protected:
std::unordered_map<string,
std::pair<PluginDeserializeFunc, PluginConstructFunc> >
std::pair<PluginDeserializeFunc, PluginConstructFunc>>
plugin_registry_;

// TODO(jie): Owned plugin should be associated with different sessions;
// should really hand ownership of plugins to resource management;
std::vector<std::unique_ptr<PluginTensorRT> > owned_plugins_;
std::vector<std::unique_ptr<PluginTensorRT>> owned_plugins_;
std::mutex instance_m_;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ StubPlugin* CreateStubPluginDeserialize(const void* serialized_data,
return new StubPlugin(serialized_data, length);
}

class PluginTest : public ::testing::Test {
class TrtPluginFactoryTest : public ::testing::Test {
public:
bool RegisterStubPlugin() {
if (PluginFactoryTensorRT::GetInstance()->IsPlugin(
Expand All @@ -94,7 +94,7 @@ class PluginTest : public ::testing::Test {
}
};

TEST_F(PluginTest, Registration) {
TEST_F(TrtPluginFactoryTest, Registration) {
EXPECT_FALSE(
PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName));
EXPECT_TRUE(RegisterStubPlugin());
Expand All @@ -103,7 +103,7 @@ TEST_F(PluginTest, Registration) {
PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName));
}

TEST_F(PluginTest, CreationDeletion) {
TEST_F(TrtPluginFactoryTest, CreationDeletion) {
EXPECT_TRUE(RegisterStubPlugin());
ASSERT_TRUE(
PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName));
Expand Down
1 change: 1 addition & 0 deletions tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS

#include <functional>

#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h"
#include "tensorflow/core/platform/types.h"

Expand Down
88 changes: 0 additions & 88 deletions tensorflow/contrib/tensorrt/plugin_test.py

This file was deleted.

0 comments on commit 03de4a4

Please sign in to comment.