Skip to content

Commit 904f76e

Browse files
ebrevdotensorflower-gardener
authored andcommitted
Add custom_getter to variable scope.
Change: 128486774
1 parent c579101 commit 904f76e

File tree

2 files changed

+199
-52
lines changed

2 files changed

+199
-52
lines changed

tensorflow/python/kernel_tests/variable_scope_test.py

+49
Original file line numberDiff line numberDiff line change
@@ -673,5 +673,54 @@ def _part_axis_1(**unused_kwargs):
673673
self.assertEqual(n1_0.get_shape(), (2, 1, 2))
674674
self.assertEqual(n1_1.get_shape(), (2, 1, 2))
675675

676+
677+
class VariableScopeWithCustomGetterTest(tf.test.TestCase):
678+
679+
def testNonCallableGetterFails(self):
680+
with self.assertRaisesRegexp(ValueError, r"custom_getter .* not callable:"):
681+
with tf.variable_scope("scope0", custom_getter=3):
682+
tf.get_variable("name0")
683+
with self.assertRaisesRegexp(ValueError, r"custom_getter .* not callable:"):
684+
tf.get_variable("name0", custom_getter=3)
685+
686+
def testNoSideEffectsWithIdentityCustomGetter(self):
687+
called = [0]
688+
def custom_getter(getter, *args, **kwargs):
689+
called[0] += 1
690+
return getter(*args, **kwargs)
691+
with tf.variable_scope("scope", custom_getter=custom_getter) as scope:
692+
v = tf.get_variable("v", [1])
693+
with tf.variable_scope(scope, reuse=True):
694+
v2 = tf.get_variable("v", [1])
695+
with tf.variable_scope("new_scope") as new_scope:
696+
v3 = tf.get_variable("v3", [1])
697+
with tf.variable_scope(new_scope, reuse=True, custom_getter=custom_getter):
698+
v4 = tf.get_variable("v3", [1])
699+
700+
self.assertEqual(v, v2)
701+
self.assertEqual(v3, v4)
702+
self.assertEqual(3, called[0]) # skipped one in the first new_scope
703+
704+
def testGetterThatCreatesTwoVariablesAndSumsThem(self):
705+
def custom_getter(getter, name, *args, **kwargs):
706+
g_0 = getter("%s/0" % name, *args, **kwargs)
707+
g_1 = getter("%s/1" % name, *args, **kwargs)
708+
with tf.name_scope("custom_getter"):
709+
return g_0 + g_1
710+
711+
with tf.variable_scope("scope", custom_getter=custom_getter):
712+
v = tf.get_variable("v", [1, 2, 3])
713+
714+
self.assertEqual([1, 2, 3], v.get_shape())
715+
true_vars = tf.trainable_variables()
716+
self.assertEqual(2, len(true_vars))
717+
self.assertEqual("scope/v/0:0", true_vars[0].name)
718+
self.assertEqual("scope/v/1:0", true_vars[1].name)
719+
self.assertEqual("custom_getter/add:0", v.name)
720+
with self.test_session() as sess:
721+
tf.initialize_all_variables().run()
722+
np_vars, np_v = sess.run([true_vars, v])
723+
self.assertAllClose(np_v, sum(np_vars))
724+
676725
if __name__ == "__main__":
677726
tf.test.main()

0 commit comments

Comments
 (0)