Skip to content

Commit

Permalink
remove dead code, add insertAt helper
Browse files Browse the repository at this point in the history
  • Loading branch information
zdevito authored and soumith committed Sep 20, 2017
1 parent 6e495f5 commit 2996aad
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 12 deletions.
2 changes: 1 addition & 1 deletion torch/csrc/jit/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ void Graph::lint() const {
}
}

void Graph::dump() const {
void Graph::dump() {
std::cout << *this << "\n";
}

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ friend struct Node;
// Checks well-formedness and invariants of graph
void lint() const;
// for use in debugger
void dump() const;
void dump();

~Graph() {
for (const Node * n : all_nodes)
Expand Down
19 changes: 9 additions & 10 deletions torch/csrc/jit/passes/graph_fuser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ struct GraphFuser {
topological_index[n] = topological_index[after];
}

void insertAt(Node ** insertion_point, Node * n) {
insertAfter(n, *insertion_point);
*insertion_point = n;
}

Node * fuse(Node * consumer, Node * producer) {
auto group = consumer;
if(group->kind() != kFusionGroup) {
Expand Down Expand Up @@ -238,7 +243,7 @@ struct GraphFuser {

// Make sure we lay out the nodes in the correct topological order.
// TODO: There should be some more enshrined way to do this
Node *last_node = chunk;
Node * insertion_point = chunk;

// apply chunk to each of op's operands
// chunked_inputs[input_nr][chunk_output_idx]
Expand All @@ -253,8 +258,7 @@ struct GraphFuser {
input_chunk->setType(multiType());
input_chunk->copyAttributes(*chunk);
input_chunk->addInput(input);
insertAfter(input_chunk, last_node);
last_node = input_chunk;
insertAt(&insertion_point, input_chunk);
// TODO: Make this go away when we make helper function for
// setting up Selects.
size_t i = 0;
Expand All @@ -265,27 +269,22 @@ struct GraphFuser {
input_chunk_sel->setType(
input_type->withSizesStrides(chunk_sel_type->sizes(),
chunk_sel_type->strides()));
insertAfter(input_chunk_sel, last_node);
last_node = input_chunk_sel;
insertAt(&insertion_point, input_chunk_sel);
chunked_inputs.back().push_back(input_chunk_sel);
}
}

// apply the op to each chunk of the chunked operands,
// and then rewrite the graph to use them!
// NB: as we replace/remove the selects the use list changes, so copy it first
auto chunk_outputs = chunk->outputs();
for (auto chunk_sel : chunk->outputs()) {
auto chunk_sel_type = chunk_sel->type()->cast<TensorType>();
Node * chunked_op = graph->create(producer_for_chunk->kind());
chunked_op->copyAttributes(*producer_for_chunk);
// Invariant: mappable operators always produce contiguous output
chunked_op->setType(chunk_sel->type()->cast<TensorType>()->contiguous());
for (auto by_chunk_output_idx : chunked_inputs) {
chunked_op->addInput(by_chunk_output_idx.at(chunk_sel->offset()));
}
insertAfter(chunked_op, last_node);
last_node = chunked_op;
insertAt(&insertion_point, chunked_op);
chunk_sel->replaceAllUsesWith(chunked_op);
// NB: Temporarily breaking the Select invariant as we clean up
chunk_sel->destroy();
Expand Down

0 comments on commit 2996aad

Please sign in to comment.