Skip to content

Commit

Permalink
Replace get_shape with shape everywhere. Remove trailing white spaces.
Browse files Browse the repository at this point in the history
  • Loading branch information
vahidk committed Aug 12, 2017
1 parent b48fcc4 commit 590ba71
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 22 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ import tensorflow as tf

def get_shape(tensor):
"""Returns static shape if available and dynamic shape otherwise."""
static_shape = tensor.get_shape().as_list()
static_shape = tensor.shape.as_list()
dynamic_shape = tf.unstack(tf.shape(tensor))
dims = [s[1] if s[0] is None else s[0]
for s in zip(static_shape, dynamic_shape)]
Expand All @@ -777,7 +777,7 @@ def batch_gather(tensor, indices):
shape = get_shape(tensor)
flat_first = tf.reshape(tensor, [shape[0] * shape[1]] + shape[2:])
indices = tf.convert_to_tensor(indices)
offset_shape = [shape[0]] + [1] * (indices.get_shape().ndims - 1)
offset_shape = [shape[0]] + [1] * (indices.shape.ndims - 1)
offset = tf.reshape(tf.range(shape[0]) * shape[1], offset_shape)
output = tf.gather(flat_first, indices + offset)
return output
Expand All @@ -799,7 +799,7 @@ def rnn_beam_search(update_fn, initial_state, sequence_length, beam_width,
ids: Output indices.
logprobs: Output log probabilities probabilities.
"""
batch_size = initial_state.get_shape().as_list()[0]
batch_size = initial_state.shape.as_list()[0]

state = tf.tile(tf.expand_dims(initial_state, axis=1), [1, beam_width, 1])

Expand All @@ -820,7 +820,7 @@ def rnn_beam_search(update_fn, initial_state, sequence_length, beam_width,
tf.expand_dims(sel_sum_logprobs, axis=2) +
(logits * tf.expand_dims(mask, axis=2)))

num_classes = logits.get_shape().as_list()[-1]
num_classes = logits.shape.as_list()[-1]

sel_sum_logprobs, indices = tf.nn.top_k(
tf.reshape(sum_logprobs, [batch_size, num_classes * beam_width]),
Expand Down
20 changes: 10 additions & 10 deletions code/framework/common/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

def get_shape(tensor):
"""Returns static shape if available and dynamic shape otherwise."""
static_shape = tensor.get_shape().as_list()
static_shape = tensor.shape.as_list()
dynamic_shape = tf.unstack(tf.shape(tensor))
dims = [s[1] if s[0] is None else s[0]
for s in zip(static_shape, dynamic_shape)]
Expand All @@ -30,12 +30,12 @@ def reshape(tensor, dims_list):
return tensor


def dense_layers(tensor,
def dense_layers(tensor,
sizes,
activation=tf.nn.relu,
linear_top_layer=False,
drop_rate=0.0,
name=None,
name=None,
**kwargs):
"""Builds a stack of fully connected layers with optional dropout."""
with tf.variable_scope(name, default_name='dense_layers'):
Expand All @@ -52,19 +52,19 @@ def dense_layers(tensor,
return tensor


def conv_layers(tensor,
filters,
kernels,
pools,
padding="same",
def conv_layers(tensor,
filters,
kernels,
pools,
padding="same",
activation=tf.nn.relu,
drop_rate=0.0,
**kwargs):
for fs, ks, ps in zip(filters, kernels, pools):
tensor = tf.layers.dropout(tensor, drop_rate)
tensor = tf.layers.conv2d(
tensor,
filters=fs,
tensor,
filters=fs,
kernel_size=ks,
padding=padding,
activation=activation,
Expand Down
10 changes: 5 additions & 5 deletions code/framework/dataset/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ def _download_data():
if not os.path.exists(LOCAL_DIR):
os.makedirs(LOCAL_DIR)
for name in [
TRAIN_IMAGE_URL,
TRAIN_LABEL_URL,
TEST_IMAGE_URL,
TRAIN_IMAGE_URL,
TRAIN_LABEL_URL,
TEST_IMAGE_URL,
TEST_LABEL_URL]:
if not os.path.exists(LOCAL_DIR + name):
urllib.request.urlretrieve(REMOTE_URL + name, LOCAL_DIR + name)
urllib.request.urlretrieve(REMOTE_URL + name, LOCAL_DIR + name)


def _image_iterator(split):
Expand All @@ -69,7 +69,7 @@ def _image_iterator(split):
tf.estimator.ModeKeys.EVAL: TEST_IMAGE_URL
}[split]
label_urls = {
tf.estimator.ModeKeys.TRAIN: TRAIN_LABEL_URL,
tf.estimator.ModeKeys.TRAIN: TRAIN_LABEL_URL,
tf.estimator.ModeKeys.EVAL: TEST_LABEL_URL
}[split]

Expand Down
6 changes: 3 additions & 3 deletions code/framework/model/convnet_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ def model_fn(features, labels, mode, params):
drop_rate = params.drop_rate if mode == tf.estimator.ModeKeys.TRAIN else 0.0

features = ops.conv_layers(
images,
filters=[32, 64, 128],
images,
filters=[32, 64, 128],
kernels=[3, 3, 3],
pools=[2, 2, 2])

features = tf.contrib.layers.flatten(features)

logits = ops.dense_layers(
features, [512, params.num_classes],
features, [512, params.num_classes],
drop_rate=drop_rate,
linear_top_layer=True)

Expand Down

0 comments on commit 590ba71

Please sign in to comment.