Skip to content

Commit 69492ad

Browse files
Elias Ellisonfacebook-github-bot
authored andcommitted
remove tuple logic in constant propagation (pytorch#31840)
Summary: Pull Request resolved: pytorch#31840 The next PR in this stack makes tuples insertable as constants, so we can remove special handling of tuples in constant propagation. Test Plan: Imported from OSS Differential Revision: D19439515 Pulled By: eellison fbshipit-source-id: c58f153157f1d4eee4c1242decc4f36e41c1aa05
1 parent b01d824 commit 69492ad

File tree

3 files changed

+1
-140
lines changed

3 files changed

+1
-140
lines changed

test/cpp/jit/test_constant_propagation.cpp

Lines changed: 0 additions & 73 deletions
This file was deleted.

test/cpp/jit/tests.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ namespace jit {
4242
_(MemoryDAG) \
4343
_(IRParser) \
4444
_(ConstantPooling) \
45-
_(ConstantPropagation) \
4645
_(NetDefConverter) \
4746
_(THNNConv) \
4847
_(ATenNativeBatchNorm) \

torch/csrc/jit/passes/constant_propagation.cpp

Lines changed: 1 addition & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,6 @@ std::unordered_set<Symbol> skip_list = {
4848
// where the constant tensor would be large but cheap to create.
4949
};
5050

51-
std::unordered_set<Symbol> tuple_ops = {
52-
prim::TupleSlice,
53-
prim::TupleIndex,
54-
prim::TupleUnpack,
55-
prim::TupleConstruct,
56-
};
57-
5851
struct ConstantPropagator {
5952
// Runs constant propagation with an aliasing db and checks if inputs or
6053
// outputs might be mutated in the graph
@@ -82,20 +75,11 @@ struct ConstantPropagator {
8275
}
8376
}
8477

85-
void pushIValue(Value* v, Stack& stack) {
86-
if (tuples.count(v)) {
87-
const auto& ival = tuples[v];
88-
stack.push_back(ival);
89-
} else {
90-
stack.push_back(*toIValue(v));
91-
}
92-
}
93-
9478
std::vector<IValue> runNode(Node* n) {
9579
auto op = getOperation(n);
9680
Stack stack;
9781
for (auto input : n->inputs()) {
98-
pushIValue(input, stack);
82+
stack.push_back(*toIValue(input));
9983
}
10084
op(stack);
10185
auto var_outputs = fmap(stack, [&](IValue v) -> IValue {
@@ -117,34 +101,6 @@ struct ConstantPropagator {
117101
return var_outputs;
118102
}
119103

120-
// Tuples are not representable as constants, however
121-
// we can try to insert each tuple element and then create a TupleConstruct
122-
// from the elements
123-
Value* tryInsertTuple(const IValue& tuple, Value* tuple_to_replace) {
124-
auto type = tuple_to_replace->type();
125-
TupleTypePtr tup_type;
126-
if (auto opt = type->cast<OptionalType>()) {
127-
tup_type = opt->getElementType()->expect<TupleType>();
128-
} else {
129-
tup_type = type->expect<TupleType>();
130-
}
131-
auto type_elements = tup_type->elements();
132-
const auto& tuple_elements = tuple.toTuple()->elements();
133-
std::vector<Value*> inputs;
134-
for (size_t i = 0; i < type_elements.size(); ++i) {
135-
auto inp = tryInsertConstant(*graph_, tuple_elements[i]);
136-
if (inp) {
137-
inputs.push_back(*inp);
138-
} else {
139-
return nullptr;
140-
}
141-
}
142-
auto new_tuple = graph_->insertNode(graph_->createTuple(inputs));
143-
tuple_to_replace->replaceAllUsesWith(new_tuple->output());
144-
new_tuple->output()->copyMetadata(tuple_to_replace);
145-
return new_tuple->output();
146-
}
147-
148104
void propagateNode(Node* n) {
149105
std::vector<IValue> outputs;
150106
try {
@@ -168,19 +124,6 @@ struct ConstantPropagator {
168124
(*new_output)->setType(n->outputs()[i]->type());
169125
}
170126
n->outputs()[i]->replaceAllUsesWith(*new_output);
171-
} else if (outputs[i].isTuple()) {
172-
// we save the new Tuple ivalue in case it is used in an op that
173-
// forwards tuples later in the graph, such as a Tuple index
174-
auto tuple_val = n->outputs()[i];
175-
if (auto new_tup = tryInsertTuple(outputs[i], tuple_val)) {
176-
GRAPH_UPDATE(
177-
"Folding tuple %",
178-
n->outputs()[i]->debugName(),
179-
" with ",
180-
getHeader(new_tup->node()));
181-
tuple_val = new_tup;
182-
}
183-
tuples[tuple_val] = std::move(outputs[i]);
184127
}
185128
// If we cannot insert the IValue as a constant, give up replacing the
186129
// node and let DCE remove it
@@ -322,12 +265,6 @@ struct ConstantPropagator {
322265
})) {
323266
return true;
324267
}
325-
if (tuple_ops.count(n->kind())) {
326-
return (
327-
std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) {
328-
return v->node()->kind() == prim::Constant || tuples.count(v);
329-
}));
330-
}
331268
return false;
332269
};
333270

@@ -390,8 +327,6 @@ struct ConstantPropagator {
390327

391328
std::shared_ptr<Graph> graph_;
392329
std::unique_ptr<AliasDb> aliasDb_;
393-
// these are tuples which we know the computed IValue for
394-
std::unordered_map<Value*, IValue> tuples;
395330
};
396331
} // anonymous namespace
397332

0 commit comments

Comments
 (0)