Skip to content

Commit

Permalink
Remove special handling of host-memory/device-memory for int32 argume…
Browse files Browse the repository at this point in the history
…nts in type lists.

The current special case behavior is preserved only for functions and their gradients.
Change: 130100547
  • Loading branch information
tensorflower-gardener committed Aug 12, 2016
1 parent 55b44e6 commit 84cefad
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 42 deletions.
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
default, simply pass the argument `state_is_tuple=False`.
* DeviceFactory's AddDevices and CreateDevices functions now return
a Status instead of void.
* Int32 elements of list(type) arguments are no longer placed in host memory by
default. If necessary, a list(type) argument to a kernel can be placed in host
memory using a HostMemory annotation.

# Release 0.10.0

Expand Down
42 changes: 7 additions & 35 deletions tensorflow/core/framework/memory_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,29 +61,6 @@ MemoryType MTypeFromDType(const DataType dtype) {
return (dtype == DT_INT32) ? HOST_MEMORY : DEVICE_MEMORY;
}

// Initialize the default memory types for type list arguments from the data
// types. (The default can be overridden by an explicit HostMemory()
// declaration.)
Status SetTypeListMTypesFromDTypes(
const NameRangeMap& name_ranges,
const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
const DataTypeVector& dtypes, MemoryTypeVector* mtypes) {
for (const auto& a : args) {
if (!a.type_list_attr().empty()) {
auto it = name_ranges.find(a.name());
if (it == name_ranges.end()) {
return errors::InvalidArgument("Name range for argument ", a.name(),
" not found.");
}

for (int i = it->second.first; i < it->second.second; ++i) {
(*mtypes)[i] = MTypeFromDType(dtypes[i]);
}
}
}
return Status::OK();
}

} // namespace

Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
Expand All @@ -107,12 +84,13 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
inp_mtypes->clear();
out_mtypes->clear();

if (!status.ok()) {
// When there is no kernel def for this op, we can only best-effort derive
// the memory type from the data type. For now, we assume int32 is always
// on host memory and other types are always on device memory. We should
// do type inference over function body to derive the correct
// input/output memory types.
// For functions (which have no KernelDef) and their gradients, we can only
// best-effort derive the memory type from the data type. For now, we assume
// int32 is always on host memory and other types are always on device memory.
// TODO(zhifengc,phawkins): We should do type inference over function bodies
// to derive the correct input/output memory types. We should also split
// host-memory and non host-memory arguments into separate type lists.
if (!status.ok() || ndef.op() == "SymbolicGradient") {
for (const auto& t : inp_dtypes) inp_mtypes->push_back(MTypeFromDType(t));
for (const auto& t : out_dtypes) out_mtypes->push_back(MTypeFromDType(t));
return Status::OK();
Expand All @@ -127,12 +105,6 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
inp_mtypes->resize(GetTotal(inp_names), DEVICE_MEMORY);
out_mtypes->resize(GetTotal(out_names), DEVICE_MEMORY);

// For type list arguments, mark int32 arguments as host memory.
TF_RETURN_IF_ERROR(SetTypeListMTypesFromDTypes(inp_names, op_def->input_arg(),
inp_dtypes, inp_mtypes));
TF_RETURN_IF_ERROR(SetTypeListMTypesFromDTypes(
out_names, op_def->output_arg(), out_dtypes, out_mtypes));

// Fills in host memory types based on the kernel def.
const auto& from_proto = kdef->host_memory_arg();
std::vector<string> host_memory_args(from_proto.begin(), from_proto.end());
Expand Down
8 changes: 4 additions & 4 deletions tensorflow/core/framework/memory_types_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ TEST(MemoryTypesForNode, Simple) {
TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_CPU, node_def,
&input, &output));
EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
DEVICE_MEMORY, DEVICE_MEMORY, HOST_MEMORY,
DEVICE_MEMORY, HOST_MEMORY}),
DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
DEVICE_MEMORY, DEVICE_MEMORY}),
input);
EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY}),
DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}),
output);

TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_GPU, node_def,
Expand All @@ -77,7 +77,7 @@ TEST(MemoryTypesForNode, Simple) {
HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
input);
EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, HOST_MEMORY, HOST_MEMORY,
HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY}),
DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}),
output);
}

Expand Down
16 changes: 13 additions & 3 deletions tensorflow/core/kernels/function_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,20 @@ REGISTER_KERNEL_BUILDER(Name("_Retval")

class PassOn : public OpKernel {
public:
explicit PassOn(OpKernelConstruction* ctx) : OpKernel(ctx) {}

void Compute(OpKernelContext* ctx) override {
explicit PassOn(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES(ctx, ctx->num_inputs() == ctx->num_outputs(),
errors::Internal("#inputs != #outputs : ", ctx->num_inputs(),
" vs. ", ctx->num_outputs()));
for (int i = 0; i < ctx->num_inputs(); ++i) {
OP_REQUIRES(
ctx, input_type(i) == output_type(i),
errors::Internal("Input and output types for position ", i,
" do not match: ", DataTypeString(input_type(i)),
" vs. ", DataTypeString(output_type(i))));
}
}

void Compute(OpKernelContext* ctx) override {
for (int i = 0; i < ctx->num_inputs(); ++i) {
ctx->set_output(i, ctx->input(i));
}
Expand All @@ -140,12 +148,14 @@ REGISTER_GPU_KERNELS(double);

REGISTER_KERNEL_BUILDER(Name("_ListToArray")
.Device(DEVICE_GPU)
.HostMemory("input")
.HostMemory("output")
.TypeConstraint<int32>("T"),
PassOn);
REGISTER_KERNEL_BUILDER(Name("_ArrayToList")
.Device(DEVICE_GPU)
.HostMemory("input")
.HostMemory("output")
.TypeConstraint<int32>("T"),
PassOn);

Expand Down

0 comments on commit 84cefad

Please sign in to comment.