Skip to content

Commit

Permalink
Remove all hard BUILD dependencies on TF from JAX and NumPy substrates.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 314401801
  • Loading branch information
sharadmv authored and tensorflower-gardener committed Jun 2, 2020
1 parent 1fe7109 commit e188f92
Show file tree
Hide file tree
Showing 46 changed files with 602 additions and 187 deletions.
15 changes: 9 additions & 6 deletions tensorflow_probability/python/build_defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,9 @@
# [internal] load python3.bzl

NO_REWRITE_NEEDED = [
"internal:cache_util",
"internal:docstring_util",
"internal:name_util",
"internal:reparameterization",
"layers",
"optimizer/convergence_criteria",
"optimizer:sgld",
"optimizer:variational_sgld",
"platform_google",
]

Expand Down Expand Up @@ -109,9 +104,17 @@ def multi_substrate_py_library(
srcs_version = srcs_version,
testonly = testonly,
)
remove_deps = [
"//third_party/py/tensorflow",
"//third_party/py/tensorflow:tensorflow",
]

trimmed_deps = [dep for dep in deps if dep not in substrates_omit_deps]
resolved_omit_deps = [_resolve_omit_dep(dep) for dep in substrates_omit_deps]
resolved_omit_deps = [
_resolve_omit_dep(dep)
for dep in substrates_omit_deps
if dep not in remove_deps
]
for src in srcs:
native.genrule(
name = "rewrite_{}_numpy".format(src.replace(".", "_")),
Expand Down
3 changes: 3 additions & 0 deletions tensorflow_probability/python/distributions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ exports_files(["LICENSE"])
multi_substrate_py_library(
name = "distributions",
srcs = ["__init__.py"],
substrates_omit_deps = [
":pixel_cnn",
],
deps = [
":autoregressive",
":batch_reshape",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# ============================================================================
# Tests for the backend integration.

# [internal] load python3.bzl

# py_test rules are rewritten to py3_test on export to OSS, for reasons of
# reversibility of import<->export transforms. This is only necessary until we
# eliminate all py2and3_test targets, after which everything can safely use
Expand All @@ -31,13 +29,14 @@ package(
],
)

py_test(
py3_test(
name = "numpy_integration_test",
size = "small",
srcs = ["numpy_integration_test.py"],
python_version = "PY3",
tags = ["tfp_numpy"],
deps = [
"//tensorflow_probability",
"//tensorflow_probability/python/experimental/substrates/numpy",
],
)

Expand All @@ -48,7 +47,7 @@ py3_test(
tags = ["tfp_jax"],
deps = [
# jax dep,
"//tensorflow_probability",
"//tensorflow_probability/python/experimental/substrates/jax",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
import jax
import jax.numpy as np

import tensorflow_probability as tfp
tfp = tfp.experimental.substrates.jax
from tensorflow_probability.python.experimental.substrates import jax as tfp

tfb = tfp.bijectors
tfd = tfp.distributions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@

import numpy as np

import tensorflow_probability as tfp
tfp = tfp.experimental.substrates.numpy
from tensorflow_probability.python.experimental.substrates import numpy as tfp

tfb = tfp.bijectors
tfd = tfp.distributions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,36 @@
'from tensorflow.python.framework import tensor_shape':
('from tensorflow_probability.python.internal.backend.numpy.gen '
'import tensor_shape'),
'from tensorflow.python.util import nest':
'from tensorflow.python.framework import ops':
('from tensorflow_probability.python.internal.backend.numpy '
'import nest'),
'import ops'),
'from tensorflow.python.framework import tensor_util':
('from tensorflow_probability.python.internal.backend.numpy '
'import ops'),
'from tensorflow.python.util import':
'from tensorflow_probability.python.internal.backend.numpy import',
'from tensorflow.python.util.all_util':
'from tensorflow_probability.python.internal.backend.numpy.private',
'from tensorflow.python.ops.linalg':
'from tensorflow_probability.python.internal.backend.numpy.gen',
'from tensorflow.python.ops import parallel_for':
'from tensorflow_probability.python.internal.backend.numpy '
'import functional_ops as parallel_for',
'from tensorflow.python.ops import control_flow_ops':
'from tensorflow_probability.python.internal.backend.numpy '
'import control_flow as control_flow_ops',
'from tensorflow.python.eager import context':
'from tensorflow_probability.python.internal.backend.numpy '
'import private',
('from tensorflow.python.client '
'import pywrap_tf_session as c_api'):
'pass',
('from tensorflow.python '
'import pywrap_tensorflow as c_api'):
'pass'
}

DISABLED_BY_PKG = {
'distributions':
('internal.moving_stats',),
'mcmc':
('nuts', 'sample_annealed_importance', 'sample_halton_sequence',
'slice_sampler_kernel'),
Expand All @@ -67,16 +84,30 @@
}
LIBS = ('bijectors', 'distributions', 'experimental', 'math', 'mcmc',
'optimizer', 'stats', 'util')
INTERNALS = ('assert_util', 'batched_rejection_sampler', 'distribution_util',
'dtype_util', 'hypothesis_testlib', 'implementation_selection',
'nest_util', 'prefer_static', 'samplers', 'special_math',
'tensor_util', 'tensorshape_util', 'test_combinations',
'test_util')
INTERNALS = (
'assert_util',
'batched_rejection_sampler',
'cache_util',
'distribution_util',
'dtype_util',
'hypothesis_testlib',
'implementation_selection',
'name_util',
'nest_util',
'prefer_static',
'samplers',
'special_math',
'tensor_util',
'tensorshape_util',
'test_combinations',
'test_util'
)
OPTIMIZERS = ('linesearch',)
LINESEARCH = ('internal',)
SAMPLERS = ('categorical', 'normal', 'poisson', 'uniform', 'shuffle')

PRIVATE_TF_PKGS = ('array_ops', 'random_ops')
PRIVATE_TF_PKGS = ('array_ops', 'control_flow_util', 'gradient_checker_v2',
'numpy_text', 'random_ops')


def main(argv):
Expand Down Expand Up @@ -190,6 +221,13 @@ def main(argv):
' as {}'.format(private)
for private in PRIVATE_TF_PKGS
})
replacements.update({
'tensorflow.python.framework.ops import {}'.format(
private):
'tensorflow_probability.python.internal.backend.numpy import private'
' as {}'.format(private)
for private in PRIVATE_TF_PKGS
})
# pylint: enable=g-complex-comprehension

# TODO(bjp): Delete this block after TFP uses stateless samplers.
Expand Down
5 changes: 2 additions & 3 deletions tensorflow_probability/python/internal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,6 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
# numpy dep,
# tensorflow dep,
],
)

