@@ -366,16 +366,52 @@ def testComplexOps(self):
366
366
367
367
self ._testBinary (
368
368
gen_math_ops ._real_div ,
369
- np .array ([3 , 3j , - 1.5j , - 8 , 2 + 3j , 2 + 4j , 44 + 3j ], dtype = dtype ),
370
- np .array ([2 , - 2 , 7j , - 4j , 4 - 6j , 1 + 2j , 0 ], dtype = dtype ),
369
+ np .array ([3 , 3j , - 1.5j , - 8 , 2 + 3j , 2 + 4j ], dtype = dtype ),
370
+ np .array ([2 , - 2 , 7j , - 4j , 4 - 6j , 1 + 2j ], dtype = dtype ),
371
+ expected = np .array (
372
+ [1.5 , - 1.5j , - 0.2142857 , - 2j , (2 + 3j ) / (4 - 6j ), 2 ],
373
+ dtype = dtype ))
374
+
375
+ # Test inf/nan scenarios.
376
+ self ._testBinary (
377
+ gen_math_ops ._real_div ,
378
+ np .array ([4 + 3j , 4 , 3j , - 4 , - 4j , 2 - 3j ], dtype = dtype ),
379
+ np .array ([0 , 0 , 0 , 0 , 0 , 0 ], dtype = dtype ),
371
380
expected = np .array (
372
381
[
373
- 1.5 , - 1.5j , - 0.2142857 , - 2j , (2 + 3j ) / (4 - 6j ), 2 ,
374
- float ("inf" )
382
+ dtype (1 + 1j ) / 0 ,
383
+ dtype (1 ) / 0 ,
384
+ dtype (1j ) / 0 ,
385
+ dtype (- 1 ) / 0 ,
386
+ dtype (- 1j ) / 0 ,
387
+ dtype (1 - 1j ) / 0
375
388
],
376
389
dtype = dtype ))
377
390
378
- # TODO(b/65408531): support+test pow for cplx
391
+ atan2_supported = self .device == "XLA_GPU"
392
+ if atan2_supported :
393
+ self ._testBinary (
394
+ math_ops .pow ,
395
+ dtype (3 + 2j ),
396
+ dtype (4 - 5j ),
397
+ expected = np .power (dtype (3 + 2j ), dtype (4 - 5j )))
398
+ self ._testBinary ( # empty rhs
399
+ math_ops .pow ,
400
+ np .array ([1 + 2j , 2 - 3j ], dtype = dtype ),
401
+ np .zeros (shape = [0 , 2 ], dtype = dtype ),
402
+ expected = np .zeros (shape = [0 , 2 ], dtype = dtype ))
403
+ self ._testBinary ( # to zero power
404
+ math_ops .pow ,
405
+ np .array ([1 + 2j , 2 - 3j ], dtype = dtype ),
406
+ np .zeros (shape = [1 , 2 ], dtype = dtype ),
407
+ expected = np .ones (shape = [1 , 2 ], dtype = dtype ))
408
+ lhs = np .array ([1 - 2j , 4 + 3j , 2 - 3j , 3 , 2j , 1 , 4 ], dtype = dtype )
409
+ rhs = np .array ([2 , 3j , 3 + 4j , 2 + 3j , 3 - 2j , 2 , 3 + 3j ], dtype = dtype )
410
+ scalar = dtype (2 + 2j )
411
+ self ._testBinary (math_ops .pow , lhs , rhs , expected = np .power (lhs , rhs ))
412
+ self ._testBinary (
413
+ math_ops .pow , scalar , rhs , expected = np .power (scalar , rhs ))
414
+ self ._testBinary (math_ops .pow , lhs , scalar , np .power (lhs , scalar ))
379
415
380
416
lhs = np .array ([4 + 2j , - 3 - 1j , 2j , 1 ], dtype = dtype )
381
417
rhs = np .array ([5 , - 6j , 7 - 3j , - 8j ], dtype = dtype )
@@ -385,7 +421,9 @@ def testComplexOps(self):
385
421
self ._testBinary (
386
422
gen_math_ops ._sigmoid_grad , lhs , rhs , expected = rhs * lhs * (1 - lhs ))
387
423
388
- # TODO(b/65408531): support+test _rsqrt_grad for cplx (needs pow)
424
+ if atan2_supported :
425
+ self ._testBinary (
426
+ gen_math_ops ._rsqrt_grad , lhs , rhs , expected = lhs ** 3 * rhs / - 2 )
389
427
390
428
self ._testBinary (
391
429
gen_math_ops ._sqrt_grad , lhs , rhs , expected = rhs / (2 * lhs ))
0 commit comments