Skip to content

Commit 05c1dfa

Browse files
authored
Merge pull request tensorflow#10531 from llhe/rdma_fix
Improve RDMA rendezvous speed
2 parents df6a235 + 3ca692e commit 05c1dfa

File tree

4 files changed

+124
-16
lines changed

4 files changed

+124
-16
lines changed

tensorflow/contrib/verbs/rdma.cc

+45-10
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
#include "tensorflow/core/common_runtime/device_mgr.h"
2222
#include "tensorflow/core/common_runtime/dma_helper.h"
2323
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
24+
#include "tensorflow/core/common_runtime/gpu/process_state.h"
2425
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
2526
#include "tensorflow/core/distributed_runtime/session_mgr.h"
2627
#include "tensorflow/core/framework/rendezvous.h"
@@ -683,7 +684,6 @@ void RdmaTensorBuffer::SendNextItem() {
683684
<< " error message: " << status.error_message();
684685
size_t buffer_size = RdmaMessage::kMessageTotalBytes;
685686
size_t tensor_bytes = 0;
686-
TensorProto proto;
687687
// Figures out which device the tensor is hosted on.
688688
Device* src_dev = nullptr;
689689
Status s = channel_->adapter_->worker_env_->device_mgr->LookupDevice(
@@ -703,21 +703,47 @@ void RdmaTensorBuffer::SendNextItem() {
703703
CHECK(s.ok()) << "dst device not found";
704704
AllocatorAttributes dst_alloc_attr;
705705
dst_alloc_attr.set_on_host(true);
706+
707+
bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
706708
// string tensor needs to be serialized
709+
Tensor copy;
710+
StringPiece copy_buf;
711+
TensorProto proto;
707712
if (src_dev->tensorflow_gpu_device_info() &&
708713
(!send_args.alloc_attrs.on_host())) {
709714
CHECK(send_args.device_context)
710-
<< "send dev name: " << src_dev->name()
711-
<< " gpu_info: " << src_dev->tensorflow_gpu_device_info();
712-
// "val" is on a GPU. Uses GPUUtil to fill the proto.
713-
s = VerbsUtil::SetProtoFromGPUSync(
714-
in, src_dev, send_args.device_context, &proto, is_dead);
715-
CHECK(s.ok()) << "set proto from gpu sync";
715+
<< "send dev name: " << src_dev->name()
716+
<< " gpu_info: " << src_dev->tensorflow_gpu_device_info();
717+
718+
if (can_memcpy) {
719+
AllocatorAttributes host_alloc_attrs;
720+
host_alloc_attrs.set_gpu_compatible(true);
721+
host_alloc_attrs.set_on_host(true);
722+
Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
723+
copy = Tensor(alloc, in.dtype(), in.shape());
724+
s = VerbsUtil::CopyGPUTensorToCPUSync(
725+
src_dev, send_args.device_context, &in, &copy);
726+
CHECK(s.ok()) << "copy tensor from gpu sync";
727+
copy_buf = copy.tensor_data();
728+
} else {
729+
// "val" is on a GPU. Uses GPUUtil to fill the proto.
730+
s = VerbsUtil::SetProtoFromGPUSync(
731+
in, src_dev, send_args.device_context, &proto, is_dead);
732+
CHECK(s.ok()) << "set proto from gpu sync";
733+
}
716734
} else {
717735
// tensor is in CPU memory.
718-
in.AsProtoTensorContent(&proto);
736+
if (can_memcpy) {
737+
copy_buf = in.tensor_data();
738+
} else {
739+
in.AsProtoTensorContent(&proto);
740+
}
741+
}
742+
if (can_memcpy) {
743+
tensor_bytes = in.TotalBytes();
744+
} else {
745+
tensor_bytes = proto.ByteSize();
719746
}
720-
tensor_bytes = proto.ByteSize();
721747
// maybe some margin for string tensor?
722748
buffer_size += tensor_bytes;
723749
// prepare message
@@ -771,7 +797,16 @@ void RdmaTensorBuffer::SendNextItem() {
771797
static_cast<void*>(static_cast<char*>(buffer_) +
772798
RdmaMessage::kTensorBufferStartIndex);
773799
CHECK(tensor_bytes + RdmaMessage::kTensorBufferStartIndex <= size_);
774-
proto.SerializeToArray(output, tensor_bytes);
800+
if (can_memcpy) {
801+
CHECK(copy_buf.size() == tensor_bytes)
802+
<< "unexpected tensor size: "
803+
<< copy_buf.size()
804+
<< " != "
805+
<< tensor_bytes;
806+
memcpy(output, copy_buf.data(), tensor_bytes);
807+
} else {
808+
proto.SerializeToArray(output, tensor_bytes);
809+
}
775810
} else {
776811
buffer_size = RdmaMessage::kMessageTotalBytes;
777812
}

tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc

+35-6
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
#include "tensorflow/core/common_runtime/device.h"
2222
#include "tensorflow/core/common_runtime/device_mgr.h"
2323
#include "tensorflow/core/common_runtime/dma_helper.h"
24+
#include "tensorflow/core/common_runtime/gpu/process_state.h"
2425
#include "tensorflow/core/lib/core/errors.h"
2526
#include "tensorflow/core/lib/strings/numbers.h"
2627
#include "tensorflow/core/lib/strings/str_util.h"
@@ -99,12 +100,40 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
99100
if (!rm.is_dead_) {
100101
void* input = static_cast<char*>(rb->buffer_) +
101102
RdmaMessage::kTensorBufferStartIndex;
102-
TensorProto proto;
103-
CHECK(rm.tensor_bytes_ + RdmaMessage::kTensorBufferStartIndex <=
104-
rb->size_);
105-
CHECK(ParseProtoUnlimited(&proto, input, rm.tensor_bytes_))
106-
<< "fail to parse proto from array";
107-
s = dst_dev->MakeTensorFromProto(proto, recv_args.alloc_attrs, &val);
103+
bool can_memcpy = DataTypeCanUseMemcpy(rm.data_type_);
104+
if (can_memcpy) {
105+
if (dst_dev->tensorflow_gpu_device_info() &&
106+
(!recv_args.alloc_attrs.on_host())) {
107+
CHECK(recv_args.device_context)
108+
<< "send dev name: " << src_dev->name()
109+
<< " gpu_info: " << src_dev->tensorflow_gpu_device_info();
110+
Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
111+
Tensor copy(alloc, rm.data_type_, rm.tensor_shape_);
112+
memcpy(DMAHelper::base(&copy), input, rm.tensor_bytes_);
113+
114+
Allocator* dst_alloc = dst_dev->GetAllocator(recv_args.alloc_attrs);
115+
Tensor gpu_copy(dst_alloc, rm.data_type_, rm.tensor_shape_);
116+
s = VerbsUtil::CopyCPUTensorToGPUSync(&copy, recv_args.device_context,
117+
dst_dev, &gpu_copy);
118+
CHECK(s.ok()) << "copy tensor to gpu sync";
119+
val = std::move(gpu_copy);
120+
} else {
121+
AllocatorAttributes host_alloc_attrs;
122+
host_alloc_attrs.set_gpu_compatible(true);
123+
host_alloc_attrs.set_on_host(true);
124+
Allocator* alloc = dst_dev->GetAllocator(host_alloc_attrs);
125+
Tensor copy(alloc, rm.data_type_, rm.tensor_shape_);
126+
memcpy(DMAHelper::base(&copy), input, rm.tensor_bytes_);
127+
val = std::move(copy);
128+
}
129+
} else {
130+
TensorProto proto;
131+
CHECK(rm.tensor_bytes_ + RdmaMessage::kTensorBufferStartIndex <=
132+
rb->size_);
133+
CHECK(ParseProtoUnlimited(&proto, input, rm.tensor_bytes_))
134+
<< "fail to parse proto from array";
135+
s = dst_dev->MakeTensorFromProto(proto, recv_args.alloc_attrs, &val);
136+
}
108137
}
109138

110139
rc->RemoveRecvCallback(key_with_step_id);

tensorflow/contrib/verbs/verbs_util.cc

+34
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,40 @@ limitations under the License.
2020
#include "tensorflow/core/lib/strings/str_util.h"
2121
namespace tensorflow {
2222

23+
// static sync wrapper:
24+
Status VerbsUtil::CopyGPUTensorToCPUSync(Device* gpu_device,
25+
const DeviceContext* device_context,
26+
const Tensor* gpu_tensor,
27+
Tensor* cpu_tensor) {
28+
Notification n;
29+
Status status;
30+
GPUUtil::CopyGPUTensorToCPU(gpu_device, device_context,
31+
gpu_tensor, cpu_tensor,
32+
[&n, &status](const Status& s) {
33+
status = s;
34+
n.Notify();
35+
});
36+
n.WaitForNotification();
37+
return status;
38+
}
39+
40+
// static sync wrapper:
41+
Status VerbsUtil::CopyCPUTensorToGPUSync(const Tensor* cpu_tensor,
42+
const DeviceContext* device_context,
43+
Device* gpu_device,
44+
Tensor* gpu_tensor) {
45+
Notification n;
46+
Status status;
47+
GPUUtil::CopyCPUTensorToGPU(cpu_tensor, device_context,
48+
gpu_device, gpu_tensor,
49+
[&n, &status](const Status& s) {
50+
status = s;
51+
n.Notify();
52+
});
53+
n.WaitForNotification();
54+
return status;
55+
}
56+
2357
// static sync wrapper:
2458
Status VerbsUtil::SetProtoFromGPUSync(const Tensor& tensor, Device* dev,
2559
const DeviceContext* device_context,

tensorflow/contrib/verbs/verbs_util.h

+10
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,16 @@ class TensorProto;
2828

2929
class VerbsUtil {
3030
public:
31+
// synchronous wrapper of CopyGPUTensorToCPU
32+
static Status CopyGPUTensorToCPUSync(Device* gpu_device,
33+
const DeviceContext* device_context,
34+
const Tensor* gpu_tensor,
35+
Tensor* cpu_tensor);
36+
// synchronous wrapper of CopyCPUTensorToGPU
37+
static Status CopyCPUTensorToGPUSync(const Tensor* cpu_tensor,
38+
const DeviceContext* device_context,
39+
Device* gpu_device,
40+
Tensor* gpu_tensor);
3141
// synchronous wrapper of SetProtoFromGPU
3242
static Status SetProtoFromGPUSync(const Tensor& tensor, Device* dev,
3343
const DeviceContext* device_context,

0 commit comments

Comments
 (0)