Skip to content

Commit

Permalink
Merge pull request tensorflow#21425 from saeta/fix_tpu
Browse files Browse the repository at this point in the history
Refactor dependencies so keras_support can be imported directly.
  • Loading branch information
Amit Patankar authored Aug 7, 2018
2 parents b7127e5 + 14b8b8b commit 656e7a2
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 10 deletions.
1 change: 0 additions & 1 deletion tensorflow/contrib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ py_library(
"//tensorflow/contrib/tfprof",
"//tensorflow/contrib/timeseries",
"//tensorflow/contrib/tpu",
"//tensorflow/contrib/tpu:tpu_py",
"//tensorflow/contrib/training:training_py",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:util",
Expand Down
1 change: 1 addition & 0 deletions tensorflow/contrib/cmake/python_protos.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ tensorflow/core
tensorflow/core/kernels/boosted_trees
tensorflow/core/profiler
tensorflow/python
tensorflow/compiler/xla
tensorflow/contrib/boosted_trees/proto
tensorflow/contrib/cloud/kernels
tensorflow/contrib/decision_trees/proto
Expand Down
3 changes: 1 addition & 2 deletions tensorflow/contrib/distribute/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,7 @@ py_library(
deps = [
":one_device_strategy",
":values",
"//tensorflow/contrib/tpu",
"//tensorflow/contrib/tpu:tpu_py",
"//tensorflow/contrib/tpu:tpu_lib",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
Expand Down
17 changes: 12 additions & 5 deletions tensorflow/contrib/tpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":tpu_lib",
":tpu_py",
"//tensorflow/compiler/xla/experimental/xla_sharding",
"//tensorflow/compiler/xla/python_api:xla_shape",
"//tensorflow/contrib/training:training_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
Expand Down Expand Up @@ -133,7 +134,7 @@ py_library(

tf_custom_op_py_library(
name = "tpu_py",
srcs = glob(["python/ops/*.py"]) + ["__init__.py"],
srcs = glob(["python/ops/*.py"]),
dso = [":python/ops/_tpu_ops.so"],
kernels = [
":all_ops",
Expand All @@ -152,9 +153,13 @@ tf_custom_op_py_library(

py_library(
name = "tpu",
srcs = ["python/tpu/__init__.py"],
srcs = [
"__init__.py",
"python/tpu/__init__.py",
],
srcs_version = "PY2AND3",
deps = [
":keras_support", # split out to avoid cycle with tpu_strategy
":tpu_estimator",
":tpu_lib",
],
Expand All @@ -166,11 +171,13 @@ py_library(
"python/tpu/keras_support.py",
],
srcs_version = "PY2AND3",
visibility = [
"//tensorflow:__subpackages__",
],
deps = [
":tpu_lib",
":tpu_py",
"//tensorflow/contrib/cluster_resolver:tpu_cluster_resolver_py",
"//tensorflow/contrib/distribute/python:tpu_strategy",
"//tensorflow/contrib/distribute",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/tpu/proto:compilation_result_proto_py",
"//tensorflow/core:protos_all_py",
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/contrib/tpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
@@TPUConfig
@@bfloat16_scope
@@TPUDistributionStrategy
@@keras_to_tpu_model
"""

from __future__ import absolute_import
Expand All @@ -58,6 +61,8 @@
from tensorflow.contrib.tpu.python.ops.tpu_ops import *
from tensorflow.contrib.tpu.python.tpu.bfloat16 import *
from tensorflow.contrib.tpu.python.tpu.device_assignment import *
from tensorflow.contrib.tpu.python.tpu.keras_support import tpu_model as keras_to_tpu_model
from tensorflow.contrib.tpu.python.tpu.keras_support import TPUDistributionStrategy
from tensorflow.contrib.tpu.python.tpu.topology import *
from tensorflow.contrib.tpu.python.tpu.tpu import *
from tensorflow.contrib.tpu.python.tpu.tpu_config import *
Expand Down
7 changes: 5 additions & 2 deletions tensorflow/contrib/tpu/python/tpu/keras_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
import numpy as np

from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver
from tensorflow.contrib.distribute.python import tpu_strategy
from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result
from tensorflow.contrib.tpu.python.ops import tpu_ops
Expand All @@ -82,7 +81,11 @@
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging

TPUDistributionStrategy = tpu_strategy.TPUStrategy # pylint: disable=invalid-name

# Work-around dependency cycle between DistributionStrategy and TPU lib.
def TPUDistributionStrategy(*args, **kw): # pylint: disable=invalid-name
from tensorflow.contrib.distribute.python import tpu_strategy # pylint: disable=g-import-not-at-top
return tpu_strategy.TPUStrategy(*args, **kw)


class TPUEmbedding(embeddings.Embedding):
Expand Down

0 comments on commit 656e7a2

Please sign in to comment.