|
30 | 30 | from tensorflow.compiler.tests import xla_test
|
31 | 31 | from tensorflow.python.framework import dtypes
|
32 | 32 | from tensorflow.python.framework import ops
|
| 33 | +from tensorflow.python.framework import test_util |
33 | 34 | from tensorflow.python.ops import array_ops
|
34 | 35 | from tensorflow.python.ops import gen_image_ops
|
35 | 36 | from tensorflow.python.ops import image_ops
|
@@ -774,6 +775,7 @@ def testNonAlignCorners3x2To6x4Batch2(self):
|
774 | 775 |
|
775 | 776 | class NonMaxSuppressionTest(xla_test.XLATestCase):
|
776 | 777 |
|
| 778 | + @test_util.disable_mlir_bridge("%1") |
777 | 779 | def testNMS128From1024(self):
|
778 | 780 | num_boxes = 1024
|
779 | 781 | boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4")
|
@@ -808,6 +810,7 @@ def testNMS128From1024(self):
|
808 | 810 |
|
809 | 811 | self.assertEqual(indices_tf.size, max_output_size)
|
810 | 812 |
|
| 813 | + @test_util.disable_mlir_bridge("%1") |
811 | 814 | def testNMS3From6Boxes(self):
|
812 | 815 | # Three boxes are selected based on IOU.
|
813 | 816 | boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
|
@@ -849,6 +852,7 @@ def testNMS3From6Boxes(self):
|
849 | 852 | self.assertEqual(num_valid, 3)
|
850 | 853 | self.assertAllClose(indices_tf[:num_valid], [3, 0, 5])
|
851 | 854 |
|
| 855 | + @test_util.disable_mlir_bridge("%1") |
852 | 856 | def testNMS3Then2WithScoreThresh(self):
|
853 | 857 | # Three boxes are selected based on IOU.
|
854 | 858 | # One is filtered out by score threshold.
|
@@ -891,6 +895,7 @@ def testNMS3Then2WithScoreThresh(self):
|
891 | 895 | self.assertEqual(num_valid, 2)
|
892 | 896 | self.assertAllClose(indices_tf[:num_valid], [3, 0])
|
893 | 897 |
|
| 898 | + @test_util.disable_mlir_bridge("%1") |
894 | 899 | def testNMS3Then1WithScoreMaxThresh(self):
|
895 | 900 | # Three boxes are selected based on IOU.
|
896 | 901 | # One is filtered out by score threshold.
|
@@ -934,6 +939,7 @@ def testNMS3Then1WithScoreMaxThresh(self):
|
934 | 939 | self.assertEqual(num_valid, 1)
|
935 | 940 | self.assertAllClose(indices_tf[:num_valid], [3])
|
936 | 941 |
|
| 942 | + @test_util.disable_mlir_bridge("%1") |
937 | 943 | def testSelectFromContinuousOverLap(self):
|
938 | 944 | # Tests that a suppressed box does not itself suppress other boxes.
|
939 | 945 |
|
@@ -978,6 +984,7 @@ def testSelectFromContinuousOverLap(self):
|
978 | 984 |
|
979 | 985 | class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
980 | 986 |
|
| 987 | + @test_util.disable_mlir_bridge("%1") |
981 | 988 | def testBatchedNMSFrom6(self):
|
982 | 989 | boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
983 | 990 | [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
@@ -1015,6 +1022,7 @@ def testBatchedNMSFrom6(self):
|
1015 | 1022 | indices_output)
|
1016 | 1023 | self.assertAllEqual([5, 4], num_valid_output)
|
1017 | 1024 |
|
| 1025 | + @test_util.disable_mlir_bridge("%1") |
1018 | 1026 | def testBatchedNMSFrom6Max3(self):
|
1019 | 1027 | boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
1020 | 1028 | [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
@@ -1048,6 +1056,7 @@ def testBatchedNMSFrom6Max3(self):
|
1048 | 1056 | self.assertAllEqual([[0, 1, 2], [0, 1, 3]], indices_output)
|
1049 | 1057 | self.assertAllEqual([3, 3], num_valid_output)
|
1050 | 1058 |
|
| 1059 | + @test_util.disable_mlir_bridge("%1") |
1051 | 1060 | def testBatchedNMSSingleFrom6Max3(self):
|
1052 | 1061 | boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
1053 | 1062 | [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
|
@@ -1078,6 +1087,7 @@ def testBatchedNMSSingleFrom6Max3(self):
|
1078 | 1087 | self.assertAllEqual([0, 1, 2], indices_output)
|
1079 | 1088 | self.assertAllEqual(3, num_valid_output)
|
1080 | 1089 |
|
| 1090 | + @test_util.disable_mlir_bridge("%1") |
1081 | 1091 | def testBatchedNMSSingleFrom6NoPad(self):
|
1082 | 1092 | boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
1083 | 1093 | [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
|
@@ -1107,6 +1117,7 @@ def testBatchedNMSSingleFrom6NoPad(self):
|
1107 | 1117 | self.assertAllEqual([0, 1, 2, 4, 5], indices_output)
|
1108 | 1118 | self.assertAllEqual(5, num_valid_output)
|
1109 | 1119 |
|
| 1120 | + @test_util.disable_mlir_bridge("%1") |
1110 | 1121 | def testBatchedNMSBatchDimsFrom6Max3(self):
|
1111 | 1122 | boxes_data = [[[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
1112 | 1123 | [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
@@ -1140,6 +1151,7 @@ def testBatchedNMSBatchDimsFrom6Max3(self):
|
1140 | 1151 | self.assertAllEqual([[[0, 1, 2], [0, 1, 3]]], indices_output)
|
1141 | 1152 | self.assertAllEqual([[3, 3]], num_valid_output)
|
1142 | 1153 |
|
| 1154 | + @test_util.disable_mlir_bridge("%1") |
1143 | 1155 | def testBatchedNMSScoreThresholdFrom6Max3(self):
|
1144 | 1156 | boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
1145 | 1157 | [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
@@ -1175,6 +1187,7 @@ def testBatchedNMSScoreThresholdFrom6Max3(self):
|
1175 | 1187 | self.assertAllEqual([3, 2], num_valid_output)
|
1176 | 1188 | self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output)
|
1177 | 1189 |
|
| 1190 | + @test_util.disable_mlir_bridge("%1") |
1178 | 1191 | def testBatchedNMSUnsortedInputFrom6(self):
|
1179 | 1192 | boxes_data = [[[0, 2, 1, 2], [3, 3, 4, 4], [0, 0, 1, 1],
|
1180 | 1193 | [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8]],
|
@@ -1211,6 +1224,7 @@ def testBatchedNMSUnsortedInputFrom6(self):
|
1211 | 1224 | indices_output)
|
1212 | 1225 | self.assertAllEqual([5, 4], num_valid_output)
|
1213 | 1226 |
|
| 1227 | + @test_util.disable_mlir_bridge("%1") |
1214 | 1228 | def testBatchedNMSNoncanonicalizedInputFrom6(self):
|
1215 | 1229 | boxes_data = [[[1, 0, 0, 1], [4, 3, 3, 4], [1, 0.4, 0, 1.4],
|
1216 | 1230 | [1, 0.6, 0, 1.6], [1, 0.8, 0, 1.8], [1, 2, 0, 2]],
|
@@ -1248,6 +1262,7 @@ def testBatchedNMSNoncanonicalizedInputFrom6(self):
|
1248 | 1262 | indices_output)
|
1249 | 1263 | self.assertAllEqual([5, 4], num_valid_output)
|
1250 | 1264 |
|
| 1265 | + @test_util.disable_mlir_bridge("%1") |
1251 | 1266 | def testBatchedNMSScoreThresholdCanInputsFrom6Max3(self):
|
1252 | 1267 | boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
1253 | 1268 | [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
@@ -1283,6 +1298,7 @@ def testBatchedNMSScoreThresholdCanInputsFrom6Max3(self):
|
1283 | 1298 | self.assertAllEqual([3, 2], num_valid_output)
|
1284 | 1299 | self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output)
|
1285 | 1300 |
|
| 1301 | + @test_util.disable_mlir_bridge("%1") |
1286 | 1302 | def testBatchedNMSFrom6DynamicInput(self):
|
1287 | 1303 | boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
1288 | 1304 | [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
|
0 commit comments