Skip to content

Commit

Permalink
Add attribute setting and getting support to TF_Function
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 169337159
  • Loading branch information
iganichev authored and tensorflower-gardener committed Sep 20, 2017
1 parent ed89a2b commit 7ad8e25
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 5 deletions.
18 changes: 18 additions & 0 deletions tensorflow/c/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,24 @@ TF_CAPI_EXPORT extern void TF_FunctionToFunctionDef(TF_Function* func,
TF_CAPI_EXPORT extern TF_Function* TF_FunctionImportFunctionDef(
const TF_Buffer* func_def, TF_Status* status);

// Sets function attribute named `attr_name` to value stored in `proto`.
// If this attribute is already set to another value, it is overriden.
// `proto` should point to a sequence of bytes of length `proto_len`
// representing a binary serialization of an AttrValue protocol
// buffer.
TF_CAPI_EXPORT extern void TF_FunctionSetAttrValueProto(TF_Function* func,
const char* attr_name,
const void* proto,
size_t proto_len,
TF_Status* status);

// Sets `output_attr_value` to the binary-serialized AttrValue proto
// representation of the value of the `attr_name` attr of `func`.
// If `attr_name` attribute is not present, status is set to an error.
TF_CAPI_EXPORT extern void TF_FunctionGetAttrValueProto(
TF_Function* func, const char* attr_name, TF_Buffer* output_attr_value,
TF_Status* status);

// Frees the memory used by the `func` struct.
// TF_DeleteFunction is a noop if `func` is null.
// Deleting a function does not remove it from any graphs it was copied to.
Expand Down
27 changes: 27 additions & 0 deletions tensorflow/c/c_api_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -545,4 +545,31 @@ TF_Function* TF_FunctionImportFunctionDef(const TF_Buffer* func_def,
return func;
}

void TF_FunctionSetAttrValueProto(TF_Function* func, const char* attr_name,
const void* proto, size_t proto_len,
TF_Status* status) {
tensorflow::AttrValue attr_value;
if (!attr_value.ParseFromArray(proto, proto_len)) {
status->status = InvalidArgument(
"Unparseable AttrValue proto passed to "
"TF_FunctionSetAttrValueProto");
return;
}
(*func->fdef.mutable_attr())[string(attr_name)] = attr_value;
status->status = tensorflow::Status::OK();
}

void TF_FunctionGetAttrValueProto(TF_Function* func, const char* attr_name,
TF_Buffer* output_attr_value,
TF_Status* status) {
const auto& it = func->fdef.attr().find(attr_name);
if (it == func->fdef.attr().end()) {
status->status =
InvalidArgument("Function '", func->fdef.signature().name(),
"' has no attr named '", attr_name, "'.");
return;
}
status->status = MessageToBuffer(it->second, output_attr_value);
}

void TF_DeleteFunction(TF_Function* func) { delete func; }
39 changes: 39 additions & 0 deletions tensorflow/c/c_api_function_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,13 @@ class CApiFunctionTest : public ::testing::Test {
TF_DeleteBuffer(buf);
}

void GetAttr(const char* attr_name, AttrValue* out_attr) {
TF_Buffer* attr_buf = TF_NewBuffer();
TF_FunctionGetAttrValueProto(func_, attr_name, attr_buf, s_);
ASSERT_TRUE(out_attr->ParseFromArray(attr_buf->data, attr_buf->length));
TF_DeleteBuffer(attr_buf);
}

const char* func_name_ = "MyFunc";
const char* func_node_name_ = "MyFunc_0";
TF_Status* s_;
Expand Down Expand Up @@ -1406,5 +1413,37 @@ TEST_F(CApiFunctionTest, ImportFunctionDef_InvalidProto) {
string(TF_Message(s_)));
}

TEST_F(CApiFunctionTest, Attribute) {
DefineFunction(func_name_, &func_);

// Get non existent attribute
TF_Buffer* attr_buf = TF_NewBuffer();
TF_FunctionGetAttrValueProto(func_, "foo_attr", attr_buf, s_);
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
EXPECT_EQ(string("Function 'MyFunc' has no attr named 'foo_attr'."),
string(TF_Message(s_)));
TF_DeleteBuffer(attr_buf);

// Set attr
tensorflow::AttrValue attr;
attr.set_s("test_attr_value");
string bytes;
attr.SerializeToString(&bytes);
TF_FunctionSetAttrValueProto(func_, "test_attr_name", bytes.data(),
bytes.size(), s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);

// Get attr
AttrValue read_attr;
GetAttr("test_attr_name", &read_attr);
ASSERT_EQ(attr.DebugString(), read_attr.DebugString());

// Retrieve the same attr after save/restore
Reincarnate();
AttrValue read_attr2;
GetAttr("test_attr_name", &read_attr2);
ASSERT_EQ(attr.DebugString(), read_attr2.DebugString());
}

} // namespace
} // namespace tensorflow
17 changes: 17 additions & 0 deletions tensorflow/python/framework/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,25 @@ def _create_definition_if_needed_impl(self):
output_names,
None, # opts
status)
self._set_c_attrs(kwargs_attr)
# pylint: enable=protected-access

def _set_c_attrs(self, attrs):
"""Sets `attrs` as attributes of self._c_func.
Requires that self._c_func is not None.
Args:
attrs: a dictionary from attribute name to attribute proto value
"""
for name, attr_value in attrs.items():
serialized = attr_value.SerializeToString()
# TODO(skyewm): this creates and deletes a new TF_Status for every attr.
# It might be worth creating a convenient way to re-use the same status.
with errors.raise_exception_on_not_ok_status() as status:
c_api.TF_FunctionSetAttrValueProto(self._c_func, compat.as_str(name),
serialized, status)

def _create_hash_str(self, input_arg, output_arg, node_def):
"""Creates an 8-character string unique to this input.
Expand Down
5 changes: 0 additions & 5 deletions tensorflow/python/framework/function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,6 @@ def Forward(x):
out, = sess.run(dx, feed)
self.assertAllClose(1 - np.square(np.tanh(inp)), out)

# C API functions don't support all optimizer options on cuda yet
@test_util.skip_if(test_util.c_api_and_cuda_enabled)
def testCustomGradient(self):
dtype = dtypes.float32

Expand Down Expand Up @@ -285,9 +283,6 @@ def testSymGradShape(self):
self.assertEqual(x.get_shape(), dx.get_shape())
self.assertEqual(y.get_shape(), dy.get_shape())

# C API functions don't support attributes yet (i.e. noinline).
# This attribute is required to run sucessfully with cuda.
@test_util.skip_if(test_util.c_api_and_cuda_enabled)
def testSymGradAttr(self):

@function.Defun(noinline=True)
Expand Down

0 comments on commit 7ad8e25

Please sign in to comment.