Skip to content

Commit

Permalink
Merge commit for internal changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Vijay Vasudevan committed Dec 16, 2015
2 parents deff101 + ee4f440 commit 45e048a
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 60 deletions.
39 changes: 36 additions & 3 deletions tensorflow/core/graph/graph_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,27 +324,60 @@ Status BuildControlFlowInfo(Graph* g, std::vector<ControlFlowInfo>* info) {
src_info.iter_level = 0;

string frame_name;
std::deque<const Node*> ready;
std::deque<Node*> ready;
ready.push_back(src_node);
while (!ready.empty()) {
const Node* curr_node = ready.front();
Node* curr_node = ready.front();
ready.pop_front();
const ControlFlowInfo& curr_info = (*info)[curr_node->id()];
const Node* frame = curr_info.frame;
const Node* parent = curr_info.parent_frame;
frame_name = curr_info.frame_name;
int iter_level = curr_info.iter_level;

// Force colocation for control flow nodes. This may reduce the number
// of devices involved in a loop.
// TODO(yuanbyu): In this case, we don't respect the requested device in
// the GraphDef for these nodes. Ideally, the placer would enforce the
// colocation to render this unnecessary.
if (IsExit(curr_node)) {
// Exit to the parent frame.
const ControlFlowInfo& parent_info = (*info)[parent->id()];
frame = parent_info.frame;
parent = parent_info.parent_frame;
frame_name = parent_info.frame_name;
iter_level = parent_info.iter_level;
for (const Edge* in_edge : curr_node->in_edges()) {
if (!in_edge->IsControlEdge()) {
// Colocate with upstream node.
curr_node->set_assigned_device_name(
in_edge->src()->assigned_device_name());
break;
}
}
} else {
if ((IsEnter(curr_node) && !IsRefType(curr_node->input_type(0))) ||
IsNextIteration(curr_node)) {
const Edge* data_edge = nullptr;
for (const Edge* out_edge : curr_node->out_edges()) {
if (!out_edge->IsControlEdge()) {
if (data_edge) {
data_edge = nullptr;
break;
}
data_edge = out_edge;
}
}
// Colocate if there is only one downstream data node.
if (data_edge) {
curr_node->set_assigned_device_name(
data_edge->dst()->assigned_device_name());
}
}
}

for (const Edge* out_edge : curr_node->out_edges()) {
const Node* out = out_edge->dst();
Node* out = out_edge->dst();
int out_id = out->id();
ControlFlowInfo* out_info = &(*info)[out_id];
const Node* out_parent = out_info->parent_frame;
Expand Down
24 changes: 23 additions & 1 deletion tensorflow/core/graph/graph_partition_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ void Partition(const GraphDef& graph_def,
void CheckLoopConstruction(const GraphDef& graph_def) {
std::unordered_map<string, GraphDef> partitions;
Partition(graph_def, &partitions);
GraphConstructorOptions opts;
for (const auto& kv : partitions) {
const GraphDef& gdef = kv.second;
bool has_control_enter = false;
Expand Down Expand Up @@ -334,5 +333,28 @@ TEST_F(GraphPartitionTest, CrossDeviceLoop) {
CheckLoopConstruction(ToGraphDef());
}

TEST_F(GraphPartitionTest, CrossDeviceLoop1) {
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
Node* a1 = BoolInput(in_.opts().WithName("A1"));
Node* a2 = Enter(a1, "foo", in_.opts().WithName("B2"));
Node* a3 = Merge({a2, {"B5", 0, DT_BOOL}}, in_.opts().WithName("A3"));
LoopCond(a3, in_.opts().WithName("A4"));
Node* b1 = Identity(a3, in_.opts().WithName("B1"));
NextIteration(b1, in_.opts().WithName("B5"));

std::unordered_map<string, GraphDef> partitions;
Partition(ToGraphDef(), &partitions);
for (const auto& kv : partitions) {
const GraphDef& gdef = kv.second;
for (const NodeDef& ndef : gdef.node()) {
if (ndef.name() == "A3") {
// A3, B2, and B5 are on the same device.
EXPECT_EQ(ndef.input(0), "B2");
EXPECT_EQ(ndef.input(1), "B5");
}
}
}
}

} // namespace
} // namespace tensorflow
11 changes: 11 additions & 0 deletions tensorflow/python/framework/tensor_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,5 +555,16 @@ def ConstantValue(tensor):
return None
cast_dtype = dtypes.as_dtype(tensor.op.get_attr("DstT"))
return pre_cast.astype(cast_dtype.as_numpy_dtype)
elif tensor.op.type == "Concat":
dim = ConstantValue(tensor.op.inputs[0])
if dim is None:
return None
values = []
for x in tensor.op.inputs[1:]:
value = ConstantValue(x)
if value is None:
return None
values.append(value)
return np.concatenate(values, axis=dim)
else:
return None
21 changes: 21 additions & 0 deletions tensorflow/python/framework/tensor_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,5 +412,26 @@ def testCast(self):
c_val = tensor_util.ConstantValue(tf_val)
self.assertAllClose(np_val.astype(np.float64), c_val)

def testConcat(self):
np_val = np.random.rand(3, 4, 7).astype(np.float32)
tf_val = array_ops.concat(
0, [np_val[0:1, :, :], np_val[1:2, :, :], np_val[2:3, :, :]])
c_val = tensor_util.ConstantValue(tf_val)
self.assertAllClose(np_val, c_val)

tf_val = array_ops.concat(
array_ops.placeholder(dtypes.int32),
[np_val[0, :, :], np_val[1, :, :], np_val[2, :, :]])
c_val = tensor_util.ConstantValue(tf_val)
self.assertIs(None, c_val)

tf_val = array_ops.concat(
1,
[np_val[0, :, :], array_ops.placeholder(dtypes.float32),
np_val[2, :, :]])
c_val = tensor_util.ConstantValue(tf_val)
self.assertIs(None, c_val)


if __name__ == "__main__":
googletest.main()
89 changes: 54 additions & 35 deletions tensorflow/python/kernel_tests/control_flow_ops_py_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def testRefEnter(self):
with self.test_session():
v = tf.Variable(7)

enter_v = control_flow_ops._Enter(v, "foo_1")
enter_v = control_flow_ops._Enter(v, "foo_1", is_constant=True)
nine = tf.constant(9)
enter_nine = control_flow_ops.enter(nine, "foo_1")
op = tf.assign(enter_v, enter_nine)
Expand Down Expand Up @@ -269,58 +269,65 @@ def testLoop_false(self):
enter_false = control_flow_ops.enter(false, "foo_1", False)
enter_n = control_flow_ops.enter(n, "foo_1", False)

merge_n = control_flow_ops.merge([enter_n], name="merge_n")[0]
merge_n = control_flow_ops.merge([enter_n, enter_n], name="merge_n")[0]
switch_n = control_flow_ops.switch(merge_n, enter_false)
exit_n = control_flow_ops.exit(switch_n[0])
next_n = control_flow_ops.next_iteration(switch_n[0])
merge_n.op._update_input(1, next_n)

result = exit_n.eval()
self.assertAllEqual(10, result)

def testLoop_false_1(self):
def testLoop_1(self):
with self.test_session():
false = tf.convert_to_tensor(False)
zero = tf.constant(0)
one = tf.constant(1)
n = tf.constant(10)

enter_false = control_flow_ops.enter(false, "foo_1", False)
enter_n = control_flow_ops.enter(n, "foo_1", False)
enter_i = control_flow_ops.enter(zero, "foo", False)
enter_one = control_flow_ops.enter(one, "foo", True)
enter_n = control_flow_ops.enter(n, "foo", True)

merge_n = control_flow_ops.merge([enter_n, enter_n], name="merge_n")[0]
switch_n = control_flow_ops.switch(merge_n, enter_false)
exit_n = control_flow_ops.exit(switch_n[0])
next_n = control_flow_ops.next_iteration(switch_n[0])
merge_n.op._update_input(1, next_n)
with tf.device("/gpu:0"):
merge_i = control_flow_ops.merge([enter_i, enter_i])[0]

result = exit_n.eval()
less_op = tf.less(merge_i, enter_n)
cond_op = control_flow_ops.loop_cond(less_op)
switch_i = control_flow_ops.switch(merge_i, cond_op)

add_i = tf.add(switch_i[1], enter_one)

next_i = control_flow_ops.next_iteration(add_i)
merge_i.op._update_input(1, next_i)

exit_i = control_flow_ops.exit(switch_i[0])
result = exit_i.eval()
self.assertAllEqual(10, result)

def testLoop_1(self):
def testLoop_2(self):
with self.test_session():
zero = tf.convert_to_tensor(0)
one = tf.convert_to_tensor(1)
zero = tf.constant(0)
one = tf.constant(1)
n = tf.constant(10)

enter_zero = control_flow_ops.enter(zero, "foo_1", False)
enter_one = control_flow_ops.enter(one, "foo_1", False)
enter_n = control_flow_ops.enter(n, "foo_1", False)
merge_zero = control_flow_ops.merge([enter_zero, enter_zero],
name="merge_zero")[0]
merge_one = control_flow_ops.merge([enter_one, enter_one],
name="merge_one")[0]
merge_n = control_flow_ops.merge([enter_n, enter_n], name="merge_n")[0]
less_op = tf.less(merge_n, merge_n)
enter_i = control_flow_ops.enter(zero, "foo", False)
enter_one = control_flow_ops.enter(one, "foo", True)
enter_n = control_flow_ops.enter(n, "foo", True)

merge_i = control_flow_ops.merge([enter_i, enter_i])[0]

less_op = tf.less(merge_i, enter_n)
cond_op = control_flow_ops.loop_cond(less_op)
switch_zero = control_flow_ops.switch(merge_zero, cond_op)
switch_one = control_flow_ops.switch(merge_one, cond_op)
switch_n = control_flow_ops.switch(merge_n, cond_op)
next_zero = control_flow_ops.next_iteration(switch_zero[1])
next_one = control_flow_ops.next_iteration(switch_one[1])
next_n = control_flow_ops.next_iteration(switch_n[1])
merge_zero.op._update_input(1, next_zero)
merge_one.op._update_input(1, next_one)
merge_n.op._update_input(1, next_n)
exit_n = control_flow_ops.exit(switch_n[0])
switch_i = control_flow_ops.switch(merge_i, cond_op)

result = exit_n.eval()
add_i = tf.add(switch_i[1], enter_one)

with tf.device("/gpu:0"):
next_i = control_flow_ops.next_iteration(add_i)
merge_i.op._update_input(1, next_i)

exit_i = control_flow_ops.exit(switch_i[0])
result = exit_i.eval()
self.assertAllEqual(10, result)

def testCondIndexedSlices(self):
Expand Down Expand Up @@ -445,6 +452,17 @@ def testCond_6(self):
result = r.eval()
self.assertAllEqual(np.array([7]), result)

def testCond_7(self):
with self.test_session() as sess:
x = tf.constant(10)
y = tf.constant(200)
pred = tf.less(1, 2)
fn1 = lambda: [tf.add(x, 1), tf.add(x, 2)]
fn2 = lambda: [y, y]
r = control_flow_ops.cond(pred, fn1, fn2)

self.assertAllEqual([11, 12], sess.run(r))

def testCondGrad_1(self):
with self.test_session():
x = tf.constant(10.0, name="x")
Expand Down Expand Up @@ -1365,5 +1383,6 @@ def testAcceptTensorsAsControlInputs(self):

self.assertEquals(1, var.eval())


if __name__ == "__main__":
tf.test.main()
14 changes: 10 additions & 4 deletions tensorflow/python/ops/control_flow_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,23 @@ def _SwitchGrad(op, *grad):
if isinstance(ctxt, WhileContext):
merge_op = ctxt.switch_map.get(op)
if merge_op:
# This is the second time this Switch is visited. Update the second
# input to the Merge.
merge_op._update_input(1, next_iteration(grad[1]))
return None, None
else:
merge_op = merge(grad, name="b_switch")[0]
# This is the first time this Switch is visited. grad[1] is empty
# at this point. Use grad[0] for both inputs to merge, but update
# the second input of merge when we see this Switch the second time.
merge_op = merge([grad[0], grad[0]], name="b_switch")[0]
ctxt.switch_map[op] = merge_op.op
return merge_op, None
elif isinstance(ctxt, CondContext):
good_grad = grad[ctxt.branch]
zero_grad = grad[1 - ctxt.branch]
zero_grad = switch(zero_grad, ctxt.pred, name="grad_0")[1 - ctxt.branch]
return merge([good_grad, zero_grad], name="switch_grad")[0], None
dtype = good_grad.dtype
zero_grad = switch(zero_grad, ctxt.pred, dtype=dtype)[1 - ctxt.branch]
return merge([good_grad, zero_grad], name="cond_grad")[0], None
else:
false_grad = switch(grad[0], op.inputs[1])[0]
true_grad = switch(grad[1], op.inputs[1])[1]
Expand All @@ -66,7 +72,7 @@ def _MergeGrad(op, grad, _):
grad_ctxt = ctxt.grad_context
return switch(grad, grad_ctxt.pivot)
elif isinstance(ctxt, CondContext):
return switch(grad, ctxt.pred, name="merge_grad")
return switch(grad, ctxt.pred, name="cond_grad")
else:
num_inputs = len(op.inputs)
cond = [math_ops.equal(op.outputs[1], i) for i in xrange(num_inputs)]
Expand Down
Loading

0 comments on commit 45e048a

Please sign in to comment.