Skip to content

Commit

Permalink
Merge pull request BVLC#1703 from longjon/pyreformation
Browse files Browse the repository at this point in the history
Reform the boost::python wrapper, including layers implemented in Python
  • Loading branch information
shelhamer committed Feb 17, 2015
2 parents dfbc2dc + 91289b3 commit 5e64f5a
Show file tree
Hide file tree
Showing 18 changed files with 405 additions and 369 deletions.
14 changes: 10 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ EMPTY_LINT_REPORT := $(BUILD_DIR)/.$(LINT_EXT)
NONEMPTY_LINT_REPORT := $(BUILD_DIR)/$(LINT_EXT)
# PY$(PROJECT)_SRC is the python wrapper for $(PROJECT)
PY$(PROJECT)_SRC := python/$(PROJECT)/_$(PROJECT).cpp
PY$(PROJECT)_HXX_SRC := python/$(PROJECT)/_$(PROJECT).hpp
PY$(PROJECT)_SO := python/$(PROJECT)/_$(PROJECT).so
PY$(PROJECT)_HXX := include/$(PROJECT)/python_layer.hpp
# MAT$(PROJECT)_SRC is the matlab wrapper for $(PROJECT)
MAT$(PROJECT)_SRC := matlab/$(PROJECT)/mat$(PROJECT).cpp
ifneq ($(MATLAB_DIR),)
Expand Down Expand Up @@ -288,6 +288,12 @@ ifeq ($(CPU_ONLY), 1)
COMMON_FLAGS += -DCPU_ONLY
endif

# Python layer support
ifeq ($(WITH_PYTHON_LAYER), 1)
COMMON_FLAGS += -DWITH_PYTHON_LAYER
LIBRARIES += $(PYTHON_LIBRARIES)
endif

# BLAS configuration (default = ATLAS)
BLAS ?= atlas
ifeq ($(BLAS), mkl)
Expand Down Expand Up @@ -421,10 +427,10 @@ py$(PROJECT): py

py: $(PY$(PROJECT)_SO) $(PROTO_GEN_PY)

$(PY$(PROJECT)_SO): $(PY$(PROJECT)_SRC) $(PY$(PROJECT)_HXX_SRC) | $(DYNAMIC_NAME)
$(PY$(PROJECT)_SO): $(PY$(PROJECT)_SRC) $(PY$(PROJECT)_HXX) | $(DYNAMIC_NAME)
@ echo CXX/LD -o $@ $<
$(Q)$(CXX) -shared -o $@ $(PY$(PROJECT)_SRC) \
-o $@ $(LINKFLAGS) $(PYTHON_LDFLAGS) -l$(PROJECT) \
-o $@ $(LINKFLAGS) -l$(PROJECT) $(PYTHON_LDFLAGS) \
-Wl,-rpath,$(ORIGIN)/../../build/lib

mat$(PROJECT): mat
Expand Down Expand Up @@ -533,7 +539,7 @@ $(TOOL_BUILD_DIR)/%: $(TOOL_BUILD_DIR)/%.bin | $(TOOL_BUILD_DIR)

$(TOOL_BINS) $(EXAMPLE_BINS): %.bin : %.o | $(DYNAMIC_NAME)
@ echo CXX/LD -o $@
$(Q)$(CXX) $< -o $@ $(LINKFLAGS) $(LDFLAGS) -l$(PROJECT) \
$(Q)$(CXX) $< -o $@ $(LINKFLAGS) -l$(PROJECT) $(LDFLAGS) \
-Wl,-rpath,$(ORIGIN)/../lib

proto: $(PROTO_GEN_CC) $(PROTO_GEN_HEADER)
Expand Down
3 changes: 3 additions & 0 deletions Makefile.config.example
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ PYTHON_INCLUDE := /usr/include/python2.7 \
PYTHON_LIB := /usr/lib
# PYTHON_LIB := $(ANACONDA_HOME)/lib

