@@ -21,6 +21,7 @@ limitations under the License.
21
21
#include " tensorflow/core/common_runtime/device_mgr.h"
22
22
#include " tensorflow/core/common_runtime/dma_helper.h"
23
23
#include " tensorflow/core/common_runtime/gpu/gpu_util.h"
24
+ #include " tensorflow/core/common_runtime/gpu/process_state.h"
24
25
#include " tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
25
26
#include " tensorflow/core/distributed_runtime/session_mgr.h"
26
27
#include " tensorflow/core/framework/rendezvous.h"
@@ -683,7 +684,6 @@ void RdmaTensorBuffer::SendNextItem() {
683
684
<< " error message: " << status.error_message ();
684
685
size_t buffer_size = RdmaMessage::kMessageTotalBytes ;
685
686
size_t tensor_bytes = 0 ;
686
- TensorProto proto;
687
687
// Figures out which device the tensor is hosted on.
688
688
Device* src_dev = nullptr ;
689
689
Status s = channel_->adapter_ ->worker_env_ ->device_mgr ->LookupDevice (
@@ -703,21 +703,47 @@ void RdmaTensorBuffer::SendNextItem() {
703
703
CHECK (s.ok ()) << " dst device not found" ;
704
704
AllocatorAttributes dst_alloc_attr;
705
705
dst_alloc_attr.set_on_host (true );
706
+
707
+ bool can_memcpy = DataTypeCanUseMemcpy (in.dtype ());
706
708
// string tensor needs to be serialized
709
+ Tensor copy;
710
+ StringPiece copy_buf;
711
+ TensorProto proto;
707
712
if (src_dev->tensorflow_gpu_device_info () &&
708
713
(!send_args.alloc_attrs .on_host ())) {
709
714
CHECK (send_args.device_context )
710
- << " send dev name: " << src_dev->name ()
711
- << " gpu_info: " << src_dev->tensorflow_gpu_device_info ();
712
- // "val" is on a GPU. Uses GPUUtil to fill the proto.
713
- s = VerbsUtil::SetProtoFromGPUSync (
714
- in, src_dev, send_args.device_context , &proto, is_dead);
715
- CHECK (s.ok ()) << " set proto from gpu sync" ;
715
+ << " send dev name: " << src_dev->name ()
716
+ << " gpu_info: " << src_dev->tensorflow_gpu_device_info ();
717
+
718
+ if (can_memcpy) {
719
+ AllocatorAttributes host_alloc_attrs;
720
+ host_alloc_attrs.set_gpu_compatible (true );
721
+ host_alloc_attrs.set_on_host (true );
722
+ Allocator* alloc = ProcessState::singleton ()->GetCUDAHostAllocator (0 );
723
+ copy = Tensor (alloc, in.dtype (), in.shape ());
724
+ s = VerbsUtil::CopyGPUTensorToCPUSync (
725
+ src_dev, send_args.device_context , &in, ©);
726
+ CHECK (s.ok ()) << " copy tensor from gpu sync" ;
727
+ copy_buf = copy.tensor_data ();
728
+ } else {
729
+ // "val" is on a GPU. Uses GPUUtil to fill the proto.
730
+ s = VerbsUtil::SetProtoFromGPUSync (
731
+ in, src_dev, send_args.device_context , &proto, is_dead);
732
+ CHECK (s.ok ()) << " set proto from gpu sync" ;
733
+ }
716
734
} else {
717
735
// tensor is in CPU memory.
718
- in.AsProtoTensorContent (&proto);
736
+ if (can_memcpy) {
737
+ copy_buf = in.tensor_data ();
738
+ } else {
739
+ in.AsProtoTensorContent (&proto);
740
+ }
741
+ }
742
+ if (can_memcpy) {
743
+ tensor_bytes = in.TotalBytes ();
744
+ } else {
745
+ tensor_bytes = proto.ByteSize ();
719
746
}
720
- tensor_bytes = proto.ByteSize ();
721
747
// maybe some margin for string tensor?
722
748
buffer_size += tensor_bytes;
723
749
// prepare message
@@ -771,7 +797,16 @@ void RdmaTensorBuffer::SendNextItem() {
771
797
static_cast <void *>(static_cast <char *>(buffer_) +
772
798
RdmaMessage::kTensorBufferStartIndex );
773
799
CHECK (tensor_bytes + RdmaMessage::kTensorBufferStartIndex <= size_);
774
- proto.SerializeToArray (output, tensor_bytes);
800
+ if (can_memcpy) {
801
+ CHECK (copy_buf.size () == tensor_bytes)
802
+ << " unexpected tensor size: "
803
+ << copy_buf.size ()
804
+ << " != "
805
+ << tensor_bytes;
806
+ memcpy (output, copy_buf.data (), tensor_bytes);
807
+ } else {
808
+ proto.SerializeToArray (output, tensor_bytes);
809
+ }
775
810
} else {
776
811
buffer_size = RdmaMessage::kMessageTotalBytes ;
777
812
}
0 commit comments