Skip to content

Commit c4e416f

Browse files
author
Jonathan Hseu
committed
Merge commit for internal changes
2 parents 7497fca + 2e57e3f commit c4e416f

File tree

249 files changed

+6558
-17741
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

249 files changed

+6558
-17741
lines changed

configure.py

+5
Original file line numberDiff line numberDiff line change
@@ -998,6 +998,10 @@ def create_android_bazelrc_configs():
998998
write_to_bazelrc('build:android_arm64 --cpu=arm64-v8a')
999999

10001000

1001+
def set_grpc_build_flags():
1002+
write_to_bazelrc('build --define grpc_no_ares=true')
1003+
1004+
10011005
def main():
10021006
# Make a copy of os.environ to be clear when functions and getting and setting
10031007
# environment variables.
@@ -1071,6 +1075,7 @@ def main():
10711075
set_mpi_home(environ_cp)
10721076
set_other_mpi_vars(environ_cp)
10731077

1078+
set_grpc_build_flags()
10741079
set_cc_opt_flags(environ_cp)
10751080
set_mkl()
10761081
set_monolithic()

tensorflow/c/python_api.cc

+27
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,33 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
4646
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
4747
TF_Status* status) {
4848
mutex_lock l(graph->mu);
49+
tensorflow::shape_inference::InferenceContext* ic =
50+
graph->refiner.GetContext(&new_src.oper->node);
51+
52+
if (ic->num_outputs() <= new_src.index) {
53+
status->status = tensorflow::errors::OutOfRange(
54+
"Cannot update edge. Output index [", new_src.index,
55+
"] is greater than the number of total outputs [", ic->num_outputs(),
56+
"].");
57+
return;
58+
}
59+
tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
60+
61+
tensorflow::shape_inference::InferenceContext* ic_dst =
62+
graph->refiner.GetContext(&dst.oper->node);
63+
if (ic_dst->num_inputs() <= dst.index) {
64+
status->status = tensorflow::errors::OutOfRange(
65+
"Cannot update edge. Input index [", dst.index,
66+
"] is greater than the number of total inputs [", ic_dst->num_inputs(),
67+
"].");
68+
return;
69+
}
70+
if (!ic_dst->MergeInput(dst.index, shape)) {
71+
status->status = tensorflow::errors::InvalidArgument(
72+
"Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
73+
" and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
74+
return;
75+
}
4976
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
5077
&dst.oper->node, dst.index);
5178
}

tensorflow/compiler/tests/binary_ops_test.py

+44-6
Original file line numberDiff line numberDiff line change
@@ -366,16 +366,52 @@ def testComplexOps(self):
366366

367367
self._testBinary(
368368
gen_math_ops._real_div,
369-
np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j, 44 + 3j], dtype=dtype),
370-
np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j, 0], dtype=dtype),
369+
np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j], dtype=dtype),
370+
np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j], dtype=dtype),
371+
expected=np.array(
372+
[1.5, -1.5j, -0.2142857, -2j, (2 + 3j) / (4 - 6j), 2],
373+
dtype=dtype))
374+
375+
# Test inf/nan scenarios.
376+
self._testBinary(
377+
gen_math_ops._real_div,
378+
np.array([4 + 3j, 4, 3j, -4, -4j, 2 - 3j], dtype=dtype),
379+
np.array([0, 0, 0, 0, 0, 0], dtype=dtype),
371380
expected=np.array(
372381
[
373-
1.5, -1.5j, -0.2142857, -2j, (2 + 3j) / (4 - 6j), 2,
374-
float("inf")
382+
dtype(1 + 1j) / 0,
383+
dtype(1) / 0,
384+
dtype(1j) / 0,
385+
dtype(-1) / 0,
386+
dtype(-1j) / 0,
387+
dtype(1 - 1j) / 0
375388
],
376389
dtype=dtype))
377390