# Uncomment to support layers written in Python (will link against Python libs)
# WITH_PYTHON_LAYER := 1

# Whatever else you find you need goes here.
INCLUDE_DIRS := $(PYTHON_INCLUDE) /usr/local/include
LIBRARY_DIRS := $(PYTHON_LIB) /usr/local/lib /usr/lib
Expand Down
11 changes: 6 additions & 5 deletions include/caffe/layer_factory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class Layer;
template <typename Dtype>
class LayerRegistry {
public:
typedef Layer<Dtype>* (*Creator)(const LayerParameter&);
typedef shared_ptr<Layer<Dtype> > (*Creator)(const LayerParameter&);
typedef std::map<string, Creator> CreatorRegistry;

static CreatorRegistry& Registry() {
Expand All @@ -70,7 +70,7 @@ class LayerRegistry {
}

// Get a layer using a LayerParameter.
static Layer<Dtype>* CreateLayer(const LayerParameter& param) {
static shared_ptr<Layer<Dtype> > CreateLayer(const LayerParameter& param) {
LOG(INFO) << "Creating layer " << param.name();
const string& type = param.type();
CreatorRegistry& registry = Registry();
Expand Down Expand Up @@ -103,7 +103,7 @@ template <typename Dtype>
class LayerRegisterer {
public:
LayerRegisterer(const string& type,
Layer<Dtype>* (*creator)(const LayerParameter&)) {
shared_ptr<Layer<Dtype> > (*creator)(const LayerParameter&)) {
// LOG(INFO) << "Registering layer type: " << type;
LayerRegistry<Dtype>::AddCreator(type, creator);
}
Expand All @@ -116,8 +116,9 @@ class LayerRegisterer {

#define REGISTER_LAYER_CLASS(type) \
template <typename Dtype> \
Layer<Dtype>* Creator_##type##Layer(const LayerParameter& param) { \
return new type##Layer<Dtype>(param); \
shared_ptr<Layer<Dtype> > Creator_##type##Layer(const LayerParameter& param) \
{ \
return shared_ptr<Layer<Dtype> >(new type##Layer<Dtype>(param)); \
} \
REGISTER_LAYER_CREATOR(type, Creator_##type##Layer)

Expand Down
68 changes: 68 additions & 0 deletions include/caffe/python_layer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#ifndef CAFFE_PYTHON_LAYER_HPP_
#define CAFFE_PYTHON_LAYER_HPP_

#include <boost/python.hpp>
#include <vector>

#include "caffe/layer.hpp"

namespace bp = boost::python;

namespace caffe {

template <typename Dtype>
class PythonLayer : public Layer<Dtype> {
public:
PythonLayer(PyObject* self, const LayerParameter& param)
: Layer<Dtype>(param), self_(self) { }

virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
try {
bp::call_method<bp::object>(self_, "setup", bottom, top);
} catch (bp::error_already_set) {
PyErr_Print();
throw;
}
}

virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
try {
bp::call_method<bp::object>(self_, "reshape", bottom, top);
} catch (bp::error_already_set) {
PyErr_Print();
throw;
}
}

virtual inline const char* type() const { return "Python"; }

protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
try {
bp::call_method<bp::object>(self_, "forward", bottom, top);
} catch (bp::error_already_set) {
PyErr_Print();
throw;
}
}
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
try {
bp::call_method<bp::object>(self_, "backward", top, propagate_down,
bottom);
} catch (bp::error_already_set) {
PyErr_Print();
throw;
}
}

private:
PyObject* self_;
};

} // namespace caffe

#endif
2 changes: 1 addition & 1 deletion python/caffe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .pycaffe import Net, SGDSolver
from ._caffe import set_mode_cpu, set_mode_gpu, set_device, \
set_phase_train, set_phase_test
set_phase_train, set_phase_test, Layer, get_solver
from .classifier import Classifier
from .detector import Detector
import io
Loading

0 comments on commit 5e64f5a

Please sign in to comment.