Skip to content

Commit 90b5434

Browse files
tamaranormantensorflower-gardener
authored andcommitted
Fixed positional test for object test cases and a lint error in the upgrade script
PiperOrigin-RevId: 230198038
1 parent f7dd78d commit 90b5434

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

tensorflow/tools/compatibility/tf_upgrade_v2.py

-4
Original file line numberDiff line numberDiff line change
@@ -896,14 +896,10 @@ def __init__(self):
896896
make_initializable_iterator_deprecation,
897897
"*.make_one_shot_iterator":
898898
make_one_shot_iterator_deprecation,
899-
"tf.assert_greater":
900-
assert_return_type_comment,
901899
"tf.assert_equal":
902900
assert_return_type_comment,
903901
"tf.assert_none_equal":
904902
assert_return_type_comment,
905-
"tf.assert_less":
906-
assert_return_type_comment,
907903
"tf.assert_negative":
908904
assert_return_type_comment,
909905
"tf.assert_positive":

tensorflow/tools/compatibility/tf_upgrade_v2_test.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -296,14 +296,20 @@ def conversion_visitor(unused_path, unused_parent, children):
296296
def testPositionsMatchArgGiven(self):
297297
full_dict = tf_upgrade_v2.TFAPIChangeSpec().function_arg_warnings
298298
method_names = full_dict.keys()
299-
for method in method_names:
300-
# doesn't test methods on objects
301-
if not method.startswith("*."):
302-
args = full_dict[method].keys()
303-
method = get_symbol_for_name(tf, method)
304-
arg_spec = tf_inspect.getfullargspec(method)
305-
for (arg, pos) in args:
306-
self.assertEqual(arg_spec[0][pos], arg)
299+
for method_name in method_names:
300+
args = full_dict[method_name].keys()
301+
# special case for optimizer methods
302+
if method_name.startswith("*."):
303+
method = method_name.replace("*", "tf.train.Optimizer")
304+
else:
305+
method = method_name
306+
method = get_symbol_for_name(tf, method)
307+
arg_spec = tf_inspect.getfullargspec(method)
308+
for (arg, pos) in args:
309+
# to deal with the self argument on methods on objects
310+
if method_name.startswith("*."):
311+
pos += 1
312+
self.assertEqual(arg_spec[0][pos], arg)
307313

308314
def testReorderFileNeedsUpdate(self):
309315
reordered_function_names = (

0 commit comments

Comments
 (0)