18
18
from __future__ import division
19
19
from __future__ import print_function
20
20
21
+ import numpy as np
22
+
21
23
from tensorflow .core .protobuf import config_pb2
22
24
from tensorflow .core .protobuf import rewriter_config_pb2
25
+ from tensorflow .core .protobuf import saver_pb2
23
26
from tensorflow .python .client import session
24
27
from tensorflow .python .framework import constant_op
25
28
from tensorflow .python .framework import dtypes
32
35
from tensorflow .python .ops import math_ops
33
36
from tensorflow .python .ops import nn
34
37
from tensorflow .python .ops import random_ops
38
+ from tensorflow .python .ops import variables
35
39
from tensorflow .python .platform import test
36
40
from tensorflow .python .training import gradient_descent
37
- from tensorflow .python .training import saver
41
+ from tensorflow .python .training import saver as saver_lib
38
42
39
43
40
44
def weight (shape ):
@@ -83,9 +87,9 @@ def loop():
83
87
return outputs
84
88
85
89
86
- def get_config ():
90
+ def get_config (layout_optimizer = True ):
87
91
rewrite_options = rewriter_config_pb2 .RewriterConfig (
88
- optimize_tensor_layout = True )
92
+ optimize_tensor_layout = layout_optimizer )
89
93
graph_options = config_pb2 .GraphOptions (
90
94
rewrite_options = rewrite_options , build_cost_model = 1 )
91
95
config = config_pb2 .ConfigProto (graph_options = graph_options )
@@ -95,6 +99,41 @@ def get_config():
95
99
class LayoutOptimizerTest (test .TestCase ):
96
100
"""Tests the Grappler layout optimizer."""
97
101
102
+ def _train (self , checkpoint_path , layout_optimizer = False , restore = False ):
103
+ ops .reset_default_graph ()
104
+ graph = ops .get_default_graph ()
105
+ with session .Session (
106
+ config = get_config (layout_optimizer ), graph = graph ) as sess :
107
+ batch = 2
108
+ height = 6
109
+ width = 7
110
+ input_channels = 3
111
+ shape = [batch , height , width , input_channels ]
112
+ image = array_ops .placeholder (dtype = 'float32' , shape = shape )
113
+ conv1 = conv_layers .conv2d (image , 32 , [3 , 3 ])
114
+ conv2 = conv_layers .conv2d (conv1 , 32 , [3 , 3 ])
115
+ optimizer = gradient_descent .GradientDescentOptimizer (0.01 )
116
+ loss = math_ops .reduce_mean (conv2 )
117
+ train_op = optimizer .minimize (loss )
118
+ saver = saver_lib .Saver (write_version = saver_pb2 .SaverDef .V2 )
119
+
120
+ if restore :
121
+ saver .restore (sess , checkpoint_path )
122
+ else :
123
+ sess .run (variables .global_variables_initializer ())
124
+
125
+ np .random .seed (0 )
126
+ for _ in range (2 ):
127
+ image_val = np .random .rand (* shape ).astype (np .float32 )
128
+ sess .run ([loss , train_op ], feed_dict = {image : image_val })
129
+
130
+ if restore :
131
+ all_vars = ops .get_collection (ops .GraphKeys .GLOBAL_VARIABLES )
132
+ all_vars_values = [var .eval (session = sess ) for var in all_vars ]
133
+ return all_vars_values
134
+ else :
135
+ saver .save (sess , checkpoint_path )
136
+
98
137
def testTwoConvLayers (self ):
99
138
if test .is_gpu_available (cuda_only = True ):
100
139
random_seed .set_random_seed (0 )
@@ -152,7 +191,7 @@ def testGradient(self):
152
191
train_op = optimizer .minimize (loss )
153
192
graph = ops .get_default_graph ()
154
193
graph .add_to_collection ('train_op' , train_op )
155
- meta_graph = saver .export_meta_graph (graph_def = graph .as_graph_def ())
194
+ meta_graph = saver_lib .export_meta_graph (graph_def = graph .as_graph_def ())
156
195
157
196
rewrite_options = rewriter_config_pb2 .RewriterConfig (
158
197
optimize_tensor_layout = True )
@@ -165,6 +204,17 @@ def testGradient(self):
165
204
self .assertEqual (node .attr ['data_format' ].s , 'NCHW' )
166
205
self .assertEqual (found , 5 )
167
206
207
+ def testCheckpointCompatibility (self ):
208
+ checkpoint_path = self .get_temp_dir ()
209
+ self ._train (checkpoint_path )
210
+ vars_expected = self ._train (checkpoint_path , restore = True )
211
+ vars_layout_optimized = self ._train (
212
+ checkpoint_path , restore = True , layout_optimizer = True )
213
+
214
+ for var_expected , var_layout_optimized in zip (vars_expected ,
215
+ vars_layout_optimized ):
216
+ self .assertAllClose (var_expected , var_layout_optimized , atol = 1e-6 )
217
+
168
218
169
219
if __name__ == '__main__' :
170
220
test .main ()
0 commit comments