378-
# TODO(b/65408531): support+test pow for cplx
391+
atan2_supported = self.device == "XLA_GPU"
392+
if atan2_supported:
393+
self._testBinary(
394+
math_ops.pow,
395+
dtype(3 + 2j),
396+
dtype(4 - 5j),
397+
expected=np.power(dtype(3 + 2j), dtype(4 - 5j)))
398+
self._testBinary( # empty rhs
399+
math_ops.pow,
400+
np.array([1 + 2j, 2 - 3j], dtype=dtype),
401+
np.zeros(shape=[0, 2], dtype=dtype),
402+
expected=np.zeros(shape=[0, 2], dtype=dtype))
403+
self._testBinary( # to zero power
404+
math_ops.pow,
405+
np.array([1 + 2j, 2 - 3j], dtype=dtype),
406+
np.zeros(shape=[1, 2], dtype=dtype),
407+
expected=np.ones(shape=[1, 2], dtype=dtype))
408+
lhs = np.array([1 - 2j, 4 + 3j, 2 - 3j, 3, 2j, 1, 4], dtype=dtype)
409+
rhs = np.array([2, 3j, 3 + 4j, 2 + 3j, 3 - 2j, 2, 3 + 3j], dtype=dtype)
410+
scalar = dtype(2 + 2j)
411+
self._testBinary(math_ops.pow, lhs, rhs, expected=np.power(lhs, rhs))
412+
self._testBinary(
413+
math_ops.pow, scalar, rhs, expected=np.power(scalar, rhs))
414+
self._testBinary(math_ops.pow, lhs, scalar, np.power(lhs, scalar))
379415

380416
lhs = np.array([4 + 2j, -3 - 1j, 2j, 1], dtype=dtype)
381417
rhs = np.array([5, -6j, 7 - 3j, -8j], dtype=dtype)
@@ -385,7 +421,9 @@ def testComplexOps(self):
385421
self._testBinary(
386422
gen_math_ops._sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs))
387423

388-
# TODO(b/65408531): support+test _rsqrt_grad for cplx (needs pow)
424+
if atan2_supported:
425+
self._testBinary(
426+
gen_math_ops._rsqrt_grad, lhs, rhs, expected=lhs**3 * rhs / -2)
389427

390428
self._testBinary(
391429
gen_math_ops._sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs))

tensorflow/compiler/tests/reduce_ops_test.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -67,25 +67,37 @@ def _testReduction(self, tf_reduce_fn, np_reduce_fn, dtype, test_inputs,
6767
np.arange(-10, -4).reshape(2, 3),
6868
np.arange(-4, 2).reshape(2, 3),
6969
]
70-
NONEMPTY_FLOAT_DATA = [
71-
np.arange(1, 7).reshape(2, 3),
72-
np.arange(-10, -4).reshape(2, 3),
73-
np.arange(-4, 2).reshape(2, 3),
70+
COMPLEX_DATA = [
71+
np.zeros(shape=(2, 0)).astype(np.complex64),
72+
np.zeros(shape=(0, 30)).astype(np.complex64),
73+
np.arange(1, 13, dtype=np.float32).view(np.complex64).reshape(2, 3),
74+
np.arange(-14, -2, dtype=np.float32).view(np.complex64).reshape(2, 3),
75+
np.arange(-4, 8, dtype=np.float32).view(np.complex64).reshape(2, 3),
7476
]
77+
NONEMPTY_FLOAT_DATA = [x for x in FLOAT_DATA if np.size(x) > 0]
78+
NONEMPTY_COMPLEX_DATA = [x for x in COMPLEX_DATA if np.size(x) > 0]
7579
BOOL_DATA = [
7680
np.array([], dtype=np.bool).reshape(2, 0),
7781
np.array([], dtype=np.bool).reshape(0, 3),
7882
np.array([[False, True, False], [True, True, False]]),
7983
]
8084

81-
def testReduceSum(self):
85+
def testReduceSumF32(self):
8286
self._testReduction(math_ops.reduce_sum, np.sum, np.float32,
8387
self.FLOAT_DATA)
8488

85-
def testReduceProd(self):
89+
def testReduceSumC64(self):
90+
self._testReduction(math_ops.reduce_sum, np.sum, np.complex64,
91+
self.COMPLEX_DATA)
92+
93+
def testReduceProdF32(self):
8694
self._testReduction(math_ops.reduce_prod, np.prod, np.float32,
8795
self.FLOAT_DATA)
8896

97+
def testReduceProdC64(self):
98+
self._testReduction(math_ops.reduce_prod, np.prod, np.complex64,
99+
self.COMPLEX_DATA)
100+
89101
def testReduceMin(self):
90102

91103
def reference_min(inp, axis):
@@ -108,12 +120,16 @@ def reference_max(inp, axis):
108120
self._testReduction(math_ops.reduce_max, reference_max, np.float32,
109121
self.FLOAT_DATA)
110122

111-
def testReduceMean(self):
123+
def testReduceMeanF32(self):
112124
# TODO(phawkins): mean on XLA currently returns 0 instead of NaN when
113125
# reducing across zero inputs.
114126
self._testReduction(math_ops.reduce_mean, np.mean, np.float32,
115127
self.NONEMPTY_FLOAT_DATA)
116128

129+
def testReduceMeanC64(self):
130+
self._testReduction(math_ops.reduce_mean, np.mean, np.complex64,
131+
self.NONEMPTY_COMPLEX_DATA)
132+
117133
def testReduceAll(self):
118134
self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA)
119135

