Skip to content

Commit

Permalink
Rename --version or --resnet_version (tensorflow#4165)
Browse files Browse the repository at this point in the history
* rename --version flag and fix tests to correctly specify version rather than verbosity

* rename version to resnet_version throughout

* fix bugs

* delint

* missed layer_test

* fix indent
  • Loading branch information
Taylor Robie authored May 4, 2018
1 parent eb0c0df commit 5be3c06
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 98 deletions.
10 changes: 5 additions & 5 deletions official/resnet/cifar10_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class Cifar10Model(resnet_model.Model):
"""Model class with appropriate defaults for CIFAR-10 data."""

def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
version=resnet_model.DEFAULT_VERSION,
resnet_version=resnet_model.DEFAULT_VERSION,
dtype=resnet_model.DEFAULT_DTYPE):
"""These are the parameters that work for CIFAR-10 data.
Expand All @@ -150,8 +150,8 @@ def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
data format to use when setting up the model.
num_classes: The number of output classes needed from the model. This
enables users to extend the same model to their own datasets.
version: Integer representing which version of the ResNet network to use.
See README for details. Valid values: [1, 2]
resnet_version: Integer representing which version of the ResNet network
to use. See README for details. Valid values: [1, 2]
dtype: The TensorFlow dtype to use for calculations.
Raises:
Expand All @@ -174,7 +174,7 @@ def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
block_sizes=[num_blocks] * 3,
block_strides=[1, 2, 2],
final_size=64,
version=version,
resnet_version=resnet_version,
data_format=data_format,
dtype=dtype
)
Expand Down Expand Up @@ -211,7 +211,7 @@ def loss_filter_fn(_):
learning_rate_fn=learning_rate_fn,
momentum=0.9,
data_format=params['data_format'],
version=params['version'],
resnet_version=params['resnet_version'],
loss_scale=params['loss_scale'],
loss_filter_fn=loss_filter_fn,
dtype=params['dtype']
Expand Down
33 changes: 17 additions & 16 deletions official/resnet/cifar10_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_dataset_input_fn(self):
for pixel in row:
self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)

def cifar10_model_fn_helper(self, mode, version, dtype):
def cifar10_model_fn_helper(self, mode, resnet_version, dtype):
input_fn = cifar10_main.get_synth_input_fn()
dataset = input_fn(True, '', _BATCH_SIZE)
iterator = dataset.make_one_shot_iterator()
Expand All @@ -87,7 +87,7 @@ def cifar10_model_fn_helper(self, mode, version, dtype):
'resnet_size': 32,
'data_format': 'channels_last',
'batch_size': _BATCH_SIZE,
'version': version,
'resnet_version': resnet_version,
'loss_scale': 128 if dtype == tf.float16 else 1,
})

Expand All @@ -111,56 +111,57 @@ def cifar10_model_fn_helper(self, mode, version, dtype):
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)

def test_cifar10_model_fn_train_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1,
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, resnet_version=1,
dtype=tf.float32)

def test_cifar10_model_fn_trainmode__v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2,
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, resnet_version=2,
dtype=tf.float32)

def test_cifar10_model_fn_eval_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1,
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, resnet_version=1,
dtype=tf.float32)

def test_cifar10_model_fn_eval_mode_v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2,
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, resnet_version=2,
dtype=tf.float32)

def test_cifar10_model_fn_predict_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1,
dtype=tf.float32)
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT,
resnet_version=1, dtype=tf.float32)

def test_cifar10_model_fn_predict_mode_v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2,
dtype=tf.float32)
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT,
resnet_version=2, dtype=tf.float32)

def _test_cifar10model_shape(self, version):
def _test_cifar10model_shape(self, resnet_version):
batch_size = 135
num_classes = 246

model = cifar10_main.Cifar10Model(32, data_format='channels_last',
num_classes=num_classes, version=version)
num_classes=num_classes,
resnet_version=resnet_version)
fake_input = tf.random_uniform([batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS])
output = model(fake_input, training=True)

self.assertAllEqual(output.shape, (batch_size, num_classes))

def test_cifar10model_shape_v1(self):
self._test_cifar10model_shape(version=1)
self._test_cifar10model_shape(resnet_version=1)

def test_cifar10model_shape_v2(self):
self._test_cifar10model_shape(version=2)
self._test_cifar10model_shape(resnet_version=2)

def test_cifar10_end_to_end_synthetic_v1(self):
integration.run_synthetic(
main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1']
extra_flags=['-resnet_version', '1']
)

def test_cifar10_end_to_end_synthetic_v2(self):
integration.run_synthetic(
main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2']
extra_flags=['-resnet_version', '2']
)


Expand Down
10 changes: 5 additions & 5 deletions official/resnet/imagenet_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ class ImagenetModel(resnet_model.Model):
"""Model class with appropriate defaults for Imagenet data."""

def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
version=resnet_model.DEFAULT_VERSION,
resnet_version=resnet_model.DEFAULT_VERSION,
dtype=resnet_model.DEFAULT_DTYPE):
"""These are the parameters that work for Imagenet data.
Expand All @@ -207,8 +207,8 @@ def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
data format to use when setting up the model.
num_classes: The number of output classes needed from the model. This
enables users to extend the same model to their own datasets.
version: Integer representing which version of the ResNet network to use.
See README for details. Valid values: [1, 2]
resnet_version: Integer representing which version of the ResNet network
to use. See README for details. Valid values: [1, 2]
dtype: The TensorFlow dtype to use for calculations.
"""

Expand All @@ -232,7 +232,7 @@ def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
block_sizes=_get_block_sizes(resnet_size),
block_strides=[1, 2, 2, 2],
final_size=final_size,
version=version,
resnet_version=resnet_version,
data_format=data_format,
dtype=dtype
)
Expand Down Expand Up @@ -289,7 +289,7 @@ def imagenet_model_fn(features, labels, mode, params):
learning_rate_fn=learning_rate_fn,
momentum=0.9,
data_format=params['data_format'],
version=params['version'],
resnet_version=params['resnet_version'],
loss_scale=params['loss_scale'],
loss_filter_fn=None,
dtype=params['dtype']
Expand Down
Loading

0 comments on commit 5be3c06

Please sign in to comment.