Skip to content

Commit 69638e5

Browse files
author
Maciek Chociej
authored
Merge pull request tensorflow#7676 from yaroslavvb/variable_shape_fix
Add .shape property to Variable object
2 parents c2fc604 + 8ce4fd7 commit 69638e5

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

tensorflow/python/kernel_tests/variables_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,13 @@ def testInitialization(self):
4646
self.assertEqual("Variable:0", var0.name)
4747
self.assertEqual([], var0.get_shape())
4848
self.assertEqual([], var0.get_shape())
49+
self.assertEqual([], var0.shape)
4950

5051
var1 = variables.Variable(1.1)
5152
self.assertEqual("Variable_1:0", var1.name)
5253
self.assertEqual([], var1.get_shape())
5354
self.assertEqual([], var1.get_shape())
55+
self.assertEqual([], var1.shape)
5456

5557
with self.assertRaisesOpError("Attempting to use uninitialized value"):
5658
var0.eval()
@@ -69,11 +71,13 @@ def testInitializationOrder(self):
6971
self.assertEqual("rnd:0", rnd.name)
7072
self.assertEqual([3, 6], rnd.get_shape())
7173
self.assertEqual([3, 6], rnd.get_shape())
74+
self.assertEqual([3, 6], rnd.shape)
7275

7376
dep = variables.Variable(rnd.initialized_value(), name="dep")
7477
self.assertEqual("dep:0", dep.name)
7578
self.assertEqual([3, 6], dep.get_shape())
7679
self.assertEqual([3, 6], dep.get_shape())
80+
self.assertEqual([3, 6], dep.shape)
7781

7882
# Currently have to set the shape manually for Add.
7983
added_val = rnd.initialized_value() + dep.initialized_value() + 2.0
@@ -83,6 +87,7 @@ def testInitializationOrder(self):
8387
self.assertEqual("depdep:0", depdep.name)
8488
self.assertEqual([3, 6], depdep.get_shape())
8589
self.assertEqual([3, 6], depdep.get_shape())
90+
self.assertEqual([3, 6], depdep.shape)
8691

8792
variables.global_variables_initializer().run()
8893

@@ -375,13 +380,15 @@ def testInitializerFunction(self):
375380

376381
v1 = variables.Variable(initializer, dtype=dtypes.float32)
377382
self.assertEqual(shape, v1.get_shape())
383+
self.assertEqual(shape, v1.shape)
378384
self.assertAllClose(value, v1.initial_value.eval())
379385
with self.assertRaises(errors_impl.FailedPreconditionError):
380386
v1.eval()
381387

382388
v2 = variables.Variable(
383389
math_ops.negative(v1.initialized_value()), dtype=dtypes.float32)
384390
self.assertEqual(v1.get_shape(), v2.get_shape())
391+
self.assertEqual(v1.shape, v2.shape)
385392
self.assertAllClose(np.negative(value), v2.initial_value.eval())
386393

387394
# Once v2.initial_value.eval() has been called, v1 has effectively been
@@ -532,6 +539,7 @@ def testPartitionedVariable(self):
532539
self.assertEqual(2, num_partitions)
533540
self.assertEqual([v0, v1], iterated_partitions)
534541
self.assertEqual([2], concatenated.get_shape())
542+
self.assertEqual([2], concatenated.shape)
535543

536544
def testPartitionedVariableFailures(self):
537545
with ops.Graph().as_default():

tensorflow/python/ops/variables.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,14 +723,19 @@ def graph(self):
723723
"""The `Graph` of this variable."""
724724
return self._variable.graph
725725

726-
def get_shape(self):
726+
@property
727+
def shape(self):
727728
"""The `TensorShape` of this variable.
728729
729730
Returns:
730731
A `TensorShape`.
731732
"""
732733
return self._variable.get_shape()
733734

735+
def get_shape(self):
736+
"""Alias of Variable.shape."""
737+
return self.shape
738+
734739
def to_proto(self, export_scope=None):
735740
"""Converts a `Variable` to a `VariableDef` protocol buffer.
736741

0 commit comments

Comments
 (0)