Skip to content

Commit

Permalink
Reinitialize tpu system for model parallelism(except mesh-tensorflow)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 244983361
  • Loading branch information
toponado-zz authored and tensorflower-gardener committed Apr 24, 2019
1 parent 6def3a9 commit c4638b8
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions tensorflow/python/tpu/tpu_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,10 +479,16 @@ def __init__(self,
ctx.config.tpu_config.initial_infeed_sleep_secs)

# When using model parallelism, the TPU is pre-initialized at startup to
# fetch mesh information. We skip re-initializing it here to avoid
# suspected issues due to the mesh layout changing on the second
# initialization.
self._should_initialize_tpu = not ctx.model_parallelism_enabled
# fetch mesh information. We skip re-initializing it here for
# MeshTensorFlow since it places variables on TPU directly. Reinitialize tpu
# is causing the variable corruption since the previous allocated memory
# might be overwritten for other purpose.
if (ctx.model_parallelism_enabled and
ctx.is_input_broadcast_with_iterators()):
self._should_initialize_tpu = False
else:
self._should_initialize_tpu = True

self._tpu_compile_op = tpu_compile_op

def begin(self):
Expand Down

0 comments on commit c4638b8

Please sign in to comment.