Skip to content

Commit 80a3d01

Browse files
rjpowertensorflower-gardener
authored andcommitted
Add a model comparison function to TPU test utilities.
PiperOrigin-RevId: 175620458
1 parent 8997ae6 commit 80a3d01

File tree

2 files changed

+130
-8
lines changed

2 files changed

+130
-8
lines changed

tensorflow/contrib/tpu/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package(
1616
"//cloud/vmm/testing/tests/tpu:__subpackages__",
1717
"//learning/brain:__subpackages__",
1818
"//tensorflow:__subpackages__",
19+
"//third_party/cloud_tpu:__subpackages__",
1920
],
2021
)
2122

tensorflow/contrib/tpu/python/tpu/test_util.py

+129-8
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,26 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21-
from tensorflow.contrib.tpu.python.tpu import tpu
21+
import os.path
22+
import pickle
23+
import tempfile
24+
25+
import numpy as np
2226

23-
from tensorflow.python.client import session
27+
from tensorflow.contrib.tpu.python.tpu import tpu
28+
from tensorflow.contrib.tpu.python.tpu import tpu_config
29+
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
30+
from tensorflow.core.protobuf import config_pb2
31+
from tensorflow.python.client import session as tf_session
32+
from tensorflow.python.estimator import model_fn as model_fn_lib
2433
from tensorflow.python.framework import errors
2534
from tensorflow.python.framework import ops
2635
from tensorflow.python.framework import test_util
2736
from tensorflow.python.ops import gen_array_ops
2837
from tensorflow.python.ops import variables
38+
from tensorflow.python.platform import gfile
39+
from tensorflow.python.platform import tf_logging as logging
40+
from tensorflow.python.training import saver as tf_saver
2941

3042