tensorflow/compiler/tests/unary_ops_test.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -330,12 +330,22 @@ def testFloatOps(self):
330330

331331
def testComplexOps(self):
332332
for dtype in self.complex_types:
333-
# TODO(b/65408531): math_ops.acosh (needs pow)
334-
# TODO(b/65408531): math_ops.asinh (needs pow)
335333

336334
# TODO(b/65408531): Wider support for log (needs atan2).
337335
atan2_supported = self.device == "XLA_GPU"
338336
if atan2_supported:
337+
self._assertOpOutputMatchesExpected(
338+
math_ops.acosh,
339+
np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype),
340+
expected=np.arccosh(
341+
np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype)))
342+
343+
self._assertOpOutputMatchesExpected(
344+
math_ops.asinh,
345+
np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype),
346+
expected=np.arcsinh(
347+
np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype)))
348+
339349
self._assertOpOutputMatchesExpected(
340350
math_ops.atanh,
341351
np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype),
@@ -392,19 +402,26 @@ def testComplexOps(self):
392402
expected=np.log1p(
393403
np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype)))
394404

395-
# TODO(b/34703906): math_ops.rsqrt (needs pow)
405+
val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)
406+
self._assertOpOutputMatchesExpected(
407+
math_ops.rsqrt, val, expected=1 / np.sqrt(val))
396408

397-
# TODO(b/34703906): math_ops.sigmoid (needs tanh)
409+
self._assertOpOutputMatchesExpected(
410+
math_ops.sigmoid, val, expected=1 / (1 + np.exp(-val)))
398411

399-
# TODO(b/34703906): math_ops.sqrt (needs pow)
412+
self._assertOpOutputMatchesExpected(
413+
math_ops.sqrt, val, expected=np.sqrt(val))
414+
415+
self._assertOpOutputMatchesExpected(
416+
math_ops.tanh,
417+
np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype),
418+
expected=np.tanh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)))
400419

401420
self._assertOpOutputMatchesExpected(
402421
math_ops.tan,
403422
np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype),
404423
expected=np.tan(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)))
405424

406-
# TODO(b/34703906): math_ops.tanh (as itself)
407-
408425
ctypes = {np.complex64: np.float32}
409426
self._assertOpOutputMatchesExpected(
410427
math_ops.abs,

tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc

+14-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ std::vector<tensorflow::Flag>* flag_objects;
3131
std::once_flag flags_init;
3232

3333
void SetDebugOptionsDefaults(DebugOptions* flags) {
34-
flags->set_xla_hlo_graph_path("/tmp/");
3534
flags->set_xla_enable_fast_math(true);
3635
flags->set_xla_llvm_enable_alias_scope_metadata(true);
3736
flags->set_xla_llvm_enable_noalias_metadata(true);
@@ -117,9 +116,22 @@ void AllocateFlags() {
117116
bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_graphdef),
118117
flag_values->xla_hlo_dump_as_graphdef(),
119118
"Dump HLO graphs as TensorFlow GraphDefs."),
119+
tensorflow::Flag(
120+
"xla_hlo_graph_sharding_color",
121+
bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color),
122+
flag_values->xla_hlo_graph_sharding_color(),
123+
"Assign colors based on sharding assignments when generating the "
124+
"HLO graphs."),
125+
tensorflow::Flag(
126+
"xla_hlo_tfgraph_device_scopes",
127+
bool_setter_for(&DebugOptions::set_xla_hlo_tfgraph_device_scopes),
128+
flag_values->xla_hlo_tfgraph_device_scopes(),
129+
"When generating TensorFlow HLO graphs, if the HLO instructions "
130+
"are assigned to a specific device, prefix the name scope with "
131+
"\"devX\" with X being the device ordinal."),
120132
tensorflow::Flag(
121133
"xla_log_hlo_text", flag_values->mutable_xla_log_hlo_text(),
122-
"HLO modules matching this regex will be dumped to LOG(INFO). "),
134+
"HLO modules matching this regex will be dumped to LOG(INFO)."),
123135
tensorflow::Flag(
124136
"xla_generate_hlo_text_to",
125137
flag_values->mutable_xla_generate_hlo_text_to(),

tensorflow/compiler/xla/literal_util_test.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,10 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) {
114114
auto bf16_lit = Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
115115
ASSERT_EQ("0.5", bf16_lit->ToString());
116116

117-
// 3.14 will be rounded to 3.125 in bfloat16 format (Round to nearest even).
117+
// 3.14 will be truncated to 3.125 in bfloat16 format.
118118
auto bf16_lit_truncated =
119119
Literal::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
120-
ASSERT_EQ("3.140625", bf16_lit_truncated->ToString());
120+
ASSERT_EQ("3.125", bf16_lit_truncated->ToString());
121121

122122
auto bf16_lit_truncated2 =
123123
Literal::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));

