Skip to content

Commit

Permalink
Add dlpack data import support
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoweiw-nv authored and sfeiwong committed Jul 27, 2023
1 parent efcedb6 commit eb19252
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 9 deletions.
1 change: 1 addition & 0 deletions bmf/hml/include/hmp/dataexport/data_export.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@

namespace hmp {
HMP_API DLManagedTensor *to_dlpack(const Tensor &src);
HMP_API Tensor from_dlpack(const DLManagedTensor* src);
} // namespace hmp
56 changes: 55 additions & 1 deletion bmf/hml/py/py_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,59 @@ using namespace hmp;
Tensor tensor_from_numpy(const py::array& arr);
py::array tensor_to_numpy(const Tensor& tensor);

static bool is_device_supported(DLDeviceType devType)
{
switch (devType)
{
case kDLCUDAHost:
case kDLCUDA:
case kDLCUDAManaged:
case kDLCPU:
return true;
default:
return false;
}
}

Tensor tensor_from_dlpack(const py::object& o){
py::object tmp = py::reinterpret_borrow<py::object>(o);
Tensor ten;
if (hasattr(tmp, "__dlpack__"))
{
// Quickly check if we support the device
if (hasattr(tmp, "__dlpack_device__"))
{
py::tuple dlpackDevice = tmp.attr("__dlpack_device__")().cast<py::tuple>();
auto devType = static_cast<DLDeviceType>(dlpackDevice[0].cast<int>());
if (!is_device_supported(devType))
{
HMP_REQUIRE(false, "Only CPU and CUDA memory buffers can be wrapped");
}
}

py::capsule cap = tmp.attr("__dlpack__")(1).cast<py::capsule>();
py::handle* hdl = dynamic_cast<py::handle*>(&cap);
PyObject* pycap = *(PyObject**)hdl;

if (auto* tensor = static_cast<DLManagedTensor*>(cap.get_pointer()))
{
// m_dlTensor = DLPackTensor{std::move(*tensor)};
ten = from_dlpack(tensor);
// signal that producer don't have to call tensor's deleter, we
// (consumer will do it instead.
HMP_REQUIRE(PyCapsule_SetName(pycap, "used_dltensor") == 0, "Failed to rename dltensor capsule");
}
else
{
HMP_REQUIRE(false, "No dlpack tensor found");
}
}
else {
HMP_REQUIRE(false, "dlpack not supported in the src tensor");
}
return ten;
}

