Skip to content

Commit

Permalink
merge baidu/develop
Browse files Browse the repository at this point in the history
  • Loading branch information
QiJune committed Jul 12, 2017
2 parents 4d336d9 + 0a32008 commit ca23d86
Show file tree
Hide file tree
Showing 14 changed files with 294 additions and 264 deletions.
1 change: 0 additions & 1 deletion paddle/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ if(Boost_FOUND)
add_subdirectory(memory)
add_subdirectory(platform)
add_subdirectory(framework)
add_subdirectory(operators)
add_subdirectory(pybind)
endif()

Expand Down
48 changes: 0 additions & 48 deletions paddle/framework/dim.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,29 +266,6 @@ HOSTDEVICE inline bool contained(const Dim<1>& idx, const Dim<1>& size) {
return ((0 <= idx.head) && (idx.head < size.head));
}

/**
* \brief Check if a size and a stride create a Fortran order contiguous
* block of memory.
*/
template <int i>
HOST bool contiguous(const Dim<i>& size, const Dim<i>& stride, int mul = 1) {
if (product(size) == 0) return true;
int contiguous_stride = get<0>(size) == 1 ? 0 : mul;
return (get<0>(stride) == contiguous_stride &&
contiguous(size.tail, stride.tail, mul * get<0>(size)));
}

///\cond HIDDEN
// Base case of contiguous, check the nth stride is the size of
// the prefix multiply of n-1 dims.
template <>
inline bool contiguous(const Dim<1>& size, const Dim<1>& stride, int mul) {
if (get<0>(size) == 0) return true;
int contiguous_stride = get<0>(size) == 1 ? 0 : mul;
return get<0>(stride) == contiguous_stride;
}
///\endcond

/**
* \brief Compute exclusive prefix-multiply of a Dim.
*/
Expand All @@ -306,31 +283,6 @@ HOSTDEVICE inline Dim<1> ex_prefix_mul(const Dim<1>& src, int mul) {
}
///\endcond

/**
* \brief Calculate strides of a contiguous array of the given size
*
* Sets the stride for any dimension with an extent of 1 to 0.
* \param size Dim object containing the size of the array.
* \param base The base stride to use.
* \return Dim object the same size as \p size with the strides.
*/
template <int i>
HOSTDEVICE Dim<i> contiguous_strides(const Dim<i>& size, int base = 1) {
int stride = size.head == 1 ? 0 : base;
return Dim<i>(stride, contiguous_strides(size.tail, base * size.head));
}

///\cond HIDDEN

// Base case of contiguous_strides
template <>
HOSTDEVICE inline Dim<1> contiguous_strides(const Dim<1>& size, int base) {
int stride = size.head == 1 ? 0 : base;
return Dim<1>(stride);
}

///\endcond

/**
* Add two dimensions together
*/
Expand Down
28 changes: 0 additions & 28 deletions paddle/framework/dim_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,24 +58,6 @@ TEST(Dim, Equality) {
EXPECT_EQ(paddle::framework::get<1>(c), 3);
EXPECT_EQ(paddle::framework::get<2>(c), 12);

// contiguous_strides
c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(10, 1, 10));
EXPECT_EQ(paddle::framework::get<0>(c), 1);
EXPECT_EQ(paddle::framework::get<1>(c), 0);
EXPECT_EQ(paddle::framework::get<2>(c), 10);
c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(10, 10, 1));
EXPECT_EQ(paddle::framework::get<0>(c), 1);
EXPECT_EQ(paddle::framework::get<1>(c), 10);
EXPECT_EQ(paddle::framework::get<2>(c), 0);
c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(1, 10, 10));
EXPECT_EQ(paddle::framework::get<0>(c), 0);
EXPECT_EQ(paddle::framework::get<1>(c), 1);
EXPECT_EQ(paddle::framework::get<2>(c), 10);
c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(2, 3, 4));
EXPECT_EQ(paddle::framework::get<0>(c), 1);
EXPECT_EQ(paddle::framework::get<1>(c), 2);
EXPECT_EQ(paddle::framework::get<2>(c), 6);

// generate from an index
auto size = paddle::framework::make_dim(4, 5, 2);
c = paddle::framework::Dim<3>(14, size);
Expand All @@ -101,16 +83,6 @@ TEST(Dim, Bool) {
EXPECT_TRUE(a == a);
EXPECT_FALSE(a == b);
EXPECT_TRUE(a == c);

// contiguous check
int x = 4, y = 5, z = 2;
paddle::framework::Dim<3> sizef(x, y, z);
paddle::framework::Dim<3> stridea(1, x, x*y);
paddle::framework::Dim<3> strideb(2, 2*x, 2*x*y);
paddle::framework::Dim<3> stridec(1, x, 2*x*y);
EXPECT_TRUE(paddle::framework::contiguous(sizef, stridea));
EXPECT_FALSE(paddle::framework::contiguous(sizef, strideb));
EXPECT_FALSE(paddle::framework::contiguous(sizef, stridec));
}