Expand Down Expand Up @@ -187,6 +185,7 @@ multi_substrate_py_test(
srcs = ["implementation_selection_test.py"],
deps = [
":implementation_selection",
# tensorflow dep,
"//tensorflow_probability/python/internal:test_util",
# tensorflow/compiler/jit dep,
],
Expand All @@ -213,7 +212,7 @@ py_library(
],
)

py_library(
multi_substrate_py_library(
name = "name_util",
srcs = ["name_util.py"],
srcs_version = "PY2AND3",
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_probability/python/internal/backend/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ FILENAMES = [
"compat",
"control_flow",
"debugging",
"deprecation",
"dtype",
"errors",
"functional_ops",
Expand All @@ -62,6 +63,7 @@ FILENAMES = [
"tensor_array_ops",
"tensor_array_ops_test",
"test_lib",
"tf_inspect",
"v1",
"v2",
"_utils",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
'framework import dtypes': 'dtype as dtypes',
'framework import errors': 'errors',
'framework import ops': 'ops',
'framework import common_shapes': 'ops as common_shapes',
'framework import tensor_shape': 'tensor_shape',
'module import module': 'ops as module',
'ops import array_ops': 'numpy_array as array_ops',
Expand All @@ -60,11 +61,11 @@
]

UTIL_IMPORTS = """
from tensorflow.python.util import lazy_loader
distribution_util = lazy_loader.LazyLoader(
from tensorflow_probability.python.internal.backend.numpy import private
distribution_util = private.LazyLoader(
"distribution_util", globals(),
"tensorflow_probability.python.internal._numpy.distribution_util")
tensorshape_util = lazy_loader.LazyLoader(
tensorshape_util = private.LazyLoader(
"tensorshape_util", globals(),
"tensorflow_probability.python.internal._numpy.tensorshape_util")
"""
Expand Down Expand Up @@ -109,9 +110,11 @@ def gen_module(module_name):
'from tensorflow.python.ops.linalg '
'import {}'.format(f))
code = code.replace(
'tensorflow.python.ops.linalg import ',
'tensorflow_probability.python.internal.backend.numpy.gen import ')

'tensorflow.python.ops.linalg import',
'tensorflow_probability.python.internal.backend.numpy.gen import')
code = code.replace(
'tensorflow.python.util import',
'tensorflow_probability.python.internal.backend.numpy import')
code = code.replace('tensor_util.constant_value(', '(')
code = code.replace('tensor_util.is_tensor(', 'ops.is_tensor(')
code = code.replace(
Expand Down
42 changes: 39 additions & 3 deletions tensorflow_probability/python/internal/backend/numpy/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ py_library(
":compat",
":control_flow",
":debugging",
":deprecation",
":dtype",
":errors",
":functional_ops",
Expand All @@ -56,6 +57,7 @@ py_library(
":static_rewrites",
":tensor_array_ops",
":test_lib",
":tf_inspect",
],
)

Expand All @@ -82,6 +84,7 @@ py_library(
srcs = ["control_flow.py"],
deps = [
":_utils",
":dtype",
":ops",
# numpy dep,
],
Expand All @@ -98,6 +101,12 @@ py_library(
],
)

py_library(
name = "deprecation",
srcs = ["deprecation.py"],
deps = [],
)

py_library(
name = "dtype",
srcs = ["dtype.py"],
Expand Down Expand Up @@ -335,11 +344,18 @@ py_library(
name = "test_lib",
srcs = ["test_lib.py"],
deps = [
":_utils",
# tensorflow dep,
# absl/logging dep,
# absl/testing:absltest dep,
# numpy dep,
],
)

py_library(
name = "tf_inspect",
srcs = ["tf_inspect.py"],
deps = [],
)

py_test(
name = "numpy_test",
size = "small",
Expand Down Expand Up @@ -378,14 +394,14 @@ py_library(
":nest",
":ops",
":tensor_array_ops",
# tensorflow dep,
],
)

py_library(
name = "_utils",
srcs = ["_utils.py"],
deps = [
":nest",
# wrapt dep,
],
)
Expand Down Expand Up @@ -431,6 +447,26 @@ LINOP_FILES = [
tools = ["//tensorflow_probability/python/internal/backend/meta:gen_linear_operators"],
) for filename in LINOP_FILES]

# Rules helpful for generating new rewritten files.
[genrule(
name = "generate_{}".format(filename),
testonly = 1,
srcs = [],
outs = ["gen_new/{}.py".format(filename)],
cmd = ("$(location //tensorflow_probability/python/internal/backend/meta:gen_linear_operators) " +
"--module_name={} --whitelist={} > $@").format(
filename,
",".join(LINOP_FILES),
),
tools = ["//tensorflow_probability/python/internal/backend/meta:gen_linear_operators"],
) for filename in LINOP_FILES]

py_library(
name = "generated_files",
testonly = 1,
srcs = ["gen_new/{}.py".format(filename) for filename in LINOP_FILES],
)

py_library(
name = "linear_operator_gen",
testonly = 1,
Expand Down
Loading

0 comments on commit e188f92

Please sign in to comment.