Skip to content

Commit 2c26c98

Browse files
Yao Zhangtensorflower-gardener
Yao Zhang
authored andcommitted
Add a checkpoint compatibility test for layout optimizer.
PiperOrigin-RevId: 175637014
1 parent ec60b0b commit 2c26c98

File tree

1 file changed

+54
-4
lines changed

1 file changed

+54
-4
lines changed

tensorflow/python/grappler/layout_optimizer_test.py

+54-4
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
import numpy as np
22+
2123
from tensorflow.core.protobuf import config_pb2
2224
from tensorflow.core.protobuf import rewriter_config_pb2
25+
from tensorflow.core.protobuf import saver_pb2
2326
from tensorflow.python.client import session
2427
from tensorflow.python.framework import constant_op
2528
from tensorflow.python.framework import dtypes
@@ -32,9 +35,10 @@
3235
from tensorflow.python.ops import math_ops
3336
from tensorflow.python.ops import nn
3437
from tensorflow.python.ops import random_ops
38+
from tensorflow.python.ops import variables
3539
from tensorflow.python.platform import test
3640
from tensorflow.python.training import gradient_descent
37-
from tensorflow.python.training import saver
41+
from tensorflow.python.training import saver as saver_lib
3842

3943

4044
def weight(shape):
@@ -83,9 +87,9 @@ def loop():
8387
return outputs
8488

8589

86-
def get_config():
90+
def get_config(layout_optimizer=True):
8791
rewrite_options = rewriter_config_pb2.RewriterConfig(
88-
optimize_tensor_layout=True)
92+
optimize_tensor_layout=layout_optimizer)
8993
graph_options = config_pb2.GraphOptions(
9094
rewrite_options=rewrite_options, build_cost_model=1)
9195
config = config_pb2.ConfigProto(graph_options=graph_options)
@@ -95,6 +99,41 @@ def get_config():
9599
class LayoutOptimizerTest(test.TestCase):
96100
"""Tests the Grappler layout optimizer."""
97101

102+
def _train(self, checkpoint_path, layout_optimizer=False, restore=False):
103+
ops.reset_default_graph()
104+
graph = ops.get_default_graph()
105+
with session.Session(
106+
config=get_config(layout_optimizer), graph=graph) as sess:
107+
batch = 2
108+
height = 6
109+
width = 7
110+
input_channels = 3
111+
shape = [batch, height, width, input_channels]
112+
image = array_ops.placeholder(dtype='float32', shape=shape)
113+
conv1 = conv_layers.conv2d(image, 32, [3, 3])
114+
conv2 = conv_layers.conv2d(conv1, 32, [3, 3])
115+
optimizer = gradient_descent.GradientDescentOptimizer(0.01)
116+
loss = math_ops.reduce_mean(conv2)
117+
train_op = optimizer.minimize(loss)
118+
saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2)
119+
120+
if restore:
121+
saver.restore(sess, checkpoint_path)
122+
else:
123+
sess.run(variables.global_variables_initializer())
124+
125+
np.random.seed(0)
126+
for _ in range(2):
127+
image_val = np.random.rand(*shape).astype(np.float32)
128+
sess.run([loss, train_op], feed_dict={image: image_val})
129+
130+
if restore:
131+
all_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
132+
all_vars_values = [var.eval(session=sess) for var in all_vars]
133+
return all_vars_values
134+
else:
135+
saver.save(sess, checkpoint_path)
136+
98137
def testTwoConvLayers(self):
99138
if test.is_gpu_available(cuda_only=True):
100139
random_seed.set_random_seed(0)
@@ -152,7 +191,7 @@ def testGradient(self):
152191
train_op = optimizer.minimize(loss)
153192
graph = ops.get_default_graph()
154193
graph.add_to_collection('train_op', train_op)
155-
meta_graph = saver.export_meta_graph(graph_def=graph.as_graph_def())
194+
meta_graph = saver_lib.export_meta_graph(graph_def=graph.as_graph_def())
156195

157196
rewrite_options = rewriter_config_pb2.RewriterConfig(
158197
optimize_tensor_layout=True)
@@ -165,6 +204,17 @@ def testGradient(self):
165204
self.assertEqual(node.attr['data_format'].s, 'NCHW')
166205
self.assertEqual(found, 5)
167206

207+
def testCheckpointCompatibility(self):
208+
checkpoint_path = self.get_temp_dir()
209+
self._train(checkpoint_path)
210+
vars_expected = self._train(checkpoint_path, restore=True)
211+
vars_layout_optimized = self._train(
212+
checkpoint_path, restore=True, layout_optimizer=True)
213+
214+
for var_expected, var_layout_optimized in zip(vars_expected,
215+
vars_layout_optimized):
216+
self.assertAllClose(var_expected, var_layout_optimized, atol=1e-6)
217+
168218

169219
if __name__ == '__main__':
170220
test.main()

0 commit comments

Comments
 (0)