Skip to content

Commit

Permalink
Expose a test related internal API for keras.
Browse files Browse the repository at this point in the history
The layer_test will be used by keras-cv/nlp and tf-addons.

PiperOrigin-RevId: 427277621
  • Loading branch information
qlzh727 authored and tensorflower-gardener committed Feb 8, 2022
1 parent 373ad97 commit 08872b3
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 0 deletions.
1 change: 1 addition & 0 deletions keras/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ py_library(
"//keras/preprocessing",
"//keras/saving",
"//keras/testing_infra:keras_doctest_lib",
"//keras/testing_infra:test_utils", # For keras.__internal__ API
"//keras/utils",
"//keras/wrappers",
],
Expand Down
1 change: 1 addition & 0 deletions keras/api/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ keras_packages = [
"keras.saving.model_config",
"keras.saving.save",
"keras.saving.saved_model_experimental",
"keras.testing_infra.test_utils",
"keras.utils.data_utils",
"keras.utils.generic_utils",
"keras.utils.io_utils",
Expand Down
4 changes: 4 additions & 0 deletions keras/api/golden/v2/tensorflow.keras.__internal__.utils.pbtxt
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
path: "tensorflow.keras.__internal__.utils"
tf_module {
member_method {
name: "layer_test"
argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "register_symbolic_tensor_type"
argspec: "args=[\'cls\'], varargs=None, keywords=None, defaults=None"
Expand Down
2 changes: 2 additions & 0 deletions keras/testing_infra/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import numpy as np
import tensorflow.compat.v2 as tf
from tensorflow.python.framework import test_util as tf_test_utils # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.util.tf_export import keras_export # pylint: disable=g-direct-tensorflow-import


def string_test(actual, expected):
Expand Down Expand Up @@ -76,6 +77,7 @@ def get_test_data(train_samples,
(x[train_samples:], y[train_samples:]))


@keras_export('keras.__internal__.utils.layer_test', v1=[])
@tf_test_utils.disable_cudnn_autotune
def layer_test(layer_cls,
kwargs=None,
Expand Down

0 comments on commit 08872b3

Please sign in to comment.