Device parse_device(const py::object &obj, const Device &ref)
{
Device device(ref);
Expand Down Expand Up @@ -101,6 +154,8 @@ void tensorBind(py::module &m)
}
return arr_list;
})
.def("from_dlpack", (Tensor(*)(const py::object&))&tensor_from_dlpack,
py::arg("tensor"))
#ifdef HMP_ENABLE_TORCH
.def("from_torch", [](const at::Tensor &t){
return hmp::torch::from_tensor(t);
Expand Down Expand Up @@ -229,7 +284,6 @@ void tensorBind(py::module &m)
.def("data_ptr", [](const Tensor &self){
return reinterpret_cast<uint64_t>(self.unsafe_data());
})
// .def("__dlpack__", &Tensor::to_dlpack, py::arg("stream")=1)
.def("__dlpack__", [](const Tensor &self, const int stream){
DLManagedTensor* dlMTensor = to_dlpack(self);
py::capsule cap(dlMTensor, "dltensor", [](PyObject *ptr)
Expand Down
98 changes: 90 additions & 8 deletions bmf/hml/src/dataexport/data_export.cpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
#include <hmp/dataexport/data_export.h>

#include <iostream>
#include <string>

namespace hmp{
static DLDataType getDLDataType(const Tensor& t) {
static DLDataType get_dl_dtype(const Tensor& t) {
DLDataType dtype;
dtype.lanes = 1;
dtype.bits = sizeof_scalar_type(t.scalar_type()) * 8;
switch (t.scalar_type()) {
case ScalarType::UInt8:
case ScalarType::UInt16:
// case ScalarType::UInt32:
// case ScalarType::UInt64:
dtype.code = DLDataTypeCode::kDLUInt;
break;
case ScalarType::Int8:
Expand All @@ -32,7 +31,66 @@ static DLDataType getDLDataType(const Tensor& t) {
return dtype;
}

static DLDevice getDLDevice(const Tensor& tensor, const int64_t& device_id) {
ScalarType to_scalar_type(const DLDataType& dtype) {
ScalarType stype;
HMP_REQUIRE(dtype.lanes == 1, "hmp does not support lanes != 1");
switch (dtype.code) {
case DLDataTypeCode::kDLUInt:
switch (dtype.bits) {
case 8:
stype = ScalarType::UInt8;
break;
case 16:
stype = ScalarType::UInt16;
break;
default:
HMP_REQUIRE(
false, "Unsupported kUInt bits " + std::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLInt:
switch (dtype.bits) {
case 8:
stype = ScalarType::Int8;
break;
case 16:
stype = ScalarType::Int16;
break;
case 32:
stype = ScalarType::Int32;
break;
case 64:
stype = ScalarType::Int64;
break;
default:
HMP_REQUIRE(
false, "Unsupported kInt bits " + std::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLFloat:
switch (dtype.bits) {
case 16:
stype = ScalarType::Half;
break;
case 32:
stype = ScalarType::Float32;
break;
case 64:
stype = ScalarType::Float64;
break;
default:
HMP_REQUIRE(
false, "Unsupported kFloat bits " + std::to_string(dtype.bits));
}
break;
default:
HMP_REQUIRE(
false, "Unsupported code " + std::to_string(dtype.code));
}
return stype;
}

static DLDevice get_dl_device(const Tensor& tensor, const int64_t& device_id) {
DLDevice ctx;
ctx.device_id = device_id;
switch (tensor.device().type()) {
Expand All @@ -48,14 +106,26 @@ static DLDevice getDLDevice(const Tensor& tensor, const int64_t& device_id) {
return ctx;
}

static Device get_hmp_device(const DLDevice& ctx) {
switch (ctx.device_type) {
case DLDeviceType::kDLCPU:
return Device(DeviceType::CPU);
#ifdef HMP_ENABLE_CUDA
case DLDeviceType::kDLCUDA:
return Device(DeviceType::CUDA, ctx.device_id);
#endif
default:
HMP_REQUIRE(
false, "Unsupported device_type: " + std::to_string(ctx.device_type));
}
}

struct HmpDLMTensor {
// HmpDLMTensor() { std::cout << "Construct HmpDLMTensor\n"; }
Tensor handle;
DLManagedTensor tensor;
};

void deleter(DLManagedTensor* arg) {
// std::cout << "Destruct HmpDLMTensor\n";
delete static_cast<HmpDLMTensor*>(arg->manager_ctx);
}

Expand All @@ -70,9 +140,9 @@ DLManagedTensor* to_dlpack(const Tensor& src) {
if (src.is_cuda()) {
device_id = src.device_index();
}
hmpDLMTensor->tensor.dl_tensor.device = getDLDevice(src, device_id);
hmpDLMTensor->tensor.dl_tensor.device = get_dl_device(src, device_id);
hmpDLMTensor->tensor.dl_tensor.ndim = src.dim();
hmpDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
hmpDLMTensor->tensor.dl_tensor.dtype = get_dl_dtype(src);
hmpDLMTensor->tensor.dl_tensor.shape =
const_cast<int64_t*>(src.shape().data());
hmpDLMTensor->tensor.dl_tensor.strides =
Expand All @@ -81,4 +151,16 @@ DLManagedTensor* to_dlpack(const Tensor& src) {
return &(hmpDLMTensor->tensor);
}

Tensor from_dlpack(
const DLManagedTensor* src) {
Device device = get_hmp_device(src->dl_tensor.device);
ScalarType stype = to_scalar_type(src->dl_tensor.dtype);
DataPtr dp{src->dl_tensor.data, device};
SizeArray shape{src->dl_tensor.shape, src->dl_tensor.shape + src->dl_tensor.ndim};
if (!src->dl_tensor.strides) {
return from_buffer({src->dl_tensor.data, device}, stype, shape);
}
SizeArray strides{src->dl_tensor.strides, src->dl_tensor.strides + src->dl_tensor.ndim};
return from_buffer({src->dl_tensor.data, device}, stype, shape, strides);
}
} // namespace hmp

0 comments on commit eb19252

Please sign in to comment.