Skip to content

Commit

Permalink
Merge pull request vahidk#2 from vrv/patch-1
Browse files Browse the repository at this point in the history
Small clarifications on the shape section
  • Loading branch information
vahidk authored Aug 12, 2017
2 parents 93a23b1 + cd6220e commit b48fcc4
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,19 @@ This is just tip of the iceberg for what Tensorflow can do. Many problems such a

## Understanding static and dynamic shapes
<a name="shapes"></a>
Tensors in Tensorflow have a static shape attribute which is determined during graph construction. The static shape may be underspecified. For example we might define a tensor of shape [None, 128]:
Tensors in Tensorflow have a static shape attribute which is determined during graph construction. The static shape may be underspecified. For example we might define a float32 tensor of shape [None, 128]:
```python
import tensorflow as tf

a = tf.placeholder([None, 128])
a = tf.placeholder(tf.float32, [None, 128])
```
This means that the first dimension can be of any size and will be determined dynamically during Session.run(). Tensorflow has a rather ugly API for exposing the static shape:
This means that the first dimension can be of any size and will be determined dynamically during Session.run(). You
can query the symbolic shape of a Tensor as follows:

```python
static_shape = a.get_shape().as_list() # returns [None, 128]
static_shape = a.shape # returns TensorShape([Dimension(None), Dimension(128)])
static_shape = a.shape.as_list() # returns [None, 128]
```
(This used to be a.shape but someone decided it's too convenient.)

To get the dynamic shape of the tensor you can call tf.shape op, which returns a tensor representing the shape of the given tensor:
```python
Expand All @@ -130,14 +132,18 @@ The static shape of a tensor can be set with Tensor.set_shape() method:
a.set_shape([32, 128])
```
Use this function only if you know what you are doing, in practice it's safer to do dynamic reshaping with tf.reshape() op:

```python
a = tf.reshape(a, [32, 128])
```

If you feed 'a' with values that don't match the shape, you will get an InvalidArgumentError indicating that the
number of values fed doesn't match the expected shape.

It can be convenient to have a function that returns the static shape when available and dynamic shape when it's not. The following utility function does just that:
```python
def get_shape(tensor):
static_shape = tensor.get_shape().as_list()
static_shape = tensor.shape.as_list()
dynamic_shape = tf.unstack(tf.shape(tensor))
dims = [s[1] if s[0] is None else s[0]
for s in zip(static_shape, dynamic_shape)]
Expand All @@ -146,7 +152,7 @@ def get_shape(tensor):

Now imagine we want to convert a Tensor of rank 3 to a tensor of rank 2 by collapsing the second and third dimensions into one. We can use our get_shape() function to do that:
```python
b = placeholder([None, 10, 32])
b = placeholder(tf.float32, [None, 10, 32])
shape = get_shape(tensor)
b = tf.reshape(b, [shape[0], shape[1] * shape[2]])
```
Expand All @@ -173,7 +179,7 @@ def reshape(tensor, dims_list):

Then collapsing the second dimension becomes very easy:
```python
b = placeholder([None, 10, 32])
b = placeholder(tf.float32, [None, 10, 32])
b = tf.reshape(b, [0, [1, 2]])
```

Expand Down

0 comments on commit b48fcc4

Please sign in to comment.