Skip to content

Commit e5c6881

Browse files
mrrytensorflower-gardener
authored andcommitted
[FunctionLibraryRuntime] Optimize single-component "multi-device" function dispatch.
This change enables the (single-device) `FunctionLibraryRuntimeImpl` to dispatch a multi-device function directly, if (i) it has a single component, and (ii) that component is local to the `FunctionLibraryRuntimeImpl` instance. This avoids the (microsecond-scale) overhead of preparing and remapping the inputs to and arguments from a multi-device function, which is important for clients (like tf.data) that invoke many fine-grained functions. PiperOrigin-RevId: 307090208 Change-Id: I21820aa5b84360c2595b49c2e22d7a2b037819c4
1 parent c301f1e commit e5c6881

7 files changed

+96
-20
lines changed

tensorflow/core/common_runtime/function.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -1203,7 +1203,8 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
12031203
};
12041204
}
12051205

1206-
LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
1206+
LocalHandle local_handle = parent_->GetHandleOnDevice(
1207+
device_name_, handle, /*include_multi_device=*/true);
12071208
if (local_handle == kInvalidLocalHandle) {
12081209
parent_->Run(run_opts, handle, frame, done);
12091210
return;

tensorflow/core/common_runtime/function_test.cc

+13
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,19 @@ TEST_F(FunctionLibraryRuntimeTest, XTimesTwo) {
394394
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
395395
}
396396

397+
TEST_F(FunctionLibraryRuntimeTest, XTimesTwo_MultiDeviceBacked) {
398+
Init({test::function::XTimesTwo()});
399+
auto x = test::AsTensor<float>({1, 2, 3, 4});
400+
Tensor y;
401+
402+
FunctionLibraryRuntime::InstantiateOptions options;
403+
options.is_multi_device_function = true;
404+
405+
TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options,
406+
{x}, {&y}));
407+
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
408+
}
409+
397410
TEST_F(FunctionLibraryRuntimeTest, XTimesN) {
398411
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
399412
test::function::XTimes16()});

tensorflow/core/common_runtime/partitioning_utils.cc

+19-5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ limitations under the License.
1414
==============================================================================*/
1515
#include "tensorflow/core/common_runtime/partitioning_utils.h"
1616