3143
def has_tpu():
@@ -38,8 +50,9 @@ def has_tpu():
3850
Returns:
3951
boolean, True if a TPU device is available, otherwise False.
4052
"""
53+
4154
def _check():
42-
with session.Session() as sess:
55+
with tf_session.Session() as sess:
4356
sess.run(tpu.initialize_system())
4457
sess.run(tpu.shutdown_system())
4558

@@ -61,14 +74,119 @@ def _available_devices():
6174
return tuple(devices)
6275

6376

77+
def copy_dir(src, tgt):
78+
"""Copy src to tgt."""
79+
gfile.MakeDirs(tgt)
80+
seen_dirs = set()
81+
for dirname, _, files in gfile.Walk(src):
82+
for f in files:
83+
src_f = os.path.join(dirname, f)
84+
tgt_f = src_f.replace(src, tgt)
85+
tgt_d = os.path.dirname(tgt_f)
86+
if tgt_d not in seen_dirs:
87+
gfile.MkDir(tgt_d)
88+
seen_dirs.add(tgt_d)
89+
gfile.Copy(src_f, tgt_f, overwrite=True)
90+
91+
92+
def compare_model(model_fn, input_fn, params, master="local", temp_dir=None,
93+
tolerance=1e-4):
94+
"""Compare the results of running `model_fn` on the TPU and CPU."""
95+
if not temp_dir:
96+
temp_dir = tempfile.mkdtemp()
97+
98+
cpu_model_dir = "%s/cpu-model" % temp_dir
99+
tpu_model_dir = "%s/tpu-model" % temp_dir
100+
initial_model_dir = "%s/initial-model" % temp_dir
101+
102+
logging.info("Checkpoints and weights will be written to %s", temp_dir)
103+
104+
num_steps = 1
105+
num_shards = 8
106+
107+
def _make_run_config(model_dir):
108+
return tpu_config.RunConfig(
109+
master=master,
110+
model_dir=model_dir,
111+
save_checkpoints_secs=10000,
112+
session_config=config_pb2.ConfigProto(
113+
allow_soft_placement=True, log_device_placement=False),
114+
tpu_config=tpu_config.TPUConfig(
115+
iterations_per_loop=num_steps,
116+
num_shards=num_shards,
117+
),
118+
)
119+
120+
def _make_estimator(use_tpu, model_dir):
121+
return tpu_estimator.TPUEstimator(
122+
model_fn=model_fn,
123+
use_tpu=use_tpu,
124+
config=_make_run_config(model_dir),
125+
train_batch_size=num_shards,
126+
params=dict(params, use_tpu=use_tpu),
127+
)
128+
129+
def _extract_weights(checkpoint):
130+
"""Extract model weights from the given checkpoint file."""
131+
weights = {}
132+
graph = ops.Graph()
133+
with graph.as_default():
134+
model_fn(
135+
*input_fn(params),
136+
params=dict(params, use_tpu=False),
137+
mode=model_fn_lib.ModeKeys.TRAIN)
138+
saver = tf_saver.Saver()
139+
with tf_session.Session(graph=graph) as sess:
140+
saver.restore(sess, checkpoint)
141+
all_vars = []
142+
all_vars.extend(graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
143+
all_vars.extend(graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
144+
all_vars.extend(graph.get_collection(ops.GraphKeys.MODEL_VARIABLES))
145+
146+
for var in all_vars:
147+
weights[var.name] = sess.run(var)
148+
return weights
149+
150+
def _run_step(use_tpu, model_dir):
151+
est = _make_estimator(use_tpu=use_tpu, model_dir=model_dir)
152+
est.train(input_fn=input_fn, steps=num_steps)
153+
weights = _extract_weights(est.latest_checkpoint())
154+
with gfile.Open(temp_dir + "tpu-%d.weights" % use_tpu, "wb") as f:
155+
f.write(pickle.dumps(weights))
156+
return weights
157+
158+
# initialize models to the same weights by running a single step on the CPU
159+
_run_step(use_tpu=False, model_dir=initial_model_dir)
160+
161+
copy_dir(initial_model_dir, cpu_model_dir)
162+
cpu_weights = _run_step(use_tpu=False, model_dir=cpu_model_dir)
163+
164+
copy_dir(initial_model_dir, tpu_model_dir)
165+
tpu_weights = _run_step(use_tpu=True, model_dir=tpu_model_dir)
166+
167+
bad_weights = False
168+
for k in cpu_weights:
169+
if k not in tpu_weights:
170+
raise KeyError("Missing weight %s from TPU checkpoint.", k)
171+
172+
if not np.allclose(
173+
cpu_weights[k], tpu_weights[k], rtol=tolerance, atol=tolerance):
174+
bad_weights = True
175+
logging.error("Weights for layer %s have diverged.", k)
176+
177+
if bad_weights:
178+
raise ValueError("Some weights have diverged. Output pickle files have "
179+
"been written to %s for inspection." % temp_dir)
180+
181+
64182
class TPUTestCase(test_util.TensorFlowTestCase):
65183
"""Adds helpers for testing on TPU devices to `TensorFlowTestCase`.
66184
67185
Example usage:
68186
69187
```
70188
def model_fn(features):
71-
return tf.reduce_sum(features * 2)
189+
return tf.reduce_sum(features * 2)
72190
73191
class ModelTests(test_util.TPUTestCase):
74192
def test_sum(self):
@@ -97,10 +215,10 @@ def run_on_device(self, model_fn, model_inputs, device):
97215
Returns:
98216
Output from the model function.
99217
"""
218+
100219
def _make_placeholders():
101-
return dict(
102-
[(gen_array_ops.placeholder_with_default(v, v.shape), v)
103-
for v in model_inputs])
220+
return dict([(gen_array_ops.placeholder_with_default(v, v.shape), v)
221+
for v in model_inputs])
104222

105223
if device == "tpu":
106224
with self.test_session(graph=ops.Graph()) as sess:
@@ -133,7 +251,10 @@ def _compare_values(self, actual_outputs, expected_outputs):
133251
else:
134252
self.assertAllCloseAccordingToType(actual_outputs, expected_outputs)
135253

136-
def assert_device_output(self, model_fn, model_inputs, expected_outputs,
254+
def assert_device_output(self,
255+
model_fn,
256+
model_inputs,
257+
expected_outputs,
137258
devices=("cpu", "gpu", "tpu")):
138259
"""Run `model_fn` on the given devices.
139260

0 commit comments

Comments
 (0)