Skip to content

Commit aa103f3

Browse files
edlopertensorflower-gardener
authored andcommitted
Improve the error message if a non-hashable Python object that is not supported by tf.nest (such as a set) is passed to a tf.function.
PiperOrigin-RevId: 307085105 Change-Id: I044debafd1eeb509c189cf2160c2deb2221a2228
1 parent 1dd42cf commit aa103f3

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

tensorflow/python/eager/function.py

+10
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,16 @@ def _make_input_signature_hashable(elem, variable_map=None):
113113
except TypeError:
114114
assert isinstance(elem, weakref.ReferenceType)
115115
v = elem()
116+
117+
# Check if v is a Variable. Note that we can't use isinstance to check if
118+
# it's a variable, since not all variable types are subclass of Variable.
119+
# TODO(mdan) Update this to use a generic "Variable" superclass once we
120+
# create one.
121+
if not (hasattr(v, "shape") and hasattr(v, "dtype")):
122+
raise ValueError("Arguments to a tf.function must be Tensors, Variables, "
123+
"or hashable Python objects (or nested structures of "
124+
"these types).\nGot type: %s" % type(v).__name__)
125+
116126
idx = variable_map.get(id(v))
117127
if idx is None:
118128
idx = len(variable_map)

tensorflow/python/eager/function_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def testNoHash(self):
442442
def f(_):
443443
return 1.0
444444

445-
with self.assertRaisesRegexp(AttributeError, 'set'):
445+
with self.assertRaisesRegexp(ValueError, r'Got type: set'):
446446
f(set([]))
447447

448448
def testFuncName(self):

0 commit comments

Comments
 (0)