17+
#include <algorithm>
18+
1719
#include "tensorflow/core/framework/function.h"
1820
#include "tensorflow/core/framework/types.h"
1921
#include "tensorflow/core/graph/graph.h"
@@ -82,20 +84,32 @@ Status UpdateArgAndRetvalMetadata(
8284
// Find the Arg and Retval nodes, along with their corresponding indices
8385
// in the original function.
8486
for (Node* node : subgraph->op_nodes()) {
85-
string node_type = node->type_string();
8687
if (node->IsArg()) {
8788
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
8889
int index = static_cast<int>(attr_value->i());
89-
arg_indices->push_back(index);
90-
arg_nodes.push_back(std::make_pair(node, index));
90+
arg_nodes.emplace_back(node, index);
9191
} else if (node->IsRetval()) {
9292
TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
9393
int index = static_cast<int>(attr_value->i());
94-
ret_indices->push_back(index);
95-
ret_nodes.push_back(std::make_pair(node, index));
94+
ret_nodes.emplace_back(node, index);
9695
}
9796
}
9897

98+
// Sort the nodes by index so that the order is stable.
99+
//
100+
// In particular, this enables calling a single-partition function with
101+
// the same signature as the original unpartitioned function.
102+
auto comparator = [](std::pair<Node*, int> a, std::pair<Node*, int> b) {
103+
return a.second < b.second;
104+
};
105+
std::sort(arg_nodes.begin(), arg_nodes.end(), comparator);
106+
std::sort(ret_nodes.begin(), ret_nodes.end(), comparator);
107+
108+
arg_indices->reserve(arg_nodes.size());
109+
for (const auto& pair : arg_nodes) arg_indices->push_back(pair.second);
110+
ret_indices->reserve(ret_nodes.size());
111+
for (const auto& pair : ret_nodes) ret_indices->push_back(pair.second);
112+
99113
for (int i = 0; i < arg_nodes.size(); ++i) {
100114
Node* arg = arg_nodes[i].first;
101115
arg->AddAttr("index", i);

tensorflow/core/common_runtime/partitioning_utils.h

+6-5
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,18 @@ Status PartitionFunctionGraph(
4141
//
4242
// More specifically, this function
4343
// (1) rewrites the indices of the `Arg` and `Retval` nodes placed
44-
// on a particular device. When a function is partitioned each
45-
// partition, `subgraph`, get a subset of the arguments and
44+
// on a particular device. When a function is partitioned, each
45+
// partition `subgraph` gets a subset of the arguments and
4646
// return values. The `index` attributes of these _Arg and _Retval
4747
// nodes reflect the indices of these parameters in the original
4848
// function. To convert `subgraph` to a function, we need to replace
4949
// there original indices with 0, 1, 2, ... .
5050
//
5151
// The argument and return value order in the partitioned function is
52-
// determined by the node iteration order in `subgraph`. This order
53-
// is also used in UpdateArgAndRetvalMetadata. This is fine because the
54-
// node iteration order is deterministic - it follows the node ids.
52+
// determined by the argument and return value order in the original
53+
// function. This stability is important because it enables us to treat
54+
// a single-partition function as having the same signature as the
55+
// subgraph.
5556
// (2) records the subsets of `Arg` and `Retval` nodes assigned to the
5657
// device in `*_indices`, and
5758
// (3) records which `Arg` and `Retval` nodes live in host memory in

tensorflow/core/common_runtime/partitioning_utils_test.cc

+35-6
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ limitations under the License.
3030
#include "tensorflow/core/lib/core/errors.h"
3131
#include "tensorflow/core/lib/core/status.h"
3232
#include "tensorflow/core/lib/core/status_test_util.h"
33+
#include "tensorflow/core/lib/gtl/array_slice.h"
3334
#include "tensorflow/core/platform/test.h"
3435
#include "tensorflow/core/public/session_options.h"
3536

@@ -91,12 +92,18 @@ class PartitioningUtilsTest : public ::testing::Test {
9192
// Fills subgraph with an identify function arg->identity->ret
9293
// where each node has type `dtype` and arg/ret nodes have
9394
// indices `arg_index` and `ret_index`.
94-
void SubGraph(Graph* subgraph, DataType dtype, int arg_index, int ret_index) {
95+
void SubGraph(Graph* subgraph, DataType dtype,
96+
gtl::ArraySlice<int> arg_indices,
97+
gtl::ArraySlice<int> ret_indices) {
9598
Scope s = Scope::NewRootScope();
9699
Scope s1 = s.WithDevice("/job:a/replica:0/task:0/device:CPU:0");
97-
auto x = ops::_Arg(s1.WithOpName("x"), dtype, arg_index);
98-
auto id_x = ops::Identity(s1.WithOpName("id_x"), x);
99-
auto dx_retval = ops::_Retval(s1.WithOpName("retval1"), id_x, ret_index);
100+
CHECK_EQ(arg_indices.size(), ret_indices.size());
101+
for (size_t i = 0; i < arg_indices.size(); ++i) {
102+
auto x = ops::_Arg(s1.WithOpName("x"), dtype, arg_indices[i]);
103+
auto id_x = ops::Identity(s1.WithOpName("id_x"), x);
104+
auto dx_retval =
105+
ops::_Retval(s1.WithOpName("retval1"), id_x, ret_indices[i]);
106+
}
100107
TF_ASSERT_OK(s.ToGraph(subgraph));
101108
Placer placer(subgraph, "", &device_set_, device0_);
102109
TF_ASSERT_OK(placer.Run());
@@ -175,8 +182,8 @@ void CheckIndex(const Node& node, int expected_index) {
175182
}
176183

177184
TEST_F(PartitioningUtilsTest, UpdateArgsAndRets) {
178-
std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
179-
SubGraph(graph.get(), DT_FLOAT, 3, 5);
185+
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
186+
SubGraph(graph.get(), DT_FLOAT, {3}, {5});
180187

181188
std::vector<int> arg_indices;
182189
std::vector<int> ret_indices;
@@ -202,5 +209,27 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets) {
202209
CheckIndex(*nodes["retval1"], 0);
203210
}
204211

212+
TEST_F(PartitioningUtilsTest, UpdateArgsAndRets_Order) {
213+
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
214+
SubGraph(graph.get(), DT_FLOAT, {9, 7, 5, 3, 1}, {2, 4, 6, 8, 10});
215+
216+
std::vector<int> arg_indices;
217+
std::vector<int> ret_indices;
218+
std::vector<AllocatorAttributes> arg_alloc_attrs;
219+
std::vector<AllocatorAttributes> ret_alloc_attrs;
220+
221+
string device_type = "CPU";
222+
223+
Status status = UpdateArgAndRetvalMetadata(
224+
graph.get(), device_type, &arg_indices, &ret_indices, &arg_alloc_attrs,
225+
&ret_alloc_attrs);
226+
ASSERT_TRUE(status.ok()) << status.ToString();
227+
228+
CheckIndices({1, 3, 5, 7, 9}, arg_indices);
229+
CheckIndices({2, 4, 6, 8, 10}, ret_indices);
230+
CheckAlloc({false, false, false, false, false}, arg_alloc_attrs);
231+
CheckAlloc({false, false, false, false, false}, ret_alloc_attrs);
232+
}
233+
205234
} // anonymous namespace
206235
} // namespace tensorflow

tensorflow/core/common_runtime/process_function_library_runtime.cc

+15-2
Original file line numberDiff line numberDiff line change
@@ -282,12 +282,25 @@ bool ProcessFunctionLibraryRuntime::IsInstantiatedOnDevice(
282282

283283
FunctionLibraryRuntime::LocalHandle
284284
ProcessFunctionLibraryRuntime::GetHandleOnDevice(
285-
const string& device_name, FunctionLibraryRuntime::Handle handle) const {
285+
const string& device_name, FunctionLibraryRuntime::Handle handle,
286+
bool include_multi_device) const {
286287
tf_shared_lock l(mu_);
287288

288289
auto miter = mdevice_data_.find(handle);
289290
if (miter != mdevice_data_.end()) {
290-
return kInvalidLocalHandle;
291+
if (!include_multi_device) return kInvalidLocalHandle;
292+
293+
const MultiDeviceFunctionData& data = *miter->second;
294+
if (data.glue_.size() != 1) return kInvalidLocalHandle;
295+
296+
const auto& pair = *data.glue_.begin();
297+
const string& func_device_name = pair.first;
298+
const ComponentFunctionData& component_data = pair.second;
299+
if (func_device_name != device_name) return kInvalidLocalHandle;
300+
301+
// Replace the given handle with the handle for the single component
302+
// function.
303+
handle = component_data.handle_;
291304
}
292305

293306
auto iter = function_data_.find(handle);

tensorflow/core/common_runtime/process_function_library_runtime.h

+6-1
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,13 @@ class ProcessFunctionLibraryRuntime {
137137
// index of instantiation of that function. If the function was not
138138
// instantiated on `device_name` or the function is multi-device,
139139
// returns kInvalidLocalHandle.
140+
//
141+
// If `include_multi_device` is true and `handle` is a multi-device function
142+
// with a single component that is placed on `device_name`, then this method
143+
// will return the local handle for that component.
140144
FunctionLibraryRuntime::LocalHandle GetHandleOnDevice(
141-
const string& device_name, FunctionLibraryRuntime::Handle handle) const;
145+
const string& device_name, FunctionLibraryRuntime::Handle handle,
146+
bool include_multi_device = false) const;
142147

143148
// Fills `output_devices` with the devices on which the results will
144149
// be produced. If some output is produced on CPU, the corresponding Device*

0 commit comments

Comments
 (0)