@@ -296,14 +296,20 @@ def conversion_visitor(unused_path, unused_parent, children):
296
296
def testPositionsMatchArgGiven (self ):
297
297
full_dict = tf_upgrade_v2 .TFAPIChangeSpec ().function_arg_warnings
298
298
method_names = full_dict .keys ()
299
- for method in method_names :
300
- # doesn't test methods on objects
301
- if not method .startswith ("*." ):
302
- args = full_dict [method ].keys ()
303
- method = get_symbol_for_name (tf , method )
304
- arg_spec = tf_inspect .getfullargspec (method )
305
- for (arg , pos ) in args :
306
- self .assertEqual (arg_spec [0 ][pos ], arg )
299
+ for method_name in method_names :
300
+ args = full_dict [method_name ].keys ()
301
+ # special case for optimizer methods
302
+ if method_name .startswith ("*." ):
303
+ method = method_name .replace ("*" , "tf.train.Optimizer" )
304
+ else :
305
+ method = method_name
306
+ method = get_symbol_for_name (tf , method )
307
+ arg_spec = tf_inspect .getfullargspec (method )
308
+ for (arg , pos ) in args :
309
+ # to deal with the self argument on methods on objects
310
+ if method_name .startswith ("*." ):
311
+ pos += 1
312
+ self .assertEqual (arg_spec [0 ][pos ], arg )
307
313
308
314
def testReorderFileNeedsUpdate (self ):
309
315
reordered_function_names = (
0 commit comments