Skip to content

Commit

Permalink
Update TFP examples to Python 3.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 292195703
  • Loading branch information
jburnim authored and tensorflower-gardener committed Jan 29, 2020
1 parent 425395d commit aac70e6
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 54 deletions.
118 changes: 67 additions & 51 deletions tensorflow_probability/examples/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,30 @@
# Description:
# TensorFlow Probability examples.

# [internal] load python3.bzl

licenses(["notice"]) # Apache 2.0

package(
default_visibility = [
"//tensorflow_probability:__subpackages__",
],
)

licenses(["notice"]) # Apache 2.0

exports_files(["LICENSE"])

py_binary(
name = "bayesian_neural_network",
srcs = ["bayesian_neural_network.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [":bayesian_neural_network_lib"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":bayesian_neural_network_lib",
],
)

py_library(
name = "bayesian_neural_network_lib",
srcs = ["bayesian_neural_network.py"],
srcs_version = "PY2AND3",
srcs_version = "PY3",
deps = [
# absl/flags dep,
# matplotlib dep,
Expand All @@ -58,7 +58,7 @@ py_test(
"--max_steps=5",
],
main = "bayesian_neural_network.py",
srcs_version = "PY2AND3",
python_version = "PY3",
tags = [
"tf2-kokoro-broken",
# TODO(b/147689726) Re-enable this test after contrib references are
Expand All @@ -73,15 +73,17 @@ py_test(
py_binary(
name = "grammar_vae",
srcs = ["grammar_vae.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [":grammar_vae_lib"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":grammar_vae_lib",
],
)

py_library(
name = "grammar_vae_lib",
srcs = ["grammar_vae.py"],
srcs_version = "PY2AND3",
srcs_version = "PY3",
deps = [
# absl/flags dep,
# numpy dep,
Expand All @@ -102,7 +104,7 @@ py_test(
"--num_units=3",
],
main = "grammar_vae.py",
srcs_version = "PY2AND3",
python_version = "PY3",
deps = [
":grammar_vae_lib",
],
Expand All @@ -111,15 +113,17 @@ py_test(
py_binary(
name = "disentangled_vae",
srcs = ["disentangled_vae.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [":disentangled_vae_lib"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":disentangled_vae_lib",
],
)

py_library(
name = "disentangled_vae_lib",
srcs = ["disentangled_vae.py"],
srcs_version = "PY2AND3",
srcs_version = "PY3",
deps = [
":sprites_dataset",
# absl:app dep,
Expand Down Expand Up @@ -147,7 +151,7 @@ py_test(
"--enable_debug_logging",
],
main = "disentangled_vae.py",
srcs_version = "PY2AND3",
python_version = "PY3",
tags = [
"tf2-kokoro-broken",
# TODO(b/147689726) Re-enable this test after contrib references are
Expand All @@ -164,8 +168,8 @@ py_test(
size = "medium",
srcs = ["disentangled_vae_test.py"],
main = "disentangled_vae_test.py",
python_version = "PY3",
shard_count = 2,
srcs_version = "PY2AND3",
tags = ["tf2-kokoro-broken"],
deps = [
":disentangled_vae_lib",
Expand All @@ -178,15 +182,17 @@ py_test(
py_binary(
name = "deep_exponential_family",
srcs = ["deep_exponential_family.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [":deep_exponential_family_lib"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":deep_exponential_family_lib",
],
)

py_library(
name = "deep_exponential_family_lib",
srcs = ["deep_exponential_family.py"],
srcs_version = "PY2AND3",
srcs_version = "PY3",
deps = [
# absl/flags dep,
# numpy dep,
Expand All @@ -206,7 +212,7 @@ py_test(
"--layer_sizes=5,3,2",
],
main = "deep_exponential_family.py",
srcs_version = "PY2AND3",
python_version = "PY3",
tags = ["tf2-kokoro-broken"],
deps = [
":deep_exponential_family_lib",
Expand All @@ -216,15 +222,17 @@ py_test(
py_binary(
name = "latent_dirichlet_allocation_distributions",
srcs = ["latent_dirichlet_allocation_distributions.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [":latent_dirichlet_allocation_distributions_lib"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":latent_dirichlet_allocation_distributions_lib",
],
)

py_library(
name = "latent_dirichlet_allocation_distributions_lib",
srcs = ["latent_dirichlet_allocation_distributions.py"],
srcs_version = "PY2AND3",
srcs_version = "PY3",
deps = [
# absl/flags dep,
# numpy dep,
Expand All @@ -248,7 +256,7 @@ py_test(
"--learning_rate=1e-7",
],
main = "latent_dirichlet_allocation_distributions.py",
srcs_version = "PY2AND3",
python_version = "PY3",
deps = [
":latent_dirichlet_allocation_distributions_lib",
],
Expand All @@ -257,15 +265,17 @@ py_test(
py_binary(
name = "logistic_regression",
srcs = ["logistic_regression.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [":logistic_regression_lib"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":logistic_regression_lib",
],
)

py_library(
name = "logistic_regression_lib",
srcs = ["logistic_regression.py"],
srcs_version = "PY2AND3",
srcs_version = "PY3",
deps = [
# absl/flags dep,
# matplotlib dep,
Expand All @@ -287,7 +297,7 @@ py_test(
"--max_steps=50",
],
main = "logistic_regression.py",
srcs_version = "PY2AND3",
python_version = "PY3",
tags = ["tf2-kokoro-broken"],
deps = [
":logistic_regression_lib",
Expand All @@ -297,7 +307,7 @@ py_test(
py_library(
name = "sprites_dataset",
srcs = ["sprites_dataset.py"],
srcs_version = "PY2AND3",
srcs_version = "PY3",
deps = [
# absl/flags dep,
# six dep,
Expand All @@ -309,15 +319,17 @@ py_library(
py_binary(
name = "vae",
srcs = ["vae.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [":vae_lib"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":vae_lib",
],
)

py_library(
name = "vae_lib",
srcs = ["vae.py"],
srcs_version = "PY2AND3",
srcs_version = "PY3",
deps = [
# absl/flags dep,
# numpy dep,
Expand All @@ -340,7 +352,7 @@ py_test(
"--learning_rate=1e-7",
],
main = "vae.py",
srcs_version = "PY2AND3",
python_version = "PY3",
deps = [
":vae_lib",
],
Expand All @@ -349,15 +361,17 @@ py_test(
py_binary(
name = "vq_vae",
srcs = ["vq_vae.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [":vq_vae_lib"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":vq_vae_lib",
],
)

py_library(
name = "vq_vae_lib",
srcs = ["vq_vae.py"],
srcs_version = "PY2AND3",
srcs_version = "PY3",
deps = [
# absl/flags dep,
# matplotlib dep,
Expand All @@ -377,7 +391,7 @@ py_test(
"--base_depth=2",
],
main = "vq_vae.py",
srcs_version = "PY2AND3",
python_version = "PY3",
tags = [
"tf2-kokoro-broken",
# TODO(b/147689726) Re-enable this test after contrib references are
Expand All @@ -392,15 +406,17 @@ py_test(
py_binary(
name = "cifar10_bnn",
srcs = ["cifar10_bnn.py"],
python_version = "PY2",
srcs_version = "PY2AND3",
deps = [":cifar10_bnn_lib"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":cifar10_bnn_lib",
],
)

py_library(
name = "cifar10_bnn_lib",
srcs = ["cifar10_bnn.py"],
srcs_version = "PY2AND3",
srcs_version = "PY3",
deps = [
# absl/flags dep,
# matplotlib dep,
Expand All @@ -421,7 +437,7 @@ py_test(
"--batch_size=5",
],
main = "cifar10_bnn.py",
srcs_version = "PY2AND3",
python_version = "PY3",
tags = ["tf2-kokoro-broken"],
deps = [
":cifar10_bnn_lib",
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_probability/examples/grammar_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def main(argv):
variables = (probabilistic_grammar.variables
+ probabilistic_grammar_variational.variables)
grads = tape.gradient(loss, variables)
grads_and_vars = zip(grads, variables)
grads_and_vars = list(zip(grads, variables))
optimizer.apply_gradients(grads_and_vars, global_step)

if step % 500 == 0:
Expand Down
7 changes: 5 additions & 2 deletions tensorflow_probability/examples/models/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,20 @@
# Description:
# Models for CIFAR10 BNN example

licenses(["notice"]) # Apache 2.0

package(
default_visibility = [
"//tensorflow_probability:__subpackages__",
],
)

licenses(["notice"]) # Apache 2.0

exports_files(["LICENSE"])

py_library(
name = "models",
srcs = ["__init__.py"],
srcs_version = "PY3",
deps = [
":bayesian_resnet",
":bayesian_vgg",
Expand All @@ -37,6 +38,7 @@ py_library(
py_library(
name = "bayesian_resnet",
srcs = ["bayesian_resnet.py"],
srcs_version = "PY3",
deps = [
# tensorflow dep,
"//tensorflow_probability",
Expand All @@ -46,6 +48,7 @@ py_library(
py_library(
name = "bayesian_vgg",
srcs = ["bayesian_vgg.py"],
srcs_version = "PY3",
deps = [
# tensorflow dep,
"//tensorflow_probability",
Expand Down

0 comments on commit aac70e6

Please sign in to comment.