tensorflow/compiler/xla/service/BUILD

+25
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,7 @@ cc_library(
630630

631631
cc_library(
632632
name = "llvm_compiler",
633+
srcs = ["llvm_compiler.cc"],
633634
hdrs = ["llvm_compiler.h"],
634635
deps = [
635636
":compiler",
@@ -1358,6 +1359,7 @@ cc_library(
13581359
deps = [
13591360
":hlo",
13601361
":hlo_cost_analysis",
1362+
":hlo_profile_printer",
13611363
":human_readable_profile_builder",
13621364
"//tensorflow/compiler/xla:types",
13631365
"//tensorflow/compiler/xla:util",
@@ -1366,6 +1368,18 @@ cc_library(
13661368
],
13671369
)
13681370

1371+
tf_cc_test(
1372+
name = "hlo_execution_profile_test",
1373+
srcs = ["hlo_execution_profile_test.cc"],
1374+
deps = [
1375+
":cpu_plugin",
1376+
":hlo_cost_analysis",
1377+
":hlo_execution_profile",
1378+
"//tensorflow/compiler/xla/tests:hlo_test_base",
1379+
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
1380+
],
1381+
)
1382+
13691383
tf_cc_test(
13701384
name = "hlo_computation_test",
13711385
srcs = ["hlo_computation_test.cc"],
@@ -1983,6 +1997,7 @@ cc_library(
19831997
":hlo",
19841998
"//tensorflow/compiler/xla:literal_util",
19851999
"//tensorflow/compiler/xla:shape_util",
2000+
"//tensorflow/compiler/xla:xla_proto",
19862001
"//tensorflow/core:framework",
19872002
"//tensorflow/core:lib",
19882003
"//tensorflow/core:protos_all_cc",
@@ -2156,6 +2171,16 @@ cc_library(
21562171
],
21572172
)
21582173

2174+
cc_library(
2175+
name = "hlo_profile_printer",
2176+
srcs = ["hlo_profile_printer.cc"],
2177+
hdrs = ["hlo_profile_printer.h"],
2178+
deps = [
2179+
":human_readable_profile_builder",
2180+
"//tensorflow/compiler/xla:types",
2181+
],
2182+
)
2183+
21592184
# -----------------------------------------------------------------------------
21602185

21612186
filegroup(

tensorflow/compiler/xla/service/algebraic_simplifier.cc

+15-3
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,6 @@ limitations under the License.
4646
namespace xla {
4747
namespace {
4848

49-
using tensorflow::gtl::nullopt;
50-
using tensorflow::gtl::optional;
51-
5249
// Returns whether operand is a literal with the given value.
5350
bool IsLiteralWithValue(const HloInstruction* operand, int8 value) {
5451
return operand->opcode() == HloOpcode::kConstant &&
@@ -135,7 +132,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
135132

136133
Status HandleConvert(HloInstruction* convert) override;
137134

135+
Status HandleComplex(HloInstruction* complex) override;
136+
138137
Status HandleReal(HloInstruction* real) override;
138+
139139
Status HandleImag(HloInstruction* imag) override;
140140

141141
Status HandleConvolution(HloInstruction* convolution) override;
@@ -947,6 +947,18 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) {
947947
return Status::OK();
948948
}
949949

950+
// Complex(Real(c), Imag(c)) -> c
951+
Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) {
952+
auto real = complex->mutable_operand(0);
953+
auto imag = complex->mutable_operand(1);
954+
if (real->opcode() == HloOpcode::kReal &&
955+
imag->opcode() == HloOpcode::kImag &&
956+
real->operand(0) == imag->operand(0)) {
957+
return ReplaceInstruction(complex, real->mutable_operand(0));
958+
}
959+
return Status::OK();
960+
}
961+
950962
// Real(Complex(r, i)) -> r
951963
Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) {
952964
auto operand = real->mutable_operand(0);

0 commit comments

Comments
 (0)