Skip to content

Commit 4470dbe

Browse files
committed
Moved tpu_config.RunConfig check to inside of use_tpu block
1 parent 2336cdf commit 4470dbe

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

tensorflow/contrib/tpu/python/tpu/tpu_estimator.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -146,27 +146,25 @@ def end(self, session):
146146
class TpuEstimator(estimator_lib.Estimator):
147147
"""Estimator with TPU support.
148148
149-
The only difference is a wrapped model_fn is set in the constructor.
149+
The only difference is a wrapped model_fn is set in the constructor.
150150
"""
151-
152151
def __init__(self,
153152
model_fn=None,
154153
model_dir=None,
155154
config=None,
156155
params=None,
157156
use_tpu=True):
158157
if use_tpu:
158+
if not isinstance(config, tpu_config.RunConfig):
159+
raise ValueError('`config` must be `tpu_config.RunConfig`')
159160
model_function = wrapped_model_fn(model_fn, config)
160161
else:
161162
model_function = model_fn
162-
163163
super(TpuEstimator, self).__init__(
164164
model_fn=model_function,
165165
model_dir=model_dir,
166166
config=config,
167167
params=params)
168-
if not isinstance(config, tpu_config.RunConfig):
169-
raise ValueError('`config` must be `tpu_config.RunConfig`')
170168

171169
def _create_global_step(self, graph):
172170
"""Creates a global step suitable for TPUs.

0 commit comments

Comments
 (0)