18
18
from __future__ import division
19
19
from __future__ import print_function
20
20
21
- from tensorflow .contrib .tpu .python .tpu import tpu
21
+ import os .path
22
+ import pickle
23
+ import tempfile
24
+
25
+ import numpy as np
22
26
23
- from tensorflow .python .client import session
27
+ from tensorflow .contrib .tpu .python .tpu import tpu
28
+ from tensorflow .contrib .tpu .python .tpu import tpu_config
29
+ from tensorflow .contrib .tpu .python .tpu import tpu_estimator
30
+ from tensorflow .core .protobuf import config_pb2
31
+ from tensorflow .python .client import session as tf_session
32
+ from tensorflow .python .estimator import model_fn as model_fn_lib
24
33
from tensorflow .python .framework import errors
25
34
from tensorflow .python .framework import ops
26
35
from tensorflow .python .framework import test_util
27
36
from tensorflow .python .ops import gen_array_ops
28
37
from tensorflow .python .ops import variables
38
+ from tensorflow .python .platform import gfile
39
+ from tensorflow .python .platform import tf_logging as logging
40
+ from tensorflow .python .training import saver as tf_saver
29
41
30
42
31
43
def has_tpu ():
@@ -38,8 +50,9 @@ def has_tpu():
38
50
Returns:
39
51
boolean, True if a TPU device is available, otherwise False.
40
52
"""
53
+
41
54
def _check ():
42
- with session .Session () as sess :
55
+ with tf_session .Session () as sess :
43
56
sess .run (tpu .initialize_system ())
44
57
sess .run (tpu .shutdown_system ())
45
58
@@ -61,14 +74,119 @@ def _available_devices():
61
74
return tuple (devices )
62
75
63
76
77
+ def copy_dir (src , tgt ):
78
+ """Copy src to tgt."""
79
+ gfile .MakeDirs (tgt )
80
+ seen_dirs = set ()
81
+ for dirname , _ , files in gfile .Walk (src ):
82
+ for f in files :
83
+ src_f = os .path .join (dirname , f )
84
+ tgt_f = src_f .replace (src , tgt )
85
+ tgt_d = os .path .dirname (tgt_f )
86
+ if tgt_d not in seen_dirs :
87
+ gfile .MkDir (tgt_d )
88
+ seen_dirs .add (tgt_d )
89
+ gfile .Copy (src_f , tgt_f , overwrite = True )
90
+
91
+
92
+ def compare_model (model_fn , input_fn , params , master = "local" , temp_dir = None ,
93
+ tolerance = 1e-4 ):
94
+ """Compare the results of running `model_fn` on the TPU and CPU."""
95
+ if not temp_dir :
96
+ temp_dir = tempfile .mkdtemp ()
97
+
98
+ cpu_model_dir = "%s/cpu-model" % temp_dir
99
+ tpu_model_dir = "%s/tpu-model" % temp_dir
100
+ initial_model_dir = "%s/initial-model" % temp_dir
101
+
102
+ logging .info ("Checkpoints and weights will be written to %s" , temp_dir )
103
+
104
+ num_steps = 1
105
+ num_shards = 8
106
+
107
+ def _make_run_config (model_dir ):
108
+ return tpu_config .RunConfig (
109
+ master = master ,
110
+ model_dir = model_dir ,
111
+ save_checkpoints_secs = 10000 ,
112
+ session_config = config_pb2 .ConfigProto (
113
+ allow_soft_placement = True , log_device_placement = False ),
114
+ tpu_config = tpu_config .TPUConfig (
115
+ iterations_per_loop = num_steps ,
116
+ num_shards = num_shards ,
117
+ ),
118
+ )
119
+
120
+ def _make_estimator (use_tpu , model_dir ):
121
+ return tpu_estimator .TPUEstimator (
122
+ model_fn = model_fn ,
123
+ use_tpu = use_tpu ,
124
+ config = _make_run_config (model_dir ),
125
+ train_batch_size = num_shards ,
126
+ params = dict (params , use_tpu = use_tpu ),
127
+ )
128
+
129
+ def _extract_weights (checkpoint ):
130
+ """Extract model weights from the given checkpoint file."""
131
+ weights = {}
132
+ graph = ops .Graph ()
133
+ with graph .as_default ():
134
+ model_fn (
135
+ * input_fn (params ),
136
+ params = dict (params , use_tpu = False ),
137
+ mode = model_fn_lib .ModeKeys .TRAIN )
138
+ saver = tf_saver .Saver ()
139
+ with tf_session .Session (graph = graph ) as sess :
140
+ saver .restore (sess , checkpoint )
141
+ all_vars = []
142
+ all_vars .extend (graph .get_collection (ops .GraphKeys .GLOBAL_VARIABLES ))
143
+ all_vars .extend (graph .get_collection (ops .GraphKeys .TRAINABLE_VARIABLES ))
144
+ all_vars .extend (graph .get_collection (ops .GraphKeys .MODEL_VARIABLES ))
145
+
146
+ for var in all_vars :
147
+ weights [var .name ] = sess .run (var )
148
+ return weights
149
+
150
+ def _run_step (use_tpu , model_dir ):
151
+ est = _make_estimator (use_tpu = use_tpu , model_dir = model_dir )
152
+ est .train (input_fn = input_fn , steps = num_steps )
153
+ weights = _extract_weights (est .latest_checkpoint ())
154
+ with gfile .Open (temp_dir + "tpu-%d.weights" % use_tpu , "wb" ) as f :
155
+ f .write (pickle .dumps (weights ))
156
+ return weights
157
+
158
+ # initialize models to the same weights by running a single step on the CPU
159
+ _run_step (use_tpu = False , model_dir = initial_model_dir )
160
+
161
+ copy_dir (initial_model_dir , cpu_model_dir )
162
+ cpu_weights = _run_step (use_tpu = False , model_dir = cpu_model_dir )
163
+
164
+ copy_dir (initial_model_dir , tpu_model_dir )
165
+ tpu_weights = _run_step (use_tpu = True , model_dir = tpu_model_dir )
166
+
167
+ bad_weights = False
168
+ for k in cpu_weights :
169
+ if k not in tpu_weights :
170
+ raise KeyError ("Missing weight %s from TPU checkpoint." , k )
171
+
172
+ if not np .allclose (
173
+ cpu_weights [k ], tpu_weights [k ], rtol = tolerance , atol = tolerance ):
174
+ bad_weights = True
175
+ logging .error ("Weights for layer %s have diverged." , k )
176
+
177
+ if bad_weights :
178
+ raise ValueError ("Some weights have diverged. Output pickle files have "
179
+ "been written to %s for inspection." % temp_dir )
180
+
181
+
64
182
class TPUTestCase (test_util .TensorFlowTestCase ):
65
183
"""Adds helpers for testing on TPU devices to `TensorFlowTestCase`.
66
184
67
185
Example usage:
68
186
69
187
```
70
188
def model_fn(features):
71
- return tf.reduce_sum(features * 2)
189
+ return tf.reduce_sum(features * 2)
72
190
73
191
class ModelTests(test_util.TPUTestCase):
74
192
def test_sum(self):
@@ -97,10 +215,10 @@ def run_on_device(self, model_fn, model_inputs, device):
97
215
Returns:
98
216
Output from the model function.
99
217
"""
218
+
100
219
def _make_placeholders ():
101
- return dict (
102
- [(gen_array_ops .placeholder_with_default (v , v .shape ), v )
103
- for v in model_inputs ])
220
+ return dict ([(gen_array_ops .placeholder_with_default (v , v .shape ), v )
221
+ for v in model_inputs ])
104
222
105
223
if device == "tpu" :
106
224
with self .test_session (graph = ops .Graph ()) as sess :
@@ -133,7 +251,10 @@ def _compare_values(self, actual_outputs, expected_outputs):
133
251
else :
134
252
self .assertAllCloseAccordingToType (actual_outputs , expected_outputs )
135
253
136
- def assert_device_output (self , model_fn , model_inputs , expected_outputs ,
254
+ def assert_device_output (self ,
255
+ model_fn ,
256
+ model_inputs ,
257
+ expected_outputs ,
137
258
devices = ("cpu" , "gpu" , "tpu" )):
138
259
"""Run `model_fn` on the given devices.
139
260
0 commit comments