Skip to content

Commit

Permalink
Fix NetworkExecutionStatePool available state tracking. (pytorch#4239)
Browse files Browse the repository at this point in the history
Summary:
Changed getNextNetworkExecutionState to use a deque to track available states instead of round robin.
Documentation:
Pull Request resolved: pytorch#4239

Test Plan:
ninja check
tracelogfb verified this fixes the segfaulting issue.

Reviewed By: tracelogfb

Differential Revision: D20209608

Pulled By: gcatron

fbshipit-source-id: 6b91893f5d8b7bb11e763525fdd94b7172eb05a5
  • Loading branch information
gcatron authored and facebook-github-bot committed Mar 3, 2020
1 parent caaeac5 commit 0c5dfb0
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 9 deletions.
21 changes: 14 additions & 7 deletions include/glow/Runtime/Executor/NetworkExecutionState.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class NetworkExecutionState final {
/// Constructor.
explicit NetworkExecutionState(const DAGNode *root);

const DAGNode *getRoot() { return root_; }

/// Destructor.
~NetworkExecutionState();

Expand Down Expand Up @@ -138,18 +140,23 @@ class NetworkExecutionState final {
class NetworkExecutionStatePool {
public:
NetworkExecutionState *getNextNetworkExecutionState() {
// Note: this assumes pool size is maxActiveRequests, otherwise it's
// possible to wrap around on a state currently in use.
unsigned current = currentState_.fetch_add(1);
return states_[current % states_.size()].get();
std::lock_guard<std::mutex> lock(stateLock_);
auto nextState = availableStates_.front();
availableStates_.pop_front();
return nextState;
}
void addNewState(std::unique_ptr<NetworkExecutionState> state) {
states_.push_back(std::move(state));

void addNewState(std::unique_ptr<NetworkExecutionState> state);

void returnNetworkExecutionState(NetworkExecutionState *state) {
std::lock_guard<std::mutex> lock(stateLock_);
availableStates_.push_back(state);
}

private:
std::vector<std::unique_ptr<NetworkExecutionState>> states_;
std::atomic<unsigned> currentState_;
std::deque<NetworkExecutionState *> availableStates_;
std::mutex stateLock_;
};

} // namespace runtime
Expand Down
8 changes: 8 additions & 0 deletions lib/Runtime/Executor/NetworkExecutionState.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
using namespace glow;
using namespace glow::runtime;

void NetworkExecutionStatePool::addNewState(
std::unique_ptr<NetworkExecutionState> state) {

std::lock_guard<std::mutex> lock(stateLock_);
availableStates_.push_back(state.get());
states_.push_back(std::move(state));
}

NetworkExecutionState::NetworkExecutionState(const DAGNode *root)
: inflightNodes_(0), module_(root->module), root_(root) {}

Expand Down
11 changes: 9 additions & 2 deletions lib/Runtime/Executor/ThreadPoolExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,15 @@ void ThreadPoolExecutor::handleDeviceManagerResult(
// will transfer. //executionState->transferOutputs();
ResultCBTy cb = executionState->getCallback();
DCHECK(cb != nullptr);
cb(executionState->getRunId(), executionState->getErrorContainer().get(),
executionState->getUniqueResultContextPtr());

// Get what we need from the executionState and return it to the pool.
auto runId = executionState->getRunId();
auto err = executionState->getErrorContainer().get();
auto resultCtx = executionState->getUniqueResultContextPtr();
states_[executionState->getRoot()]->returnNetworkExecutionState(
executionState);

cb(runId, std::move(err), std::move(resultCtx));
}

// Decrement the inflight barrier for the executor keeping track of all
Expand Down

0 comments on commit 0c5dfb0

Please sign in to comment.