TEST(Dim, Print) {
Expand Down
33 changes: 16 additions & 17 deletions paddle/framework/op_registry_test.cc
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>
#include "paddle/framework/operator.h"
#include "paddle/operators/demo_op.h"

using namespace paddle::framework;

namespace paddle {
namespace framework {
class CosineOp : public OperatorWithKernel {
class CosineOp : public OperatorBase {
public:
void Run(const OpRunContext* context) const override {
printf("%s\n", DebugString().c_str());
}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {}
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
};

class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
Expand All @@ -30,12 +28,13 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {

REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim)

class MyTestOp : public OperatorWithKernel {
class MyTestOp : public OperatorBase {
public:
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {}

public:
void Run(const OpRunContext* ctx) const override {
printf("%s\n", DebugString().c_str());
printf("test_attr = %d\n", ctx->op_->GetAttr<int>("test_attr"));
}
};

class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
Expand Down Expand Up @@ -73,8 +72,8 @@ TEST(OpRegistry, CreateOp) {
paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<Scope>();
auto dev_ctx = DeviceContext();
op->Run(scope, &dev_ctx);
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
float scale_get = op->GetAttr<float>("scale");
ASSERT_EQ(scale_get, scale);
}
Expand Down Expand Up @@ -116,8 +115,8 @@ TEST(OpRegistry, DefaultValue) {
paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<Scope>();
auto dev_ctx = DeviceContext();
op->Run(scope, &dev_ctx);
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
ASSERT_EQ(op->GetAttr<float>("scale"), 1.0);
}

Expand Down Expand Up @@ -169,9 +168,9 @@ TEST(OpRegistry, CustomChecker) {
attr->set_i(4);
paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto dev_ctx = DeviceContext();
paddle::platform::CPUDeviceContext dev_ctx;
auto scope = std::make_shared<Scope>();
op->Run(scope, &dev_ctx);
op->Run(scope, dev_ctx);
int test_attr = op->GetAttr<int>("test_attr");
ASSERT_EQ(test_attr, 4);
}
Expand Down
8 changes: 0 additions & 8 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,5 @@ std::string OperatorBase::DebugString() const {
return ss.str();
}

const Variable* OpRunContext::Input(int index) const {
return scope_->GetVariable(op_->inputs_[index]);
}

Variable* OpRunContext::Output(int index) const {
return scope_->GetVariable(op_->outputs_[index]);
}

} // namespace framework
} // namespace paddle
117 changes: 80 additions & 37 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,44 +14,22 @@ limitations under the License. */

#pragma once

#include <paddle/framework/attr_checker.h>
#include <paddle/framework/op_desc.pb.h>
#include <paddle/framework/scope.h>
#include <paddle/platform/device_context.h>
#include <paddle/platform/place.h>
#include <paddle/utils/Error.h>
#include <boost/variant.hpp>
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/utils/Error.h"

namespace paddle {
namespace framework {

class OperatorBase;

class DeviceContext {};

/**
* OpRunContext is the only parameter of Operator's Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* OpRunContext. User should construct it before run the Operator.
*/
class OpRunContext {
public:
OpRunContext(const OperatorBase* op, const std::shared_ptr<Scope> scope,
const DeviceContext* device_context)
: op_(op), scope_(scope), device_context_(device_context) {}

const Variable* Input(int index) const;
Variable* Output(int index) const;

public:
const OperatorBase* op_;
const std::shared_ptr<Scope> scope_;
const DeviceContext* device_context_;
};

/**
* OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
Expand All @@ -77,7 +55,10 @@ class OperatorBase {

/// Net will call this function to Run an op.
virtual void Run(const std::shared_ptr<Scope>& scope,
const DeviceContext* dev_ctx) const = 0;
const platform::DeviceContext& dev_ctx) const = 0;

protected:
std::string Type() const { return desc_.type(); }

public:
OpDesc desc_;
Expand All @@ -86,22 +67,84 @@ class OperatorBase {
AttributeMap attrs_;
};

class OpKernel {
public:
/**
* KernelContext is the only parameter of Kernel Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
*/
class KernelContext {
public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {}

const Variable* Input(int index) const {
return scope_->GetVariable(op_.inputs_[index]);
}

Variable* Output(int index) const {
return scope_->GetVariable(op_.outputs_[index]);
}

const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
const platform::DeviceContext& device_context_;
};

virtual void Compute(const KernelContext& context) const = 0;

virtual ~OpKernel() {}
};

class OperatorWithKernel : public OperatorBase {
public:
virtual ~OperatorWithKernel() {}
struct OpKernelKey {
platform::Place place_;

virtual void InferShape(const std::shared_ptr<Scope>& scope) const {}
OpKernelKey() = default;
OpKernelKey(const platform::DeviceContext& dev_ctx) {
place_ = dev_ctx.GetPlace();
}

bool operator==(const OpKernelKey& o) const { return place_ == o.place_; }
};

struct OpKernelHash {
std::hash<bool> hash_;
size_t operator()(const OpKernelKey& key) const {
return hash_(platform::is_gpu_place(key.place_));
}
};

using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;

void Run(const std::shared_ptr<Scope>& scope,
const DeviceContext* dev_ctx) const {
OpRunContext op_ctx(this, scope, dev_ctx);
Run(&op_ctx);
const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(Type()).at(OpKernelKey(dev_ctx));
opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx));
}

/// when implement an Op, your should implement this function.
/// this function should be moved to OpKernel later
virtual void Run(const OpRunContext* context) const = 0;
static std::unordered_map<std::string /* op_type */, OpKernelMap>&
AllOpKernels() {
static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
return g_all_op_kernels;
};
};

} // namespace framework
} // namespace paddle

#define REGISTER_OP_KERNEL(type, PlaceType, KernelType) \
struct __op_kernel_register__##type##__ { \
__op_kernel_register__##type##__() { \
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
key.place_ = PlaceType(); \
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
.reset(new KernelType()); \
} \
}; \
static __op_kernel_register__##type##__ __reg_kernel_##type##__
Loading

0 comments on commit ca23d86

Please sign in to comment.