@@ -673,5 +673,54 @@ def _part_axis_1(**unused_kwargs):
673
673
self .assertEqual (n1_0 .get_shape (), (2 , 1 , 2 ))
674
674
self .assertEqual (n1_1 .get_shape (), (2 , 1 , 2 ))
675
675
676
+
677
+ class VariableScopeWithCustomGetterTest (tf .test .TestCase ):
678
+
679
+ def testNonCallableGetterFails (self ):
680
+ with self .assertRaisesRegexp (ValueError , r"custom_getter .* not callable:" ):
681
+ with tf .variable_scope ("scope0" , custom_getter = 3 ):
682
+ tf .get_variable ("name0" )
683
+ with self .assertRaisesRegexp (ValueError , r"custom_getter .* not callable:" ):
684
+ tf .get_variable ("name0" , custom_getter = 3 )
685
+
686
+ def testNoSideEffectsWithIdentityCustomGetter (self ):
687
+ called = [0 ]
688
+ def custom_getter (getter , * args , ** kwargs ):
689
+ called [0 ] += 1
690
+ return getter (* args , ** kwargs )
691
+ with tf .variable_scope ("scope" , custom_getter = custom_getter ) as scope :
692
+ v = tf .get_variable ("v" , [1 ])
693
+ with tf .variable_scope (scope , reuse = True ):
694
+ v2 = tf .get_variable ("v" , [1 ])
695
+ with tf .variable_scope ("new_scope" ) as new_scope :
696
+ v3 = tf .get_variable ("v3" , [1 ])
697
+ with tf .variable_scope (new_scope , reuse = True , custom_getter = custom_getter ):
698
+ v4 = tf .get_variable ("v3" , [1 ])
699
+
700
+ self .assertEqual (v , v2 )
701
+ self .assertEqual (v3 , v4 )
702
+ self .assertEqual (3 , called [0 ]) # skipped one in the first new_scope
703
+
704
+ def testGetterThatCreatesTwoVariablesAndSumsThem (self ):
705
+ def custom_getter (getter , name , * args , ** kwargs ):
706
+ g_0 = getter ("%s/0" % name , * args , ** kwargs )
707
+ g_1 = getter ("%s/1" % name , * args , ** kwargs )
708
+ with tf .name_scope ("custom_getter" ):
709
+ return g_0 + g_1
710
+
711
+ with tf .variable_scope ("scope" , custom_getter = custom_getter ):
712
+ v = tf .get_variable ("v" , [1 , 2 , 3 ])
713
+
714
+ self .assertEqual ([1 , 2 , 3 ], v .get_shape ())
715
+ true_vars = tf .trainable_variables ()
716
+ self .assertEqual (2 , len (true_vars ))
717
+ self .assertEqual ("scope/v/0:0" , true_vars [0 ].name )
718
+ self .assertEqual ("scope/v/1:0" , true_vars [1 ].name )
719
+ self .assertEqual ("custom_getter/add:0" , v .name )
720
+ with self .test_session () as sess :
721
+ tf .initialize_all_variables ().run ()
722
+ np_vars , np_v = sess .run ([true_vars , v ])
723
+ self .assertAllClose (np_v , sum (np_vars ))
724
+
676
725
if __name__ == "__main__" :
677
726
tf .test .main ()
0 commit comments