Skip to content

Commit

Permalink
Try a stack-based DFS for eval (ml-explore#980)
Browse files Browse the repository at this point in the history
* rebase

* nit

* fix eval in vmap
  • Loading branch information
awni authored Apr 11, 2024
1 parent 061cf9a commit 8580d99
Showing 1 changed file with 34 additions and 40 deletions.
74 changes: 34 additions & 40 deletions mlx/transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <numeric>
#include <set>
#include <sstream>
#include <stack>
#include <unordered_map>
#include <unordered_set>

Expand All @@ -17,9 +18,6 @@

namespace mlx::core {

// Maximum allowed graph depth for eval
constexpr uint32_t max_graph_depth = 100'000;

/* This class is only meant to be used in eval
* for synchronizing with the main thread. */
class Synchronizer : public Primitive {
Expand All @@ -44,8 +42,6 @@ std::shared_future<void> async_eval(std::vector<array> outputs) {
if (global_synchronizer.valid()) {
global_synchronizer.wait();
}

std::function<void(const array&)> recurse;
std::queue<array> tape;
std::unordered_set<std::uintptr_t> cache;
std::unordered_map<std::uintptr_t, std::shared_future<void>> deps;
Expand All @@ -62,47 +58,45 @@ std::shared_future<void> async_eval(std::vector<array> outputs) {
auto synchronizer = array(
{}, bool_, std::make_shared<Synchronizer>(stream), std::move(outputs));

size_t depth_counter = 0;
recurse = [&](const array& a) {
if (depth_counter > max_graph_depth) {
throw std::runtime_error(
"[eval] Graph depth exceeded maximum allowed limit."
" Try evaluating the graph more frequently.");
}

auto id = a.id();
if (cache.find(id) != cache.end()) {
return;
}
{
std::stack<std::pair<std::reference_wrapper<array>, int>> dfs;
dfs.emplace(synchronizer, 0);
while (!dfs.empty()) {
auto& [a_ref, idx] = dfs.top();
auto& a = a_ref.get();
if (idx < a.inputs().size()) {
// Add an input, and continue
auto& in = a.inputs()[idx++];
if (!in.is_evaled()) {
if (!in.has_primitive()) {
throw std::invalid_argument(
"[eval] Attempting to eval an array without a primitive.");
}

// If the input is being computed on a different stream, we need to
// manage the dependency.
if (a.primitive().stream() != in.primitive().stream()) {
deps.insert({in.output(0).id(), std::shared_future<void>{}});
}
}

// Recurse to the largest or smallest branch first.
depth_counter++;
for (auto& in : a.inputs()) {
recurse(in);
if (!in.is_evaled()) {
// If the input is being computed on a different stream, we need to
// manage the dependency.
if (a.primitive().stream() != in.primitive().stream()) {
deps.insert({in.output(0).id(), std::shared_future<void>{}});
if (cache.find(in.id()) == cache.end()) {
dfs.emplace(in, 0);
cache.insert(in.id());
for (auto& s : in.siblings()) {
cache.insert(s.id());
}
}
continue;
}
}
depth_counter--;

cache.insert(id);
for (auto& s : a.siblings()) {
cache.insert(s.id());
}
if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) {
if (!a.has_primitive()) {
throw std::invalid_argument(
"[eval] Attempting to eval an array without a primitive.");
// All inputs are done being processed, process this array
if (!a.is_evaled() || (!a.is_tracer() && a.has_primitive())) {
tape.push(a);
}
tape.push(a);
dfs.pop();
}
};

recurse(synchronizer);
}
deps.insert({synchronizer.id(), std::shared_future<void>{}});

std::vector<std::shared_ptr<std::promise<void>>> ps;
Expand Down

0 comments on commit 8580d99

Please sign in to comment.