Skip to content

Commit

Permalink
Allow workers to send SHUTDOWN request to the coordinator (horovod#201)
Browse files Browse the repository at this point in the history
* Allow workers to send SHUTDOWN request to the coordinator

* Bump version to 0.12.1
  • Loading branch information
alsrgv authored Mar 13, 2018
1 parent c2828dc commit 75c1f3f
Show file tree
Hide file tree
Showing 10 changed files with 190 additions and 97 deletions.
2 changes: 1 addition & 1 deletion horovod/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.12.0'
__version__ = '0.12.1'
4 changes: 4 additions & 0 deletions horovod/common/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ Status Status::PreconditionError(std::string message) {
return Status(StatusType::PRECONDITION_ERROR, message);
}

Status Status::Aborted(std::string message) {
return Status(StatusType::ABORTED, message);
}

bool Status::ok() const {
return type_ == StatusType::OK;
}
Expand Down
3 changes: 2 additions & 1 deletion horovod/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ namespace common {
// List of supported frameworks.
enum Framework { TENSORFLOW };

enum StatusType { OK, UNKNOWN_ERROR, PRECONDITION_ERROR };
enum StatusType { OK, UNKNOWN_ERROR, PRECONDITION_ERROR, ABORTED };

class Status {
public:
Status();
static Status OK();
static Status UnknownError(std::string message);
static Status PreconditionError(std::string message);
static Status Aborted(std::string message);
bool ok() const;
StatusType type() const;
const std::string& reason() const;
Expand Down
10 changes: 10 additions & 0 deletions horovod/common/mpi_message.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,14 @@ void MPIRequestList::set_requests(const std::vector<MPIRequest>& value) {
requests_ = value;
}

bool MPIRequestList::shutdown() const {
return shutdown_;
}

void MPIRequestList::set_shutdown(bool value) {
shutdown_ = value;
}

void MPIRequestList::add_requests(MPIRequest value) {
requests_.push_back(value);
}
Expand All @@ -183,6 +191,7 @@ void MPIRequestList::ParseFromString(MPIRequestList& request_list,
MPIRequest_ParseFromWire(request, *it);
request_list.add_requests(std::move(request));
}
request_list.set_shutdown(obj->shutdown());
}

void MPIRequestList::SerializeToString(MPIRequestList& request_list,
Expand All @@ -197,6 +206,7 @@ void MPIRequestList::SerializeToString(MPIRequestList& request_list,
requests.push_back(req_obj);
}
request_list_builder.add_requests(builder.CreateVector(requests));
request_list_builder.add_shutdown(request_list.shutdown());
auto obj = request_list_builder.Finish();
builder.Finish(obj);

Expand Down
3 changes: 3 additions & 0 deletions horovod/common/mpi_message.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ class MPIRequestList {
const std::vector<MPIRequest>& requests() const;
void set_requests(const std::vector<MPIRequest>& value);
void add_requests(MPIRequest value);
bool shutdown() const;
void set_shutdown(bool value);

static void ParseFromString(MPIRequestList& request_list,
const std::string& input);
Expand All @@ -97,6 +99,7 @@ class MPIRequestList {

private:
std::vector<MPIRequest> requests_;
bool shutdown_ = false;
};

// An MPIResponse is a message sent from the coordinator (rank zero) to a rank
Expand Down
138 changes: 93 additions & 45 deletions horovod/common/operations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,14 @@ static HorovodGlobalState horovod_global;
// Stall-check warning time
#define STALL_WARNING_TIME std::chrono::seconds(60)

static const Status NOT_INITIALIZED_ERROR = Status::PreconditionError(
"Horovod has not been initialized; use hvd.init().");

static const Status SHUT_DOWN_ERROR = Status::Aborted(
"Horovod has been shut down. This has been caused by an exception on one "
"of the rank or an attempt to allreduce, allgather or broadcast a tensor "
"after one of the ranks has finished execution.");

// Store the MPIRequest for a name, and return whether the total count of
// MPIRequests for that tensor is now equal to the MPI size (and thus we are
// ready to reduce the tensor).
Expand Down Expand Up @@ -1237,15 +1245,17 @@ void BackgroundThreadLoop(HorovodGlobalState& state) {
}
}
} else {
std::string encoded_message;
MPIRequestList message_list;
while (!message_queue.empty()) {
message_list.add_requests(message_queue.front());
message_queue.pop();
if (!message_queue.empty()) {
std::string encoded_message;
MPIRequestList message_list;
while (!message_queue.empty()) {
message_list.add_requests(message_queue.front());
message_queue.pop();
}
MPIRequestList::SerializeToString(message_list, encoded_message);
MPI_Send(encoded_message.c_str(), (int)encoded_message.length() + 1,
MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);
}
MPIRequestList::SerializeToString(message_list, encoded_message);
MPI_Send(encoded_message.c_str(), (int)encoded_message.length() + 1,
MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);
}

// Rank zero has put all its own tensors in the tensor count table.
Expand Down Expand Up @@ -1292,6 +1302,10 @@ void BackgroundThreadLoop(HorovodGlobalState& state) {
ready_to_reduce.push_back(received_name);
}
}
if (received_message_list.shutdown()) {
// Received SHUTDOWN request from one of the workers.
state.shut_down = true;
}
}

// At this point, rank zero should have a fully updated tensor count
Expand Down Expand Up @@ -1374,6 +1388,16 @@ void BackgroundThreadLoop(HorovodGlobalState& state) {
state.last_stall_check = std::chrono::steady_clock::now();
}
} else {
if (state.shut_down) {
// Send a SHUTDOWN request to the coordinator.
std::string encoded_message;
MPIRequestList shutdown_request;
shutdown_request.set_shutdown(true);
MPIRequestList::SerializeToString(shutdown_request, encoded_message);
MPI_Send(encoded_message.c_str(), (int)encoded_message.length() + 1,
MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);
}

// Notify the coordinator that this node is done sending messages.
// A DONE message is encoded as a zero-length message.
MPI_Send(NULL, 0, MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);
Expand Down Expand Up @@ -1424,6 +1448,25 @@ void BackgroundThreadLoop(HorovodGlobalState& state) {
// ncclCommDestroy(it->second);
// }
//#endif

// Notify all outstanding operations that Horovod has been shut down
// and clear up the tensor table and message queue.
std::vector<StatusCallback> callbacks;
{
std::lock_guard<std::mutex> guard(state.mutex);
for (auto it = state.tensor_table.begin(); it != state.tensor_table.end();
it++) {
callbacks.emplace_back(it->second.callback);
}
state.tensor_table.clear();
while (!state.message_queue.empty()) {
state.message_queue.pop();
}
}
for (auto it = callbacks.begin(); it != callbacks.end(); it++) {
(*it)(SHUT_DOWN_ERROR);
}

MPI_Finalize();
}

Expand All @@ -1446,8 +1489,7 @@ void InitializeHorovodOnce() {

Status CheckInitialized() {
if (!horovod_global.initialization_done) {
return Status::PreconditionError(
"Horovod has not been initialized; use hvd.init().");
return NOT_INITIALIZED_ERROR;
}
return Status::OK();
}
Expand Down Expand Up @@ -1494,17 +1536,14 @@ int horovod_mpi_threads_supported() {

// MPI must be initialized and the background thread must be running before
// this function is called.
void EnqueueTensorAllreduce(std::shared_ptr<OpContext> context,
std::shared_ptr<Tensor> tensor,
std::shared_ptr<Tensor> output,
std::shared_ptr<ReadyEvent> ready_event,
const std::string name, const int device,
StatusCallback callback) {
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);

Status EnqueueTensorAllreduce(std::shared_ptr<OpContext> context,
std::shared_ptr<Tensor> tensor,
std::shared_ptr<Tensor> output,
std::shared_ptr<ReadyEvent> ready_event,
const std::string name, const int device,
StatusCallback callback) {
MPIRequest message;
message.set_request_rank(rank);
message.set_request_rank(horovod_global.rank);
message.set_tensor_name(name);
message.set_tensor_type(tensor->dtype());
message.set_device(device);
Expand All @@ -1523,22 +1562,24 @@ void EnqueueTensorAllreduce(std::shared_ptr<OpContext> context,
e.callback = callback;

std::lock_guard<std::mutex> guard(horovod_global.mutex);
horovod_global.tensor_table.emplace(name, std::move(e));
horovod_global.message_queue.push(message);
if (!horovod_global.shut_down) {
horovod_global.tensor_table.emplace(name, std::move(e));
horovod_global.message_queue.push(message);
return Status::OK();
} else {
return SHUT_DOWN_ERROR;
}
}

// MPI must be initialized and the background thread must be running before
// this function is called.
void EnqueueTensorAllgather(std::shared_ptr<OpContext> context,
std::shared_ptr<Tensor> tensor,
std::shared_ptr<ReadyEvent> ready_event,
const std::string name, const int device,
StatusCallback callback) {
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);

Status EnqueueTensorAllgather(std::shared_ptr<OpContext> context,
std::shared_ptr<Tensor> tensor,
std::shared_ptr<ReadyEvent> ready_event,
const std::string name, const int device,
StatusCallback callback) {
MPIRequest message;
message.set_request_rank(rank);
message.set_request_rank(horovod_global.rank);
message.set_tensor_name(name);
message.set_tensor_type(tensor->dtype());
message.set_device(device);
Expand All @@ -1556,23 +1597,25 @@ void EnqueueTensorAllgather(std::shared_ptr<OpContext> context,
e.callback = callback;

std::lock_guard<std::mutex> guard(horovod_global.mutex);
horovod_global.tensor_table.emplace(name, std::move(e));
horovod_global.message_queue.push(message);
if (!horovod_global.shut_down) {
horovod_global.tensor_table.emplace(name, std::move(e));
horovod_global.message_queue.push(message);
return Status::OK();
} else {
return SHUT_DOWN_ERROR;
}
}

// MPI must be initialized and the background thread must be running before
// this function is called.
void EnqueueTensorBroadcast(std::shared_ptr<OpContext> context,
std::shared_ptr<Tensor> tensor,
std::shared_ptr<Tensor> output, int root_rank,
std::shared_ptr<ReadyEvent> ready_event,
const std::string name, const int device,
StatusCallback callback) {
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);

Status EnqueueTensorBroadcast(std::shared_ptr<OpContext> context,
std::shared_ptr<Tensor> tensor,
std::shared_ptr<Tensor> output, int root_rank,
std::shared_ptr<ReadyEvent> ready_event,
const std::string name, const int device,
StatusCallback callback) {
MPIRequest message;
message.set_request_rank(rank);
message.set_request_rank(horovod_global.rank);
message.set_tensor_name(name);
message.set_tensor_type(tensor->dtype());
message.set_root_rank(root_rank);
Expand All @@ -1593,8 +1636,13 @@ void EnqueueTensorBroadcast(std::shared_ptr<OpContext> context,
e.callback = callback;

std::lock_guard<std::mutex> guard(horovod_global.mutex);
horovod_global.tensor_table.emplace(name, std::move(e));
horovod_global.message_queue.push(message);
if (!horovod_global.shut_down) {
horovod_global.tensor_table.emplace(name, std::move(e));
horovod_global.message_queue.push(message);
return Status::OK();
} else {
return SHUT_DOWN_ERROR;
}
}

} // namespace common
Expand Down
38 changes: 19 additions & 19 deletions horovod/common/operations.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,25 +60,25 @@ int horovod_local_size();
int horovod_mpi_threads_supported();
}

void EnqueueTensorAllreduce(std::shared_ptr<OpContext> context,
std::shared_ptr<Tensor> tensor,
std::shared_ptr<Tensor> output,
std::shared_ptr<ReadyEvent> ready_event,
const std::string name, const int device,
StatusCallback callback);

void EnqueueTensorAllgather(std::shared_ptr<OpContext> context,
std::shared_ptr<Tensor> tensor,
std::shared_ptr<ReadyEvent> ready_event,
const std::string name, const int device,
StatusCallback callback);

void EnqueueTensorBroadcast(std::shared_ptr<OpContext> context,
std::shared_ptr<Tensor> tensor,
std::shared_ptr<Tensor> output, int root_rank,
std::shared_ptr<ReadyEvent> ready_event,
const std::string name, const int device,
StatusCallback callback);
Status EnqueueTensorAllreduce(std::shared_ptr<OpContext> context,
std::shared_ptr<Tensor> tensor,
std::shared_ptr<Tensor> output,
std::shared_ptr<ReadyEvent> ready_event,
const std::string name, const int device,
StatusCallback callback);

Status EnqueueTensorAllgather(std::shared_ptr<OpContext> context,
std::shared_ptr<Tensor> tensor,
std::shared_ptr<ReadyEvent> ready_event,
const std::string name, const int device,
StatusCallback callback);

Status EnqueueTensorBroadcast(std::shared_ptr<OpContext> context,
std::shared_ptr<Tensor> tensor,
std::shared_ptr<Tensor> output, int root_rank,
std::shared_ptr<ReadyEvent> ready_event,
const std::string name, const int device,
StatusCallback callback);

} // namespace common
} // namespace horovod
Expand Down
3 changes: 3 additions & 0 deletions horovod/common/wire/mpi_message.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ table MPIRequest {
}
table MPIRequestList {
requests:[MPIRequest];

// Flag indicating whether worker is requesting a shutdown.
shutdown:bool;
}

// An MPIResponse is a message sent from the coordinator (rank zero) to a rank
Expand Down
Loading

0 comments on commit 75c1f3f

Please sign in to comment.