Skip to content

Commit

Permalink
add BytepsPushPullXlaOp
Browse files Browse the repository at this point in the history
debugging again
cleaned up

PushPull passes through input directly to output. the xla custom op gets
the input tensor pointers, passes the input pointers to byteps to
performe inplace pushpull, then returns right away.

SyncAllTensors gets the input tensors as input, passes them through as
output. its CustomCall waits on all tensors, then return.

The current code doesn't allocate extra memory or perform extra memcpy.
The gradients passed in from tensorflow is retained, operated on, and
returned after SyncAllTenors finishes.

the result is correct.

Signed-off-by: Yulu Jia <[email protected]>
  • Loading branch information
pleasantrabbit committed Sep 29, 2020
1 parent d2cf5f8 commit ce3ebe0
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions byteps/tensorflow/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ class BytepsPushPullXlaOp : public ::tensorflow::XlaOpKernel {
context->SetOutput(
1, xla::CustomCall(context->builder(),
/*call_target_name=*/"StartTaskWrapper",
{input_tensor}, output_shapes, ss.str()));
{input_tensor}, output_tensor_shape, ss.str()));

context->op_kernel_context()->set_output(0,
context->op_kernel_context()->input(0));
Expand Down Expand Up @@ -450,12 +450,28 @@ void StartTaskBlockingXla(::tensorflow::OpKernelContext* context,

std::string name_key(node_name);
std::replace(name_key.begin(), name_key.end(), '/', '_');
_name_to_done_args[name_key]->is_done = false;

std::shared_ptr<Xla_done_cb_args> new_args(new Xla_done_cb_args);
new_args->is_done = false;
new_args->bps_out_buf = const_cast<void *>(byteps_output->data());
new_args->bps_in_buf = const_cast<void *>(byteps_input->data());
new_args->bps_buf_size = size;

std::unique_lock<std::mutex> my_lk(_name_to_done_args_mtx);
auto it = _name_to_done_args.find(name_key);
ASSERTF(it == _name_to_done_args.end(), "duplicate tensor_name");
_name_to_done_args[name_key] = new_args;
my_lk.unlock();

auto enqueue_result =
EnqueueTensor(byteps_context, byteps_input, byteps_output, ready_event,
device, -byteps_context.declared_key, 0,
[name_key](const common::Status& status) {
std::unique_lock<std::mutex> my_lk(_name_to_done_args_mtx);
auto it = _name_to_done_args.find(name_key);
ASSERTF(it != _name_to_done_args.end(), "YOU SHOULD NOT SEE ME");
auto args = _name_to_done_args[name_key];
my_lk.unlock();
{
std::unique_lock<std::mutex> lk(args->mtx);
args->is_done = true;
Expand All @@ -464,7 +480,9 @@ void StartTaskBlockingXla(::tensorflow::OpKernelContext* context,
},
queue_list);
{
std::unique_lock<std::mutex> my_big_lk(_name_to_done_args_mtx);
auto args = _name_to_done_args[name_key];
my_big_lk.unlock();
std::unique_lock<std::mutex> lk(args->mtx);
args->cv.wait(lk, [args]{
std::this_thread::yield();
Expand All @@ -475,9 +493,6 @@ void StartTaskBlockingXla(::tensorflow::OpKernelContext* context,

void StartTaskBlockingWrapper(CUstream stream, void** buffers,
const char* opaque, size_t opaque_len) {
void *a = buffers[0];
void *b = buffers[1];

std::stringstream ss(opaque);
std::string tmp_name;
::tensorflow::OpKernelContext* context = nullptr;
Expand Down

0 comments on commit ce3ebe0

Please sign in to comment.