diff --git a/include/taco/format.h b/include/taco/format.h index 7a46f2410..81bdadda4 100644 --- a/include/taco/format.h +++ b/include/taco/format.h @@ -97,7 +97,8 @@ class ModeFormat { /// Properties of a mode format enum Property { FULL, NOT_FULL, ORDERED, NOT_ORDERED, UNIQUE, NOT_UNIQUE, BRANCHLESS, - NOT_BRANCHLESS, COMPACT, NOT_COMPACT, ZEROLESS, NOT_ZEROLESS + NOT_BRANCHLESS, COMPACT, NOT_COMPACT, ZEROLESS, NOT_ZEROLESS, PADDED, + NOT_PADDED }; /// Instantiates an undefined mode format @@ -129,6 +130,7 @@ class ModeFormat { bool isBranchless() const; bool isCompact() const; bool isZeroless() const; + bool isPadded() const; /// Returns true if a mode format has a specific capability, false otherwise bool hasCoordValIter() const; diff --git a/include/taco/index_notation/index_notation.h b/include/taco/index_notation/index_notation.h index f7c7fad37..7923d22ca 100644 --- a/include/taco/index_notation/index_notation.h +++ b/include/taco/index_notation/index_notation.h @@ -637,6 +637,23 @@ class IndexStmt : public util::IntrusivePtr { /// reorder takes a new ordering for a set of index variables that are directly nested in the iteration order IndexStmt reorder(std::vector reorderedvars) const; + /// The mergeby transformation specifies how to merge iterators on + /// the given index variable. By default, if an iterator is used for windowing + /// it will be merged with the "gallop" strategy. + /// All other iterators are merged with the "two finger" strategy. + /// The two finger strategy merges by advancing each iterator one at a time, + /// while the gallop strategy implements the exponential search algorithm. + /// + /// Preconditions: + /// This command applies to variables involving sparse iterators only; + /// it is a no-op if the variable invovles any dense iterators. + /// Any variable can be merged with the two finger strategy, whereas gallop + /// only applies to a variable if its merge lattice has a single point + /// (i.e. an intersection). For example, if a variable involves multiplications + /// only, it can be merged with gallop. + /// Furthermore, all iterators must be ordered for gallop to apply. + IndexStmt mergeby(IndexVar i, MergeStrategy strategy) const; + /// The parallelize /// transformation tags an index variable for parallel execution. The /// transformation takes as an argument the type of parallel hardware @@ -835,13 +852,14 @@ class Forall : public IndexStmt { Forall() = default; Forall(const ForallNode*); Forall(IndexVar indexVar, IndexStmt stmt); - Forall(IndexVar indexVar, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0); + Forall(IndexVar indexVar, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0); IndexVar getIndexVar() const; IndexStmt getStmt() const; ParallelUnit getParallelUnit() const; OutputRaceStrategy getOutputRaceStrategy() const; + MergeStrategy getMergeStrategy() const; size_t getUnrollFactor() const; @@ -850,7 +868,7 @@ class Forall : public IndexStmt { /// Create a forall index statement. Forall forall(IndexVar i, IndexStmt stmt); -Forall forall(IndexVar i, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0); +Forall forall(IndexVar i, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0); /// A where statment has a producer statement that binds a tensor variable in diff --git a/include/taco/index_notation/index_notation_nodes.h b/include/taco/index_notation/index_notation_nodes.h index 5289ff069..0feea404e 100644 --- a/include/taco/index_notation/index_notation_nodes.h +++ b/include/taco/index_notation/index_notation_nodes.h @@ -398,8 +398,8 @@ struct YieldNode : public IndexStmtNode { }; struct ForallNode : public IndexStmtNode { - ForallNode(IndexVar indexVar, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0) - : indexVar(indexVar), stmt(stmt), parallel_unit(parallel_unit), output_race_strategy(output_race_strategy), unrollFactor(unrollFactor) {} + ForallNode(IndexVar indexVar, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0) + : indexVar(indexVar), stmt(stmt), merge_strategy(merge_strategy), parallel_unit(parallel_unit), output_race_strategy(output_race_strategy), unrollFactor(unrollFactor) {} void accept(IndexStmtVisitorStrict* v) const { v->visit(this); @@ -407,6 +407,7 @@ struct ForallNode : public IndexStmtNode { IndexVar indexVar; IndexStmt stmt; + MergeStrategy merge_strategy; ParallelUnit parallel_unit; OutputRaceStrategy output_race_strategy; size_t unrollFactor = 0; diff --git a/include/taco/index_notation/transformations.h b/include/taco/index_notation/transformations.h index f898c92b9..b750e3961 100644 --- a/include/taco/index_notation/transformations.h +++ b/include/taco/index_notation/transformations.h @@ -22,6 +22,7 @@ class AddSuchThatPredicates; class Parallelize; class TopoReorder; class SetAssembleStrategy; +class SetMergeStrategy; /// A transformation is an optimization that transforms a statement in the /// concrete index notation into a new statement that computes the same result @@ -36,6 +37,7 @@ class Transformation { Transformation(TopoReorder); Transformation(AddSuchThatPredicates); Transformation(SetAssembleStrategy); + Transformation(SetMergeStrategy); IndexStmt apply(IndexStmt stmt, std::string *reason = nullptr) const; @@ -206,6 +208,25 @@ class SetAssembleStrategy : public TransformationInterface { /// Print a SetAssembleStrategy command. std::ostream &operator<<(std::ostream &, const SetAssembleStrategy&); +class SetMergeStrategy : public TransformationInterface { +public: + SetMergeStrategy(IndexVar i, MergeStrategy strategy); + + IndexVar geti() const; + MergeStrategy getMergeStrategy() const; + + IndexStmt apply(IndexStmt stmt, std::string *reason = nullptr) const; + + void print(std::ostream &os) const; + +private: + struct Content; + std::shared_ptr content; +}; + +/// Print a SetMergeStrategy command. +std::ostream &operator<<(std::ostream &, const SetMergeStrategy&); + // Autoscheduling functions /** diff --git a/include/taco/ir_tags.h b/include/taco/ir_tags.h index 5858a13e3..be5d8f6bb 100644 --- a/include/taco/ir_tags.h +++ b/include/taco/ir_tags.h @@ -33,6 +33,13 @@ enum class AssembleStrategy { }; extern const char *AssembleStrategy_NAMES[]; +/// MergeStrategy::TwoFinger merges iterators by incrementing one at a time +/// MergeStrategy::Galloping merges iterators by exponential search (galloping) +enum class MergeStrategy { + TwoFinger, Gallop +}; +extern const char *MergeStrategy_NAMES[]; + } #endif //TACO_IR_TAGS_H diff --git a/include/taco/lower/lowerer_impl_imperative.h b/include/taco/lower/lowerer_impl_imperative.h index 4498b37f0..fa97e3cd9 100644 --- a/include/taco/lower/lowerer_impl_imperative.h +++ b/include/taco/lower/lowerer_impl_imperative.h @@ -146,15 +146,18 @@ class LowererImplImperative : public LowererImpl { * \param statement * A concrete index notation statement to compute at the points in the * sparse iteration space described by the merge lattice. + * \param mergeStrategy + * A strategy for merging iterators. One of TwoFinger or Gallop. * * \return * IR code to compute the forall loop. */ virtual ir::Stmt lowerMergeLattice(MergeLattice lattice, IndexVar coordinateVar, IndexStmt statement, - const std::set& reducedAccesses); + const std::set& reducedAccesses, + MergeStrategy mergeStrategy); - virtual ir::Stmt resolveCoordinate(std::vector mergers, ir::Expr coordinate, bool emitVarDecl); + virtual ir::Stmt resolveCoordinate(std::vector mergers, ir::Expr coordinate, bool emitVarDecl, bool mergeWithMax); /** * Lower the merge point at the top of the given lattice to code that iterates @@ -169,15 +172,20 @@ class LowererImplImperative : public LowererImpl { * coordinate the merge point is at. * A concrete index notation statement to compute at the points in the * sparse iteration space region described by the merge point. + * \param mergeWithMax + * A boolean indicating whether coordinates should be combined with MAX instead of MIN. + * MAX is needed when the iterators are merged with the Gallop strategy. */ virtual ir::Stmt lowerMergePoint(MergeLattice pointLattice, ir::Expr coordinate, IndexVar coordinateVar, IndexStmt statement, - const std::set& reducedAccesses, bool resolvedCoordDeclared); + const std::set& reducedAccesses, bool resolvedCoordDeclared, + MergeStrategy mergestrategy); /// Lower a merge lattice to cases. virtual ir::Stmt lowerMergeCases(ir::Expr coordinate, IndexVar coordinateVar, IndexStmt stmt, MergeLattice lattice, - const std::set& reducedAccesses); + const std::set& reducedAccesses, + MergeStrategy mergeStrategy); /// Lower a forall loop body. virtual ir::Stmt lowerForallBody(ir::Expr coordinate, IndexStmt stmt, @@ -185,7 +193,8 @@ class LowererImplImperative : public LowererImpl { std::vector inserters, std::vector appenders, MergeLattice caseLattice, - const std::set& reducedAccesses); + const std::set& reducedAccesses, + MergeStrategy mergeStrategy); /// Lower a where statement. @@ -375,7 +384,7 @@ class LowererImplImperative : public LowererImpl { /// Conditionally increment iterator position variables. ir::Stmt codeToIncIteratorVars(ir::Expr coordinate, IndexVar coordinateVar, - std::vector iterators, std::vector mergers); + std::vector iterators, std::vector mergers, MergeStrategy strategy); ir::Stmt codeToLoadCoordinatesFromPosIterators(std::vector iterators, bool declVars); @@ -410,7 +419,8 @@ class LowererImplImperative : public LowererImpl { /// Lowers a merge lattice to cases assuming there are no more loops to be emitted in stmt. /// Will emit checks for explicit zeros for each mode iterator and each locator in the lattice. ir::Stmt lowerMergeCasesWithExplicitZeroChecks(ir::Expr coordinate, IndexVar coordinateVar, IndexStmt stmt, - MergeLattice lattice, const std::set& reducedAccesses); + MergeLattice lattice, const std::set& reducedAccesses, + MergeStrategy mergeStrategy); /// Constructs cases comparing the coordVar for each iterator to the resolved coordinate. /// Returns a vector where coordComparisons[i] corresponds to a case for iters[i] @@ -444,7 +454,7 @@ class LowererImplImperative : public LowererImpl { /// The map must be of iterators to exprs of boolean types std::vector lowerCasesFromMap(std::map iteratorToCondition, ir::Expr coordinate, IndexStmt stmt, const MergeLattice& lattice, - const std::set& reducedAccesses); + const std::set& reducedAccesses, MergeStrategy mergeStrategy); /// Constructs an expression which checks if this access is "zero" ir::Expr constructCheckForAccessZero(Access); diff --git a/include/taco/lower/mode_format_impl.h b/include/taco/lower/mode_format_impl.h index 3e2cbad66..824802286 100644 --- a/include/taco/lower/mode_format_impl.h +++ b/include/taco/lower/mode_format_impl.h @@ -106,9 +106,10 @@ class ModeFormatImpl { public: ModeFormatImpl(std::string name, bool isFull, bool isOrdered, bool isUnique, bool isBranchless, bool isCompact, bool isZeroless, - bool hasCoordValIter, bool hasCoordPosIter, bool hasLocate, - bool hasInsert, bool hasAppend, bool hasSeqInsertEdge, - bool hasInsertCoord, bool isYieldPosPure); + bool isPadded, bool hasCoordValIter, bool hasCoordPosIter, + bool hasLocate, bool hasInsert, bool hasAppend, + bool hasSeqInsertEdge, bool hasInsertCoord, + bool isYieldPosPure); virtual ~ModeFormatImpl(); @@ -246,6 +247,7 @@ class ModeFormatImpl { const bool isBranchless; const bool isCompact; const bool isZeroless; + const bool isPadded; const bool hasCoordValIter; const bool hasCoordPosIter; diff --git a/include/taco/lower/mode_format_singleton.h b/include/taco/lower/mode_format_singleton.h index 65bfac2c0..1be6ce53a 100644 --- a/include/taco/lower/mode_format_singleton.h +++ b/include/taco/lower/mode_format_singleton.h @@ -10,8 +10,9 @@ class SingletonModeFormat : public ModeFormatImpl { using ModeFormatImpl::getInsertCoord; SingletonModeFormat(); - SingletonModeFormat(bool isFull, bool isOrdered, - bool isUnique, bool isZeroless, long long allocSize = DEFAULT_ALLOC_SIZE); + SingletonModeFormat(bool isFull, bool isOrdered, bool isUnique, + bool isZeroless, bool isPadded, + long long allocSize = DEFAULT_ALLOC_SIZE); ~SingletonModeFormat() override {} diff --git a/src/codegen/codegen_c.cpp b/src/codegen/codegen_c.cpp index b57eb51bd..d53e3b06c 100644 --- a/src/codegen/codegen_c.cpp +++ b/src/codegen/codegen_c.cpp @@ -62,6 +62,28 @@ const string cHeaders = "int cmp(const void *a, const void *b) {\n" " return *((const int*)a) - *((const int*)b);\n" "}\n" + // Increment arrayStart until array[arrayStart] >= target or arrayStart >= arrayEnd + // using an exponential search algorithm: https://en.wikipedia.org/wiki/Exponential_search. + "int taco_gallop(int *array, int arrayStart, int arrayEnd, int target) {\n" + " if (array[arrayStart] >= target || arrayStart >= arrayEnd) {\n" + " return arrayStart;\n" + " }\n" + " int step = 1;\n" + " int curr = arrayStart;\n" + " while (curr + step < arrayEnd && array[curr + step] < target) {\n" + " curr += step;\n" + " step = step * 2;\n" + " }\n" + "\n" + " step = step / 2;\n" + " while (step > 0) {\n" + " if (curr + step < arrayEnd && array[curr + step] < target) {\n" + " curr += step;\n" + " }\n" + " step = step / 2;\n" + " }\n" + " return curr+1;\n" + "}\n" "int taco_binarySearchAfter(int *array, int arrayStart, int arrayEnd, int target) {\n" " if (array[arrayStart] >= target) {\n" " return arrayStart;\n" diff --git a/src/format.cpp b/src/format.cpp index a5144ae24..0424a1a52 100644 --- a/src/format.cpp +++ b/src/format.cpp @@ -187,6 +187,11 @@ bool ModeFormat::hasProperties(const std::vector& properties) const { return false; } break; + case PADDED: + if (!isPadded()) { + return false; + } + break; case NOT_FULL: if (isFull()) { return false; @@ -217,6 +222,11 @@ bool ModeFormat::hasProperties(const std::vector& properties) const { return false; } break; + case NOT_PADDED: + if (isPadded()) { + return false; + } + break; } } return true; @@ -252,6 +262,11 @@ bool ModeFormat::isZeroless() const { return impl->isZeroless; } +bool ModeFormat::isPadded() const { + taco_iassert(defined()); + return impl->isPadded; +} + bool ModeFormat::hasCoordValIter() const { taco_iassert(defined()); return impl->hasCoordValIter; diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index f2d877367..4afd0ed94 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -1907,6 +1907,15 @@ IndexStmt IndexStmt::reorder(std::vector reorderedvars) const { return transformed; } +IndexStmt IndexStmt::mergeby(IndexVar i, MergeStrategy strategy) const { + string reason; + IndexStmt transformed = SetMergeStrategy(i, strategy).apply(*this, &reason); + if (!transformed.defined()) { + taco_uerror << reason; + } + return transformed; +} + IndexStmt IndexStmt::parallelize(IndexVar i, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy) const { string reason; IndexStmt transformed = Parallelize(i, parallel_unit, output_race_strategy).apply(*this, &reason); @@ -2017,7 +2026,7 @@ IndexStmt IndexStmt::unroll(IndexVar i, size_t unrollFactor) const { void visit(const ForallNode* node) { if (node->indexVar == i) { - stmt = Forall(i, rewrite(node->stmt), node->parallel_unit, node->output_race_strategy, unrollFactor); + stmt = Forall(i, rewrite(node->stmt), node->merge_strategy, node->parallel_unit, node->output_race_strategy, unrollFactor); } else { IndexNotationRewriter::visit(node); @@ -2169,11 +2178,11 @@ Forall::Forall(const ForallNode* n) : IndexStmt(n) { } Forall::Forall(IndexVar indexVar, IndexStmt stmt) - : Forall(indexVar, stmt, ParallelUnit::NotParallel, OutputRaceStrategy::IgnoreRaces) { + : Forall(indexVar, stmt, MergeStrategy::TwoFinger, ParallelUnit::NotParallel, OutputRaceStrategy::IgnoreRaces) { } -Forall::Forall(IndexVar indexVar, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor) - : Forall(new ForallNode(indexVar, stmt, parallel_unit, output_race_strategy, unrollFactor)) { +Forall::Forall(IndexVar indexVar, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor) + : Forall(new ForallNode(indexVar, stmt, merge_strategy, parallel_unit, output_race_strategy, unrollFactor)) { } IndexVar Forall::getIndexVar() const { @@ -2192,6 +2201,10 @@ OutputRaceStrategy Forall::getOutputRaceStrategy() const { return getNode(*this)->output_race_strategy; } +MergeStrategy Forall::getMergeStrategy() const { + return getNode(*this)->merge_strategy; +} + size_t Forall::getUnrollFactor() const { return getNode(*this)->unrollFactor; } @@ -2200,8 +2213,8 @@ Forall forall(IndexVar i, IndexStmt stmt) { return Forall(i, stmt); } -Forall forall(IndexVar i, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor) { - return Forall(i, stmt, parallel_unit, output_race_strategy, unrollFactor); +Forall forall(IndexVar i, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor) { + return Forall(i, stmt, merge_strategy, parallel_unit, output_race_strategy, unrollFactor); } template <> bool isa(IndexStmt s) { @@ -3982,7 +3995,7 @@ struct Zero : public IndexNotationRewriterStrict { stmt = op; } else { - stmt = new ForallNode(op->indexVar, body, op->parallel_unit, op->output_race_strategy, op->unrollFactor); + stmt = new ForallNode(op->indexVar, body, op->merge_strategy, op->parallel_unit, op->output_race_strategy, op->unrollFactor); } } @@ -4199,6 +4212,22 @@ IndexStmt generatePackStmt(TensorVar tensor, packStmt = forall(indexVars[mode], packStmt); } + bool doAppend = true; + const Format lhsFormat = otherIsOnRight ? format : otherFormat; + for (int i = lhsFormat.getOrder() - 1; i >= 0; --i) { + const auto modeFormat = lhsFormat.getModeFormats()[i]; + if (modeFormat.isBranchless() && i != 0) { + const auto parentModeFormat = lhsFormat.getModeFormats()[i - 1]; + if (parentModeFormat.isUnique() || !parentModeFormat.hasAppend()) { + doAppend = false; + break; + } + } + } + if (!doAppend) { + packStmt = packStmt.assemble(otherIsOnRight ? tensor : other, AssembleStrategy::Insert); + } + return packStmt; } diff --git a/src/index_notation/index_notation_rewriter.cpp b/src/index_notation/index_notation_rewriter.cpp index 5caa2da4b..e8c123af1 100644 --- a/src/index_notation/index_notation_rewriter.cpp +++ b/src/index_notation/index_notation_rewriter.cpp @@ -185,7 +185,7 @@ void IndexNotationRewriter::visit(const ForallNode* op) { stmt = op; } else { - stmt = new ForallNode(op->indexVar, s, op->parallel_unit, op->output_race_strategy, op->unrollFactor); + stmt = new ForallNode(op->indexVar, s, op->merge_strategy, op->parallel_unit, op->output_race_strategy, op->unrollFactor); } } @@ -406,7 +406,7 @@ struct ReplaceIndexVars : public IndexNotationRewriter { stmt = op; } else { - stmt = new ForallNode(iv, s, op->parallel_unit, op->output_race_strategy, + stmt = new ForallNode(iv, s, op->merge_strategy, op->parallel_unit, op->output_race_strategy, op->unrollFactor); } } diff --git a/src/index_notation/transformations.cpp b/src/index_notation/transformations.cpp index 1dd0cc71c..3ac83240d 100644 --- a/src/index_notation/transformations.cpp +++ b/src/index_notation/transformations.cpp @@ -22,1974 +22,1969 @@ using namespace std; namespace taco { // class Transformation - Transformation::Transformation(Reorder reorder) - : transformation(new Reorder(reorder)) { - } - - Transformation::Transformation(Precompute precompute) - : transformation(new Precompute(precompute)) { - } +Transformation::Transformation(Reorder reorder) + : transformation(new Reorder(reorder)) { +} - Transformation::Transformation(ForAllReplace forallreplace) - : transformation(new ForAllReplace(forallreplace)) { - } +Transformation::Transformation(Precompute precompute) + : transformation(new Precompute(precompute)) { +} - Transformation::Transformation(Parallelize parallelize) - : transformation(new Parallelize(parallelize)) { - } +Transformation::Transformation(ForAllReplace forallreplace) + : transformation(new ForAllReplace(forallreplace)) { +} - Transformation::Transformation(AddSuchThatPredicates addsuchthatpredicates) - : transformation(new AddSuchThatPredicates(addsuchthatpredicates)) { - } +Transformation::Transformation(SetMergeStrategy setmergestrategy) + : transformation(new SetMergeStrategy(setmergestrategy)) { +} - IndexStmt Transformation::apply(IndexStmt stmt, string* reason) const { - return transformation->apply(stmt, reason); - } +Transformation::Transformation(Parallelize parallelize) + : transformation(new Parallelize(parallelize)) { +} - std::ostream& operator<<(std::ostream& os, const Transformation& t) { - t.transformation->print(os); - return os; - } +Transformation::Transformation(AddSuchThatPredicates addsuchthatpredicates) + : transformation(new AddSuchThatPredicates(addsuchthatpredicates)) { +} +IndexStmt Transformation::apply(IndexStmt stmt, string* reason) const { + return transformation->apply(stmt, reason); +} -// class Reorder - struct Reorder::Content { - std::vector replacePattern; - bool pattern_ordered; // In case of Reorder(i, j) need to change replacePattern ordering to actually reorder - }; - - Reorder::Reorder(IndexVar i, IndexVar j) : content(new Content) { - content->replacePattern = {i, j}; - content->pattern_ordered = false; - } +std::ostream& operator<<(std::ostream& os, const Transformation& t) { + t.transformation->print(os); + return os; +} - Reorder::Reorder(std::vector replacePattern) : content(new Content) { - content->replacePattern = replacePattern; - content->pattern_ordered = true; - } - IndexVar Reorder::geti() const { - return content->replacePattern[0]; - } +// class Reorder +struct Reorder::Content { + std::vector replacePattern; + bool pattern_ordered; // In case of Reorder(i, j) need to change replacePattern ordering to actually reorder +}; + +Reorder::Reorder(IndexVar i, IndexVar j) : content(new Content) { + content->replacePattern = {i, j}; + content->pattern_ordered = false; +} - IndexVar Reorder::getj() const { - if (content->replacePattern.size() == 1) { - return geti(); - } - return content->replacePattern[1]; - } +Reorder::Reorder(std::vector replacePattern) : content(new Content) { + content->replacePattern = replacePattern; + content->pattern_ordered = true; +} - const std::vector& Reorder::getreplacepattern() const { - return content->replacePattern; - } +IndexVar Reorder::geti() const { + return content->replacePattern[0]; +} - IndexStmt Reorder::apply(IndexStmt stmt, string* reason) const { - INIT_REASON(reason); +IndexVar Reorder::getj() const { + if (content->replacePattern.size() == 1) { + return geti(); + } + return content->replacePattern[1]; +} - string r; - if (!isConcreteNotation(stmt, &r)) { - *reason = "The index statement is not valid concrete index notation: " + r; - return IndexStmt(); - } +const std::vector& Reorder::getreplacepattern() const { + return content->replacePattern; +} - // collect current ordering of IndexVars - bool startedMatch = false; - std::vector currentOrdering; - bool matchFailed = false; - - match(stmt, - std::function([&](const ForallNode* op) { - bool matches = std::find (getreplacepattern().begin(), getreplacepattern().end(), op->indexVar) != getreplacepattern().end(); - if (matches) { - currentOrdering.push_back(op->indexVar); - startedMatch = true; - } - else if (startedMatch && currentOrdering.size() != getreplacepattern().size()) { - matchFailed = true; - } - }) - ); +IndexStmt Reorder::apply(IndexStmt stmt, string* reason) const { + INIT_REASON(reason); + + string r; + if (!isConcreteNotation(stmt, &r)) { + *reason = "The index statement is not valid concrete index notation: " + r; + return IndexStmt(); + } + + // collect current ordering of IndexVars + bool startedMatch = false; + std::vector currentOrdering; + bool matchFailed = false; + + match(stmt, + std::function([&](const ForallNode* op) { + bool matches = std::find (getreplacepattern().begin(), getreplacepattern().end(), op->indexVar) != getreplacepattern().end(); + if (matches) { + currentOrdering.push_back(op->indexVar); + startedMatch = true; + } + else if (startedMatch && currentOrdering.size() != getreplacepattern().size()) { + matchFailed = true; + } + }) + ); + + if (!content->pattern_ordered && currentOrdering == getreplacepattern()) { + taco_iassert(getreplacepattern().size() == 2); + content->replacePattern = {getreplacepattern()[1], getreplacepattern()[0]}; + } + + if (matchFailed || currentOrdering.size() != getreplacepattern().size()) { + *reason = "The foralls of reorder pattern: " + util::join(getreplacepattern()) + " were not directly nested."; + return IndexStmt(); + } + return ForAllReplace(currentOrdering, getreplacepattern()).apply(stmt, reason); +} - if (!content->pattern_ordered && currentOrdering == getreplacepattern()) { - taco_iassert(getreplacepattern().size() == 2); - content->replacePattern = {getreplacepattern()[1], getreplacepattern()[0]}; - } +void Reorder::print(std::ostream& os) const { + os << "reorder(" << util::join(getreplacepattern()) << ")"; +} - if (matchFailed || currentOrdering.size() != getreplacepattern().size()) { - *reason = "The foralls of reorder pattern: " + util::join(getreplacepattern()) + " were not directly nested."; - return IndexStmt(); - } - return ForAllReplace(currentOrdering, getreplacepattern()).apply(stmt, reason); - } +std::ostream& operator<<(std::ostream& os, const Reorder& reorder) { + reorder.print(os); + return os; +} - void Reorder::print(std::ostream& os) const { - os << "reorder(" << util::join(getreplacepattern()) << ")"; - } +struct SetMergeStrategy::Content { + IndexVar i_var; + MergeStrategy strategy; +}; - std::ostream& operator<<(std::ostream& os, const Reorder& reorder) { - reorder.print(os); - return os; - } +SetMergeStrategy::SetMergeStrategy(IndexVar i, MergeStrategy strategy) : content(new Content) { + content->i_var = i; + content->strategy = strategy; +} +IndexVar SetMergeStrategy::geti() const { + return content->i_var; +} -// class Precompute - struct Precompute::Content { - IndexExpr expr; - std::vector i_vars; - std::vector iw_vars; - TensorVar workspace; - }; - - Precompute::Precompute() : content(nullptr) { - } +MergeStrategy SetMergeStrategy::getMergeStrategy() const { + return content->strategy; +} - Precompute::Precompute(IndexExpr expr, IndexVar i, IndexVar iw, - TensorVar workspace) : content(new Content) { - std::vector i_vars{i}; - std::vector iw_vars{iw}; - content->expr = expr; - content->i_vars = i_vars; - content->iw_vars = iw_vars; - content->workspace = workspace; - } +IndexStmt SetMergeStrategy::apply(IndexStmt stmt, string* reason) const { + INIT_REASON(reason); + + string r; + if (!isConcreteNotation(stmt, &r)) { + *reason = "The index statement is not valid concrete index notation: " + r; + return IndexStmt(); + } + + struct SetMergeStrategyRewriter : public IndexNotationRewriter { + using IndexNotationRewriter::visit; + + ProvenanceGraph provGraph; + map tensorVars; + set definedIndexVars; + + SetMergeStrategy transformation; + string reason; + SetMergeStrategyRewriter(SetMergeStrategy transformation) + : transformation(transformation) {} + + IndexStmt setmergestrategy(IndexStmt stmt) { + provGraph = ProvenanceGraph(stmt); + tensorVars = createIRTensorVars(stmt); + return rewrite(stmt); + } + + void visit(const ForallNode* node) { + Forall foralli(node); + IndexVar i = transformation.geti(); + + definedIndexVars.insert(foralli.getIndexVar()); + + if (foralli.getIndexVar() == i) { + Iterators iterators(foralli, tensorVars); + MergeLattice lattice = MergeLattice::make(foralli, iterators, provGraph, + definedIndexVars); + for (auto iterator : lattice.iterators()) { + if (!iterator.isOrdered()) { + reason = "Precondition failed: Variable " + + i.getName() + + " is not ordered and cannot be galloped."; + return; + } + } + if (lattice.points().size() != 1) { + reason = "Precondition failed: The merge lattice of variable " + + i.getName() + + " has more than 1 point and cannot be merged by galloping"; + return; + } - Precompute::Precompute(IndexExpr expr, std::vector i_vars, - std::vector iw_vars, - TensorVar workspace) : content(new Content) { - content->expr = expr; - content->i_vars = i_vars; - content->iw_vars = iw_vars; - content->workspace = workspace; - } + MergeStrategy strategy = transformation.getMergeStrategy(); + stmt = rewrite(foralli.getStmt()); + stmt = Forall(node->indexVar, stmt, strategy, node->parallel_unit, + node->output_race_strategy, node->unrollFactor); + return; + } + IndexNotationRewriter::visit(node); + } + }; + SetMergeStrategyRewriter rewriter = SetMergeStrategyRewriter(*this); + IndexStmt rewritten = rewriter.setmergestrategy(stmt); + if (!rewriter.reason.empty()) { + *reason = rewriter.reason; + return IndexStmt(); + } + return rewritten; +} - IndexExpr Precompute::getExpr() const { - return content->expr; - } +void SetMergeStrategy::print(std::ostream& os) const { + os << "mergeby(" << geti() << ", " + << MergeStrategy_NAMES[(int)getMergeStrategy()] << ")"; +} - std::vector& Precompute::getIVars() const { - return content->i_vars; - } +std::ostream& operator<<(std::ostream& os, const SetMergeStrategy& setmergestrategy) { + setmergestrategy.print(os); + return os; +} - std::vector& Precompute::getIWVars() const { - return content->iw_vars; - } +// class Precompute +struct Precompute::Content { + IndexExpr expr; + std::vector i_vars; + std::vector iw_vars; + TensorVar workspace; +}; + +Precompute::Precompute() : content(nullptr) { +} - TensorVar Precompute::getWorkspace() const { - return content->workspace; - } +Precompute::Precompute(IndexExpr expr, IndexVar i, IndexVar iw, + TensorVar workspace) : content(new Content) { + std::vector i_vars{i}; + std::vector iw_vars{iw}; + content->expr = expr; + content->i_vars = i_vars; + content->iw_vars = iw_vars; + content->workspace = workspace; +} - static bool containsExpr(Assignment assignment, IndexExpr expr) { - struct ContainsVisitor : public IndexNotationVisitor { - using IndexNotationVisitor::visit; + Precompute::Precompute(IndexExpr expr, std::vector i_vars, + std::vector iw_vars, + TensorVar workspace) : content(new Content) { + content->expr = expr; + content->i_vars = i_vars; + content->iw_vars = iw_vars; + content->workspace = workspace; +} + +IndexExpr Precompute::getExpr() const { + return content->expr; +} - IndexExpr expr; - bool contains = false; +std::vector& Precompute::getIVars() const { + return content->i_vars; +} - void visit(const AccessNode* node) { - if (equals(IndexExpr(node), expr)) { - contains = true; - } - } +std::vector& Precompute::getIWVars() const { + return content->iw_vars; +} - void visit(const UnaryExprNode* node) { - if (equals(IndexExpr(node), expr)) { - contains = true; - } - else { - IndexNotationVisitor::visit(node); - } - } +TensorVar Precompute::getWorkspace() const { + return content->workspace; +} - void visit(const BinaryExprNode* node) { - if (equals(IndexExpr(node), expr)) { - contains = true; - } - else { - IndexNotationVisitor::visit(node); - } - } +static bool containsExpr(Assignment assignment, IndexExpr expr) { + struct ContainsVisitor : public IndexNotationVisitor { + using IndexNotationVisitor::visit; - void visit(const ReductionNode* node) { - taco_ierror << "Reduction node in concrete index notation."; - } - }; + IndexExpr expr; + bool contains = false; - ContainsVisitor visitor; - visitor.expr = expr; - visitor.visit(assignment); - return visitor.contains; + void visit(const AccessNode* node) { + if (equals(IndexExpr(node), expr)) { + contains = true; + } } - static Assignment getAssignmentContainingExpr(IndexStmt stmt, IndexExpr expr) { - Assignment assignment; - match(stmt, - function([&assignment, &expr]( - const AssignmentNode* node, Matcher* ctx) { - if (containsExpr(node, expr)) { - assignment = node; - } - }) - ); - return assignment; + void visit(const UnaryExprNode* node) { + if (equals(IndexExpr(node), expr)) { + contains = true; + } + else { + IndexNotationVisitor::visit(node); + } } - static IndexStmt eliminateRedundantReductions(IndexStmt stmt, - const std::set* const candidates = nullptr) { - - struct ReduceToAssign : public IndexNotationRewriter { - using IndexNotationRewriter::visit; - - const std::set* const candidates; - std::map> availableVars; + void visit(const BinaryExprNode* node) { + if (equals(IndexExpr(node), expr)) { + contains = true; + } + else { + IndexNotationVisitor::visit(node); + } + } - ReduceToAssign(const std::set* const candidates) : - candidates(candidates) {} + void visit(const ReductionNode* node) { + taco_ierror << "Reduction node in concrete index notation."; + } + }; - IndexStmt rewrite(IndexStmt stmt) { - for (const auto& result : getResults(stmt)) { - availableVars[result] = {}; - } - return IndexNotationRewriter::rewrite(stmt); - } + ContainsVisitor visitor; + visitor.expr = expr; + visitor.visit(assignment); + return visitor.contains; +} - void visit(const ForallNode* op) { - for (auto& it : availableVars) { - it.second.insert(op->indexVar); - } - IndexNotationRewriter::visit(op); - for (auto& it : availableVars) { - it.second.erase(op->indexVar); - } - } +static Assignment getAssignmentContainingExpr(IndexStmt stmt, IndexExpr expr) { + Assignment assignment; + match(stmt, + function([&assignment, &expr]( + const AssignmentNode* node, Matcher* ctx) { + if (containsExpr(node, expr)) { + assignment = node; + } + }) + ); + return assignment; +} - void visit(const WhereNode* op) { - const auto workspaces = getResults(op->producer); - for (const auto& workspace : workspaces) { - availableVars[workspace] = {}; - } - IndexNotationRewriter::visit(op); - for (const auto& workspace : workspaces) { - availableVars.erase(workspace); - } - } +static IndexStmt eliminateRedundantReductions(IndexStmt stmt, + const std::set* const candidates = nullptr) { + + struct ReduceToAssign : public IndexNotationRewriter { + using IndexNotationRewriter::visit; + + const std::set* const candidates; + std::map> availableVars; + + ReduceToAssign(const std::set* const candidates) : + candidates(candidates) {} + + IndexStmt rewrite(IndexStmt stmt) { + for (const auto& result : getResults(stmt)) { + availableVars[result] = {}; + } + return IndexNotationRewriter::rewrite(stmt); + } + + void visit(const ForallNode* op) { + for (auto& it : availableVars) { + it.second.insert(op->indexVar); + } + IndexNotationRewriter::visit(op); + for (auto& it : availableVars) { + it.second.erase(op->indexVar); + } + } + + void visit(const WhereNode* op) { + const auto workspaces = getResults(op->producer); + for (const auto& workspace : workspaces) { + availableVars[workspace] = {}; + } + IndexNotationRewriter::visit(op); + for (const auto& workspace : workspaces) { + availableVars.erase(workspace); + } + } + + void visit(const AssignmentNode* op) { + const auto result = op->lhs.getTensorVar(); + if (op->op.defined() && + util::toSet(op->lhs.getIndexVars()) == availableVars[result] && + (!candidates || util::contains(*candidates, result))) { + stmt = Assignment(op->lhs, op->rhs); + return; + } + stmt = op; + } + }; + return ReduceToAssign(candidates).rewrite(stmt); +} - void visit(const AssignmentNode* op) { - const auto result = op->lhs.getTensorVar(); - if (op->op.defined() && - util::toSet(op->lhs.getIndexVars()) == availableVars[result] && - (!candidates || util::contains(*candidates, result))) { - stmt = Assignment(op->lhs, op->rhs); - return; - } - stmt = op; - } - }; - return ReduceToAssign(candidates).rewrite(stmt); +IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const { + INIT_REASON(reason); + + // Precondition: The expr to precompute is not in `stmt` + Assignment assignment = getAssignmentContainingExpr(stmt, getExpr()); + if (!assignment.defined()) { + *reason = "The expression (" + util::toString(getExpr()) + ") " + + "is not in " + util::toString(stmt); + return IndexStmt(); + } + + vector forallIndexVars; + match(stmt, + function([&](const ForallNode* op) { + forallIndexVars.push_back(op->indexVar); + }) + ); + + ProvenanceGraph provGraph = ProvenanceGraph(stmt); + + struct PrecomputeRewriter : public IndexNotationRewriter { + using IndexNotationRewriter::visit; + Precompute precompute; + ProvenanceGraph provGraph; + vector forallIndexVarList; + + Assignment getConsumerAssignment(IndexStmt stmt, TensorVar& ws) { + Assignment a = Assignment(); + match(stmt, + function([&](const AssignmentNode* op, Matcher* ctx) { + a = Assignment(op); + }), + function([&](const WhereNode* op, Matcher* ctx) { + ctx->match(op->consumer); + ctx->match(op->producer); + }), + function([&](const AccessNode* op, Matcher* ctx) { + if (op->tensorVar == ws) { + return; + } + }) + ); + + if (!a.getReductionVars().empty()) { + a = Assignment(a.getLhs(), a.getRhs(), Add()); + } else { + a = Assignment(a.getLhs(), a.getRhs()); + } + return a; + } + + Assignment getProducerAssignment(TensorVar& ws, + const std::vector& i_vars, + const std::vector& iw_vars, + const IndexExpr& e, + map substitutions) { + + auto assignment = ws(iw_vars) = replace(e, substitutions); + if (!assignment.getReductionVars().empty()) + assignment = Assignment(assignment.getLhs(), assignment.getRhs(), Add()); + return assignment; + } + + IndexStmt generateForalls(IndexStmt stmt, vector indexVars) { + auto returnStmt = stmt; + for (auto &i : indexVars) { + returnStmt = forall(i, returnStmt); + } + return returnStmt; + } + + bool containsIndexVarScheduled(vector indexVars, + IndexVar indexVar) { + bool contains = false; + for (auto &i : indexVars) { + if (i == indexVar) { + contains = true; + } else if (provGraph.isFullyDerived(indexVar) && !provGraph.isFullyDerived(i)) { + for (auto &child : provGraph.getFullyDerivedDescendants(i)) { + if (child == indexVar) + contains = true; + } + } else if (provGraph.isFullyDerived(indexVar) && !provGraph.isFullyDerived(i)) { + for (auto &child : provGraph.getFullyDerivedDescendants(indexVar)) { + if (child == i) + contains = true; + } + } + } + return contains; } - IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const { - INIT_REASON(reason); + void visit(const ForallNode* node) { + Forall foralli(node); + std::vector i_vars = precompute.getIVars(); - // Precondition: The expr to precompute is in `stmt` - Assignment assignment = getAssignmentContainingExpr(stmt, getExpr()); - if (!assignment.defined()) { - *reason = "The expression (" + util::toString(getExpr()) + ") " + - "is not in " + util::toString(stmt); - return IndexStmt(); - } + bool containsWhere = false; + match(foralli, + function([&](const WhereNode* op) { + containsWhere = true; + }) + ); + + if (!containsWhere) { vector forallIndexVars; - match(stmt, + match(foralli, function([&](const ForallNode* op) { - forallIndexVars.push_back(op->indexVar); + forallIndexVars.push_back(op->indexVar); }) ); - ProvenanceGraph provGraph = ProvenanceGraph(stmt); - - - - struct PrecomputeRewriter : public IndexNotationRewriter { - using IndexNotationRewriter::visit; - Precompute precompute; - ProvenanceGraph provGraph; - vector forallIndexVarList; - - Assignment getConsumerAssignment(IndexStmt stmt, TensorVar& ws) { - Assignment a = Assignment(); - match(stmt, - function([&](const AssignmentNode* op, Matcher* ctx) { - a = Assignment(op); - }), - function([&](const WhereNode* op, Matcher* ctx) { - ctx->match(op->consumer); - ctx->match(op->producer); - }), - function([&](const AccessNode* op, Matcher* ctx) { - if (op->tensorVar == ws) { - return; - } - }) - ); - - IndexSetRel rel = a.getIndexSetRel(); - switch (rel) { - case none: a = Assignment(a.getLhs(), a.getRhs());break; // = - case rcl: a = Assignment(a.getLhs(), a.getRhs(), Add());break; // += - case lcr: a = Assignment(a.getLhs(), a.getRhs());break; // = - case inter: a = Assignment(a.getLhs(), a.getRhs(), Add());break; // += - case equal: a = Assignment(a.getLhs(), a.getRhs());break;// = OR += - }return a; - } + IndexStmt s = foralli.getStmt(); + TensorVar ws = precompute.getWorkspace(); + IndexExpr e = precompute.getExpr(); + std::vector iw_vars = precompute.getIWVars(); - Assignment getProducerAssignment(TensorVar& ws, - const std::vector& i_vars, - const std::vector& iw_vars, - const IndexExpr& e, - map substitutions) { - - auto a = ws(iw_vars) = replace(e, substitutions); - IndexSetRel rel = a.getIndexSetRel(); - switch (rel) { - case none: a = Assignment(a.getLhs(), a.getRhs());break; // = - case rcl: a = Assignment(a.getLhs(), a.getRhs(), Add());break; // += - case lcr: a = Assignment(a.getLhs(), a.getRhs());break; // = - case inter: a = Assignment(a.getLhs(), a.getRhs(), Add());break; // += - case equal: a = Assignment(a.getLhs(), a.getRhs());break;// = OR += - } + map substitutions; + taco_iassert(i_vars.size() == iw_vars.size()) << "i_vars and iw_vars lists must be the same size"; - return a; - } + for (int index = 0; index < (int)i_vars.size(); index++) { + substitutions[i_vars[index]] = iw_vars[index]; + } - IndexStmt generateForalls(IndexStmt stmt, vector indexVars) { - auto returnStmt = stmt; - for (auto &i : indexVars) { - returnStmt = forall(i, returnStmt); - } + // Build consumer by replacing with temporary (in replacedStmt) + IndexStmt replacedStmt = replace(s, {{e, ws(i_vars) }}); + if (replacedStmt != s) { + // Then modify the replacedStmt to have the correct foralls + // by concretizing the consumer assignment - return returnStmt; - } + auto consumerAssignment = getConsumerAssignment(replacedStmt, ws); + auto consumerIndexVars = consumerAssignment.getIndexVars(); - bool containsIndexVarScheduled(vector indexVars, - IndexVar indexVar) { - bool contains = false; - for (auto &i : indexVars) { - if (i == indexVar) { - contains = true; - } else if (provGraph.isFullyDerived(indexVar) && !provGraph.isFullyDerived(i)) { - for (auto &child : provGraph.getFullyDerivedDescendants(i)) { - if (child == indexVar) - contains = true; - } - } - } - return contains; - } + auto producerAssignment = getProducerAssignment(ws, i_vars, iw_vars, e, substitutions); + auto producerIndexVars = producerAssignment.getIndexVars(); - void visit(const ForallNode* node) { - Forall foralli(node); - std::vector i_vars = precompute.getIVars(); - - bool containsWhere = false; - match(foralli, - function([&](const WhereNode* op) { - containsWhere = true; - }) - ); - - if (!containsWhere) { - vector forallIndexVars; - match(foralli, - function([&](const ForallNode* op) { - forallIndexVars.push_back(op->indexVar); - }) - ); - - IndexStmt s = foralli.getStmt(); - TensorVar ws = precompute.getWorkspace(); - IndexExpr e = precompute.getExpr(); - std::vector iw_vars = precompute.getIWVars(); - - map substitutions; - taco_iassert(i_vars.size() == iw_vars.size()) << "i_vars and iw_vars lists must be the same size"; - - for (int index = 0; index < (int)i_vars.size(); index++) { - substitutions[i_vars[index]] = iw_vars[index]; - } + vector producerForallIndexVars; + vector consumerForallIndexVars; + vector outerForallIndexVars; - // Build consumer by replacing with temporary (in replacedStmt) - IndexStmt replacedStmt = replace(s, {{e, ws(i_vars) }}); - if (replacedStmt != s) { - // Then modify the replacedStmt to have the correct foralls - // by concretizing the consumer assignment - - auto consumerAssignment = getConsumerAssignment(replacedStmt, ws); - auto consumerIndexVars = consumerAssignment.getIndexVars(); - - auto producerAssignment = getProducerAssignment(ws, i_vars, iw_vars, e, substitutions); - auto producerIndexVars = producerAssignment.getIndexVars(); - - vector producerForallIndexVars; - vector consumerForallIndexVars; - vector outerForallIndexVars; - - bool stopForallDistribution = false; - for (auto &i : util::reverse(forallIndexVars)) { - if (!stopForallDistribution && containsIndexVarScheduled(i_vars, i)) { - producerForallIndexVars.push_back(substitutions[i]); - consumerForallIndexVars.push_back(i); - } else { - auto consumerContains = containsIndexVarScheduled(consumerIndexVars, i); - auto producerContains = containsIndexVarScheduled(producerIndexVars, i); - if (stopForallDistribution || (producerContains && consumerContains)) { - outerForallIndexVars.push_back(i); - stopForallDistribution = true; - } else if (!stopForallDistribution && consumerContains) { - consumerForallIndexVars.push_back(i); - } else if (!stopForallDistribution && producerContains) { - producerForallIndexVars.push_back(i); - } - } - } - IndexStmt consumer = generateForalls(consumerAssignment, consumerForallIndexVars); - - IndexStmt producer = generateForalls(producerAssignment, producerForallIndexVars); - Where where(consumer, producer); - - stmt = generateForalls(where, outerForallIndexVars); - - return; - } - } - IndexNotationRewriter::visit(node); + bool stopForallDistribution = false; + for (auto &i : util::reverse(forallIndexVars)) { + if (!stopForallDistribution && containsIndexVarScheduled(i_vars, i)) { + producerForallIndexVars.push_back(substitutions[i]); + consumerForallIndexVars.push_back(i); + } else { + auto consumerContains = containsIndexVarScheduled(consumerIndexVars, i); + auto producerContains = containsIndexVarScheduled(producerIndexVars, i); + if (stopForallDistribution || (producerContains && consumerContains)) { + outerForallIndexVars.push_back(i); + stopForallDistribution = true; + } else if (!stopForallDistribution && consumerContains) { + consumerForallIndexVars.push_back(i); + } else if (!stopForallDistribution && producerContains) { + producerForallIndexVars.push_back(i); + } } - }; - - struct RedundentVisitor: public IndexNotationVisitor { - using IndexNotationVisitor::visit; - - std::vector& to_change; - std::vector ctx_stack; - std::vector num_stack; - int ctx_num; - const ProvenanceGraph& provGraph; - - RedundentVisitor(std::vector& to_change, const ProvenanceGraph& provGraph):to_change(to_change), ctx_num(0), provGraph(provGraph){} + } - void visit(const ForallNode* node) { - Forall foralli(node); - IndexVar var = foralli.getIndexVar(); - ctx_stack.push_back(var); - if (! num_stack.empty()) { - num_stack.back()++; - } - if (num_stack.empty()) { - num_stack.push_back(1); - } - IndexNotationVisitor::visit(node); - } - void visit(const WhereNode* node) { - num_stack.push_back(0); - IndexNotationVisitor::visit(node->consumer); - ctx_num = num_stack.back(); - for (int i = 0; i < ctx_num; i++){ - ctx_stack.pop_back(); - } - num_stack.pop_back(); - num_stack.push_back(0); - IndexNotationVisitor::visit(node->producer); - ctx_num = num_stack.back(); - for (int i = 0; i < ctx_num; i++){ - ctx_stack.pop_back(); - } - num_stack.pop_back(); - } - void visit(const AssignmentNode* node) { - Assignment a(node->lhs, node->rhs, node->op); - vector freeVars = a.getLhs().getIndexVars(); - set seen(freeVars.begin(), freeVars.end()); - bool has_sibling = false; - match(a.getRhs(), - std::function([&](const AccessNode* op) { - for (auto& var : op->indexVars) { - for (auto& svar : ctx_stack) { - if ((provGraph.getUnderivedAncestors(var)[0] == provGraph.getUnderivedAncestors(svar)[0]) && svar != var) { - has_sibling = true; - } - } - } - })); - bool is_equal = (a.getIndexSetRel() == equal); - bool is_none = (a.getIndexSetRel() == none); - if (is_equal && has_sibling) { - to_change.push_back(a); - } - if (is_none && has_sibling && ctx_num > 1) { - to_change.push_back(a); - } - /* - bool has_outside = false; - for (auto & var : seen) { - if (var!=ctx_stack.back()){ - has_outside = true; - break; - } - } - if (is_none && has_sibling && ctx_num == 1 && has_outside) { - to_change.push_back(a); - } - */ - } - }; + IndexStmt consumer = generateForalls(consumerAssignment, consumerForallIndexVars); - struct RedundentRewriter: public IndexNotationRewriter { - using IndexNotationRewriter::visit; - std::set to_change; - RedundentRewriter(std::vector& to_change):to_change(to_change.begin(),to_change.end()){} - - void visit(const AssignmentNode* node) { - Assignment a(node->lhs, node->rhs, node->op); - for (auto & v: to_change) { - if ((v.getLhs() == a.getLhs()) && (v.getRhs() == a.getRhs()) ) { - stmt = Assignment(a.getLhs(), a.getRhs(), Add()); - return; - } - } - IndexNotationRewriter::visit(node); - } + IndexStmt producer = generateForalls(producerAssignment, producerForallIndexVars); + Where where(consumer, producer); + stmt = generateForalls(where, outerForallIndexVars); + return; + } + } + IndexNotationRewriter::visit(node); + } + }; - }; + PrecomputeRewriter rewriter; + rewriter.precompute = *this; + rewriter.provGraph = provGraph; + rewriter.forallIndexVarList = forallIndexVars; + stmt = rewriter.rewrite(stmt); - PrecomputeRewriter rewriter; - rewriter.precompute = *this; - rewriter.provGraph = provGraph; - rewriter.forallIndexVarList = forallIndexVars; - stmt = rewriter.rewrite(stmt); - std::vector to_change; - RedundentVisitor findVisitor(to_change, provGraph); - stmt.accept(&findVisitor); - RedundentRewriter ReRewriter(to_change); - stmt = ReRewriter.rewrite(stmt); - return stmt; - } + return stmt; +} - void Precompute::print(std::ostream& os) const { - os << "precompute(" << getExpr() << ", " << getIVars() << ", " - << getIWVars() << ", " << getWorkspace() << ")"; - } +void Precompute::print(std::ostream& os) const { + os << "precompute(" << getExpr() << ", " << getIVars() << ", " + << getIWVars() << ", " << getWorkspace() << ")"; +} - bool Precompute::defined() const { - return content != nullptr; - } +bool Precompute::defined() const { + return content != nullptr; +} - std::ostream& operator<<(std::ostream& os, const Precompute& precompute) { - precompute.print(os); - return os; - } +std::ostream& operator<<(std::ostream& os, const Precompute& precompute) { + precompute.print(os); + return os; +} // class ForAllReplace - struct ForAllReplace::Content { - std::vector pattern; - std::vector replacement; - }; +struct ForAllReplace::Content { + std::vector pattern; + std::vector replacement; +}; - ForAllReplace::ForAllReplace() : content(nullptr) { - } - - ForAllReplace::ForAllReplace(std::vector pattern, std::vector replacement) : content(new Content) { - taco_iassert(!pattern.empty()); - content->pattern = pattern; - content->replacement = replacement; - } +ForAllReplace::ForAllReplace() : content(nullptr) { +} - std::vector ForAllReplace::getPattern() const { - return content->pattern; - } +ForAllReplace::ForAllReplace(std::vector pattern, std::vector replacement) : content(new Content) { + taco_iassert(!pattern.empty()); + content->pattern = pattern; + content->replacement = replacement; +} - std::vector ForAllReplace::getReplacement() const { - return content->replacement; - } +std::vector ForAllReplace::getPattern() const { + return content->pattern; +} - IndexStmt ForAllReplace::apply(IndexStmt stmt, string* reason) const { - INIT_REASON(reason); +std::vector ForAllReplace::getReplacement() const { + return content->replacement; +} - string r; - if (!isConcreteNotation(stmt, &r)) { - *reason = "The index statement is not valid concrete index notation: " + r; - return IndexStmt(); +IndexStmt ForAllReplace::apply(IndexStmt stmt, string* reason) const { + INIT_REASON(reason); + + string r; + if (!isConcreteNotation(stmt, &r)) { + *reason = "The index statement is not valid concrete index notation: " + r; + return IndexStmt(); + } + + /// Since all IndexVars can only appear once, assume replacement will work and error if it doesn't + struct ForAllReplaceRewriter : public IndexNotationRewriter { + using IndexNotationRewriter::visit; + + ForAllReplace transformation; + string* reason; + int elementsMatched = 0; + ForAllReplaceRewriter(ForAllReplace transformation, string* reason) + : transformation(transformation), reason(reason) {} + + IndexStmt forallreplace(IndexStmt stmt) { + IndexStmt replaced = rewrite(stmt); + + // Precondition: Did not find pattern + if (replaced == stmt || elementsMatched == -1) { + *reason = "The pattern of ForAlls: " + + util::join(transformation.getPattern()) + + " was not found while attempting to replace with: " + + util::join(transformation.getReplacement()); + return IndexStmt(); + } + return replaced; + } + + void visit(const ForallNode* node) { + Forall foralli(node); + vector pattern = transformation.getPattern(); + if (elementsMatched == -1) { + return; // pattern did not match + } + + if(elementsMatched >= (int) pattern.size()) { + IndexNotationRewriter::visit(node); + return; + } + + if (foralli.getIndexVar() == pattern[elementsMatched]) { + if (elementsMatched + 1 < (int) pattern.size() && !isa(foralli.getStmt())) { + // child is not a forallnode (not directly nested) + elementsMatched = -1; + return; } + // assume rest of pattern matches + vector replacement = transformation.getReplacement(); + bool firstMatch = (elementsMatched == 0); + elementsMatched++; + stmt = rewrite(foralli.getStmt()); + if (firstMatch) { + // add replacement nodes and cut out this node + for (auto i = replacement.rbegin(); i != replacement.rend(); ++i ) { + stmt = forall(*i, stmt); + } + } + // else cut out this node + return; + } + else if (elementsMatched > 0) { + elementsMatched = -1; // pattern did not match + return; + } + // before pattern match + IndexNotationRewriter::visit(node); + } + }; + return ForAllReplaceRewriter(*this, reason).forallreplace(stmt); +} - /// Since all IndexVars can only appear once, assume replacement will work and error if it doesn't - struct ForAllReplaceRewriter : public IndexNotationRewriter { - using IndexNotationRewriter::visit; - - ForAllReplace transformation; - string* reason; - int elementsMatched = 0; - ForAllReplaceRewriter(ForAllReplace transformation, string* reason) - : transformation(transformation), reason(reason) {} - - IndexStmt forallreplace(IndexStmt stmt) { - IndexStmt replaced = rewrite(stmt); - - // Precondition: Did not find pattern - if (replaced == stmt || elementsMatched == -1) { - *reason = "The pattern of ForAlls: " + - util::join(transformation.getPattern()) + - " was not found while attempting to replace with: " + - util::join(transformation.getReplacement()); - return IndexStmt(); - } - return replaced; - } - - void visit(const ForallNode* node) { - Forall foralli(node); - vector pattern = transformation.getPattern(); - if (elementsMatched == -1) { - return; // pattern did not match - } - - if(elementsMatched >= (int) pattern.size()) { - IndexNotationRewriter::visit(node); - return; - } - - if (foralli.getIndexVar() == pattern[elementsMatched]) { - if (elementsMatched + 1 < (int) pattern.size() && !isa(foralli.getStmt())) { - // child is not a forallnode (not directly nested) - elementsMatched = -1; - return; - } - // assume rest of pattern matches - vector replacement = transformation.getReplacement(); - bool firstMatch = (elementsMatched == 0); - elementsMatched++; - stmt = rewrite(foralli.getStmt()); - if (firstMatch) { - // add replacement nodes and cut out this node - for (auto i = replacement.rbegin(); i != replacement.rend(); ++i ) { - stmt = forall(*i, stmt); - } - } - // else cut out this node - return; - } - else if (elementsMatched > 0) { - elementsMatched = -1; // pattern did not match - return; - } - // before pattern match - IndexNotationRewriter::visit(node); - } - }; - return ForAllReplaceRewriter(*this, reason).forallreplace(stmt); - } - - void ForAllReplace::print(std::ostream& os) const { - os << "forallreplace(" << util::join(getPattern()) << ", " << util::join(getReplacement()) << ")"; - } +void ForAllReplace::print(std::ostream& os) const { + os << "forallreplace(" << util::join(getPattern()) << ", " << util::join(getReplacement()) << ")"; +} - std::ostream& operator<<(std::ostream& os, const ForAllReplace& forallreplace) { - forallreplace.print(os); - return os; - } +std::ostream& operator<<(std::ostream& os, const ForAllReplace& forallreplace) { + forallreplace.print(os); + return os; +} // class AddSuchThatRels - struct AddSuchThatPredicates::Content { - std::vector predicates; - }; +struct AddSuchThatPredicates::Content { + std::vector predicates; +}; - AddSuchThatPredicates::AddSuchThatPredicates() : content(nullptr) { - } +AddSuchThatPredicates::AddSuchThatPredicates() : content(nullptr) { +} - AddSuchThatPredicates::AddSuchThatPredicates(std::vector predicates) : content(new Content) { - taco_iassert(!predicates.empty()); - content->predicates = predicates; - } +AddSuchThatPredicates::AddSuchThatPredicates(std::vector predicates) : content(new Content) { + taco_iassert(!predicates.empty()); + content->predicates = predicates; +} - std::vector AddSuchThatPredicates::getPredicates() const { - return content->predicates; - } +std::vector AddSuchThatPredicates::getPredicates() const { + return content->predicates; +} - IndexStmt AddSuchThatPredicates::apply(IndexStmt stmt, string* reason) const { - INIT_REASON(reason); +IndexStmt AddSuchThatPredicates::apply(IndexStmt stmt, string* reason) const { + INIT_REASON(reason); + + string r; + if (!isConcreteNotation(stmt, &r)) { + *reason = "The index statement is not valid concrete index notation: " + r; + return IndexStmt(); + } + + if (isa(stmt)) { + SuchThat suchThat = to(stmt); + vector predicate = suchThat.getPredicate(); + vector predicates = getPredicates(); + predicate.insert(predicate.end(), predicates.begin(), predicates.end()); + return SuchThat(suchThat.getStmt(), predicate); + } + else{ + return SuchThat(stmt, content->predicates); + } +} - string r; - if (!isConcreteNotation(stmt, &r)) { - *reason = "The index statement is not valid concrete index notation: " + r; - return IndexStmt(); - } +void AddSuchThatPredicates::print(std::ostream& os) const { + os << "addsuchthatpredicates(" << util::join(getPredicates()) << ")"; +} - if (isa(stmt)) { - SuchThat suchThat = to(stmt); - vector predicate = suchThat.getPredicate(); - vector predicates = getPredicates(); - predicate.insert(predicate.end(), predicates.begin(), predicates.end()); - return SuchThat(suchThat.getStmt(), predicate); - } - else{ - return SuchThat(stmt, content->predicates); - } - } +std::ostream& operator<<(std::ostream& os, const AddSuchThatPredicates& addSuchThatPredicates) { + addSuchThatPredicates.print(os); + return os; +} - void AddSuchThatPredicates::print(std::ostream& os) const { - os << "addsuchthatpredicates(" << util::join(getPredicates()) << ")"; +struct ReplaceReductionExpr : public IndexNotationRewriter { + const std::map& substitutions; + ReplaceReductionExpr(const std::map& substitutions) + : substitutions(substitutions) {} + using IndexNotationRewriter::visit; + void visit(const AssignmentNode* node) { + if (util::contains(substitutions, node->lhs)) { + stmt = Assignment(substitutions.at(node->lhs), rewrite(node->rhs), node->op); } - - std::ostream& operator<<(std::ostream& os, const AddSuchThatPredicates& addSuchThatPredicates) { - addSuchThatPredicates.print(os); - return os; + else { + IndexNotationRewriter::visit(node); } - - struct ReplaceReductionExpr : public IndexNotationRewriter { - const std::map& substitutions; - ReplaceReductionExpr(const std::map& substitutions) - : substitutions(substitutions) {} - using IndexNotationRewriter::visit; - void visit(const AssignmentNode* node) { - if (util::contains(substitutions, node->lhs)) { - stmt = Assignment(substitutions.at(node->lhs), rewrite(node->rhs), node->op); - } - else { - IndexNotationRewriter::visit(node); - } - } - }; + } +}; - IndexStmt scalarPromote(IndexStmt stmt, ProvenanceGraph provGraph, - bool isWholeStmt, bool promoteScalar); +IndexStmt scalarPromote(IndexStmt stmt, ProvenanceGraph provGraph, + bool isWholeStmt, bool promoteScalar); // class Parallelize - struct Parallelize::Content { - IndexVar i; - ParallelUnit parallel_unit; - OutputRaceStrategy output_race_strategy; - }; +struct Parallelize::Content { + IndexVar i; + ParallelUnit parallel_unit; + OutputRaceStrategy output_race_strategy; +}; - Parallelize::Parallelize() : content(nullptr) { - } - - Parallelize::Parallelize(IndexVar i) : Parallelize(i, ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces) {} - - Parallelize::Parallelize(IndexVar i, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy) : content(new Content) { - content->i = i; - content->parallel_unit = parallel_unit; - content->output_race_strategy = output_race_strategy; - } - - - IndexVar Parallelize::geti() const { - return content->i; - } - - ParallelUnit Parallelize::getParallelUnit() const { - return content->parallel_unit; - } - - OutputRaceStrategy Parallelize::getOutputRaceStrategy() const { - return content->output_race_strategy; - } - - IndexStmt Parallelize::apply(IndexStmt stmt, std::string* reason) const { - INIT_REASON(reason); - - struct ParallelizeRewriter : public IndexNotationRewriter { - using IndexNotationRewriter::visit; - - Parallelize parallelize; - ProvenanceGraph provGraph; - map tensorVars; - vector assembledByUngroupedInsert; - set definedIndexVars; - set reductionIndexVars; - set parentParallelUnits; - std::string reason = ""; - - IndexStmt rewriteParallel(IndexStmt stmt) { - provGraph = ProvenanceGraph(stmt); - - const auto reductionVars = getReductionVars(stmt); - - reductionIndexVars.clear(); - for (const auto& iv : stmt.getIndexVars()) { - if (util::contains(reductionVars, iv)) { - for (const auto& rv : provGraph.getFullyDerivedDescendants(iv)) { - reductionIndexVars.insert(rv); - } - } - } - - tensorVars = createIRTensorVars(stmt); - - assembledByUngroupedInsert.clear(); - for (const auto& result : getAssembledByUngroupedInsertion(stmt)) { - assembledByUngroupedInsert.push_back(tensorVars[result]); - } - - return rewrite(stmt); - } +Parallelize::Parallelize() : content(nullptr) { +} - void visit(const ForallNode* node) { - Forall foralli(node); - IndexVar i = parallelize.geti(); +Parallelize::Parallelize(IndexVar i) : Parallelize(i, ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces) {} - definedIndexVars.insert(foralli.getIndexVar()); +Parallelize::Parallelize(IndexVar i, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy) : content(new Content) { + content->i = i; + content->parallel_unit = parallel_unit; + content->output_race_strategy = output_race_strategy; +} - if (foralli.getIndexVar() == i) { - // Precondition 1: No parallelization of reduction variables - if (parallelize.getOutputRaceStrategy() == OutputRaceStrategy::NoRaces && - util::contains(reductionIndexVars, i)) { - reason = "Precondition failed: Cannot parallelize reduction loops " - "without synchronization"; - return; - } - Iterators iterators(foralli, tensorVars); - MergeLattice lattice = MergeLattice::make(foralli, iterators, provGraph, - definedIndexVars); +IndexVar Parallelize::geti() const { + return content->i; +} - // Precondition 2: No coiteration of modes (i.e., merge lattice has - // only one iterator) - if (lattice.iterators().size() != 1) { - reason = "Precondition failed: The loop must not merge tensor " - "dimensions, that is, it must be a for loop;"; - return; - } +ParallelUnit Parallelize::getParallelUnit() const { + return content->parallel_unit; +} - vector underivedAncestors = provGraph.getUnderivedAncestors(i); - IndexVar underivedAncestor = underivedAncestors.back(); - - // Get lattice that corresponds to underived ancestor. This is - // bottom-most loop that shares underived ancestor - Forall underivedForall = foralli; - match(foralli.getStmt(), - function([&](const ForallNode* node) { - const auto nodeUnderivedAncestors = - provGraph.getUnderivedAncestors(node->indexVar); - definedIndexVars.insert(node->indexVar); - if (underivedAncestor == nodeUnderivedAncestors.back()) { - underivedForall = Forall(node); - } - }) - ); - MergeLattice underivedLattice = MergeLattice::make(underivedForall, - iterators, provGraph, - definedIndexVars); - - // Precondition 3: Every result iterator must have insert capability - for (Iterator iterator : underivedLattice.results()) { - if (util::contains(assembledByUngroupedInsert, iterator.getTensor())) { - for (Iterator it = iterator; !it.isRoot(); it = it.getParent()) { - if (it.hasInsertCoord() || !it.isYieldPosPure()) { - reason = "Precondition failed: The output tensor does not " - "support parallelized inserts"; - return; - } - } - } else { - while (true) { - if (!iterator.hasInsert()) { - reason = "Precondition failed: The output tensor must support " - "inserts"; - return; - } - if (iterator.isLeaf()) { - break; - } - iterator = iterator.getChild(); - } - } - } +OutputRaceStrategy Parallelize::getOutputRaceStrategy() const { + return content->output_race_strategy; +} - if (parallelize.getOutputRaceStrategy() == OutputRaceStrategy::Temporary && - util::contains(reductionIndexVars, underivedForall.getIndexVar())) { - // Need to precompute reduction - - // Find all occurrences of reduction in expression - vector precomputeAssignments; - match(foralli.getStmt(), - function([&](const AssignmentNode* node) { - for (auto underivedVar : underivedAncestors) { - vector reductionVars = Assignment(node).getReductionVars(); - bool reducedByI = - find(reductionVars.begin(), reductionVars.end(), underivedVar) != reductionVars.end(); - if (reducedByI) { - precomputeAssignments.push_back(node); - break; - } - } - }) - ); - taco_iassert(!precomputeAssignments.empty()); - - IndexStmt precomputed_stmt = forall(i, foralli.getStmt(), parallelize.getParallelUnit(), parallelize.getOutputRaceStrategy(), foralli.getUnrollFactor()); - for (auto assignment : precomputeAssignments) { - // Construct temporary of correct type and size of outer loop - TensorVar w(string("w_") + ParallelUnit_NAMES[(int) parallelize.getParallelUnit()], Type(assignment->lhs.getDataType(), {Dimension(i)}), taco::dense); - - // rewrite producer to write to temporary, mark producer as parallel - IndexStmt producer = ReplaceReductionExpr(map({{assignment->lhs, w(i)}})).rewrite(precomputed_stmt); - taco_iassert(isa(producer)); - Forall producer_forall = to(producer); - producer = forall(producer_forall.getIndexVar(), producer_forall.getStmt(), parallelize.getParallelUnit(), parallelize.getOutputRaceStrategy(), foralli.getUnrollFactor()); - - // build consumer that writes from temporary to output, mark consumer as parallel reduction - ParallelUnit reductionUnit = ParallelUnit::CPUThreadGroupReduction; - if (should_use_CUDA_codegen()) { - if (parentParallelUnits.count(ParallelUnit::GPUWarp)) { - reductionUnit = ParallelUnit::GPUWarpReduction; - } - else { - reductionUnit = ParallelUnit::GPUBlockReduction; - } - } - IndexStmt consumer = forall(i, Assignment(assignment->lhs, w(i), assignment->op), reductionUnit, OutputRaceStrategy::ParallelReduction); - precomputed_stmt = where(consumer, producer); - } - stmt = precomputed_stmt; - return; - } +IndexStmt Parallelize::apply(IndexStmt stmt, std::string* reason) const { + INIT_REASON(reason); - if (parallelize.getOutputRaceStrategy() == OutputRaceStrategy::Atomics) { - // want to avoid extra atomics by accumulating variable and then - // reducing at end - IndexStmt body = scalarPromote(foralli.getStmt(), provGraph, - false, true); - stmt = forall(i, body, parallelize.getParallelUnit(), - parallelize.getOutputRaceStrategy(), - foralli.getUnrollFactor()); - return; - } + struct ParallelizeRewriter : public IndexNotationRewriter { + using IndexNotationRewriter::visit; + Parallelize parallelize; + ProvenanceGraph provGraph; + map tensorVars; + vector assembledByUngroupedInsert; + set definedIndexVars; + set reductionIndexVars; + set parentParallelUnits; + std::string reason = ""; - stmt = forall(i, foralli.getStmt(), parallelize.getParallelUnit(), parallelize.getOutputRaceStrategy(), foralli.getUnrollFactor()); - return; - } + IndexStmt rewriteParallel(IndexStmt stmt) { + provGraph = ProvenanceGraph(stmt); - if (foralli.getParallelUnit() != ParallelUnit::NotParallel) { - parentParallelUnits.insert(foralli.getParallelUnit()); - } - IndexNotationRewriter::visit(node); - } - }; + const auto reductionVars = getReductionVars(stmt); - ParallelizeRewriter rewriter; - rewriter.parallelize = *this; - IndexStmt rewritten = rewriter.rewriteParallel(stmt); - if (!rewriter.reason.empty()) { - *reason = rewriter.reason; - return IndexStmt(); + reductionIndexVars.clear(); + for (const auto& iv : stmt.getIndexVars()) { + if (util::contains(reductionVars, iv)) { + for (const auto& rv : provGraph.getFullyDerivedDescendants(iv)) { + reductionIndexVars.insert(rv); + } } - return rewritten; - } - - - void Parallelize::print(std::ostream& os) const { - os << "parallelize(" << geti() << ")"; - } - - - std::ostream& operator<<(std::ostream& os, const Parallelize& parallelize) { - parallelize.print(os); - return os; - } + } + tensorVars = createIRTensorVars(stmt); -// class SetAssembleStrategy + assembledByUngroupedInsert.clear(); + for (const auto& result : getAssembledByUngroupedInsertion(stmt)) { + assembledByUngroupedInsert.push_back(tensorVars[result]); + } - struct SetAssembleStrategy::Content { - TensorVar result; - AssembleStrategy strategy; - bool separatelySchedulable; - }; - - SetAssembleStrategy::SetAssembleStrategy(TensorVar result, - AssembleStrategy strategy, - bool separatelySchedulable) : - content(new Content) { - content->result = result; - content->strategy = strategy; - content->separatelySchedulable = separatelySchedulable; + return rewrite(stmt); } - TensorVar SetAssembleStrategy::getResult() const { - return content->result; - } + void visit(const ForallNode* node) { + Forall foralli(node); + IndexVar i = parallelize.geti(); - AssembleStrategy SetAssembleStrategy::getAssembleStrategy() const { - return content->strategy; - } + definedIndexVars.insert(foralli.getIndexVar()); - bool SetAssembleStrategy::getSeparatelySchedulable() const { - return content->separatelySchedulable; - } + if (foralli.getIndexVar() == i) { + // Precondition 1: No parallelization of reduction variables + if (parallelize.getOutputRaceStrategy() == OutputRaceStrategy::NoRaces && + util::contains(reductionIndexVars, i)) { + reason = "Precondition failed: Cannot parallelize reduction loops " + "without synchronization"; + return; + } - IndexStmt SetAssembleStrategy::apply(IndexStmt stmt, string* reason) const { - INIT_REASON(reason); + Iterators iterators(foralli, tensorVars); + MergeLattice lattice = MergeLattice::make(foralli, iterators, provGraph, + definedIndexVars); - if (getAssembleStrategy() == AssembleStrategy::Append) { - return stmt; + // Precondition 2: No coiteration of modes (i.e., merge lattice has + // only one iterator) + if (lattice.iterators().size() != 1) { + reason = "Precondition failed: The loop must not merge tensor " + "dimensions, that is, it must be a for loop;"; + return; } - bool hasSeqInsertEdge = false; - bool hasInsertCoord = false; - bool hasNonpureYieldPos = false; - for (const auto& modeFormat : getResult().getFormat().getModeFormats()) { - if (hasSeqInsertEdge) { - if (modeFormat.hasSeqInsertEdge()) { - *reason = "Precondition failed: The output tensor does not support " - "ungrouped insertion (cannot have multiple modes requiring " - "non-trivial edge insertion)"; - return IndexStmt(); - } - } else { - hasSeqInsertEdge = (hasSeqInsertEdge || modeFormat.hasSeqInsertEdge()); - if (modeFormat.hasSeqInsertEdge()) { - if (hasInsertCoord) { - *reason = "Precondition failed: The output tensor does not support " - "ungrouped insertion (cannot have mode requiring " - "non-trivial coordinate insertion above mode requiring " - "non-trivial edge insertion)"; - return IndexStmt(); - } - hasSeqInsertEdge = true; + vector underivedAncestors = provGraph.getUnderivedAncestors(i); + IndexVar underivedAncestor = underivedAncestors.back(); + + // Get lattice that corresponds to underived ancestor. This is + // bottom-most loop that shares underived ancestor + Forall underivedForall = foralli; + match(foralli.getStmt(), + function([&](const ForallNode* node) { + const auto nodeUnderivedAncestors = + provGraph.getUnderivedAncestors(node->indexVar); + definedIndexVars.insert(node->indexVar); + if (underivedAncestor == nodeUnderivedAncestors.back()) { + underivedForall = Forall(node); } - hasInsertCoord = (hasInsertCoord || modeFormat.hasInsertCoord()); + }) + ); + MergeLattice underivedLattice = MergeLattice::make(underivedForall, + iterators, provGraph, + definedIndexVars); + + // Precondition 3: Every result iterator must have insert capability + for (Iterator iterator : underivedLattice.results()) { + if (util::contains(assembledByUngroupedInsert, iterator.getTensor())) { + for (Iterator it = iterator; !it.isRoot(); it = it.getParent()) { + if (it.hasInsertCoord() || !it.isYieldPosPure()) { + reason = "Precondition failed: The output tensor does not " + "support parallelized inserts"; + return; + } } - if (hasNonpureYieldPos && !modeFormat.isBranchless()) { - *reason = "Precondition failed: The output tensor does not support " - "ungrouped insertion (a mode that has a non-pure " - "implementation of yield_pos cannot be followed by a " - "non-branchless mode)"; - return IndexStmt(); - } else if (!modeFormat.isYieldPosPure()) { - hasNonpureYieldPos = true; + } else { + while (true) { + if (!iterator.hasInsert()) { + reason = "Precondition failed: The output tensor must support " + "inserts"; + return; + } + if (iterator.isLeaf()) { + break; + } + iterator = iterator.getChild(); } + } } - IndexStmt loweredQueries = stmt; - - // If attribute query computation should be independently schedulable, then - // need to use fresh index variables - if (getSeparatelySchedulable()) { - std::map ivReplacements; - for (const auto& indexVar : getIndexVars(stmt)) { - ivReplacements[indexVar] = IndexVar("q" + indexVar.getName()); + if (parallelize.getOutputRaceStrategy() == OutputRaceStrategy::Temporary && + util::contains(reductionIndexVars, underivedForall.getIndexVar())) { + // Need to precompute reduction + + // Find all occurrences of reduction in expression + vector precomputeAssignments; + match(foralli.getStmt(), + function([&](const AssignmentNode* node) { + for (auto underivedVar : underivedAncestors) { + vector reductionVars = Assignment(node).getReductionVars(); + bool reducedByI = + find(reductionVars.begin(), reductionVars.end(), underivedVar) != reductionVars.end(); + if (reducedByI) { + precomputeAssignments.push_back(node); + break; + } + } + }) + ); + taco_iassert(!precomputeAssignments.empty()); + + IndexStmt precomputed_stmt = forall(i, foralli.getStmt(), foralli.getMergeStrategy(), parallelize.getParallelUnit(), parallelize.getOutputRaceStrategy(), foralli.getUnrollFactor()); + for (auto assignment : precomputeAssignments) { + // Construct temporary of correct type and size of outer loop + TensorVar w(string("w_") + ParallelUnit_NAMES[(int) parallelize.getParallelUnit()], Type(assignment->lhs.getDataType(), {Dimension(i)}), taco::dense); + + // rewrite producer to write to temporary, mark producer as parallel + IndexStmt producer = ReplaceReductionExpr(map({{assignment->lhs, w(i)}})).rewrite(precomputed_stmt); + taco_iassert(isa(producer)); + Forall producer_forall = to(producer); + producer = forall(producer_forall.getIndexVar(), producer_forall.getStmt(), foralli.getMergeStrategy(), parallelize.getParallelUnit(), parallelize.getOutputRaceStrategy(), foralli.getUnrollFactor()); + + // build consumer that writes from temporary to output, mark consumer as parallel reduction + ParallelUnit reductionUnit = ParallelUnit::CPUThreadGroupReduction; + if (should_use_CUDA_codegen()) { + if (parentParallelUnits.count(ParallelUnit::GPUWarp)) { + reductionUnit = ParallelUnit::GPUWarpReduction; + } + else { + reductionUnit = ParallelUnit::GPUBlockReduction; + } } - loweredQueries = replace(loweredQueries, ivReplacements); + IndexStmt consumer = forall(i, Assignment(assignment->lhs, w(i), assignment->op), foralli.getMergeStrategy(), reductionUnit, OutputRaceStrategy::ParallelReduction); + precomputed_stmt = where(consumer, producer); + } + stmt = precomputed_stmt; + return; } - // FIXME: Unneeded if scalar promotion is made default when concretizing - loweredQueries = scalarPromote(loweredQueries); - - // Tracks all tensors that correspond to attribute query results or that are - // used to compute attribute queries - std::set insertedResults; - - // Lower attribute queries to canonical forms using same schedule as - // actual computation - Assemble::AttrQueryResults queryResults; - struct LowerAttrQuery : public IndexNotationRewriter { - using IndexNotationRewriter::visit; - - TensorVar result; - Assemble::AttrQueryResults& queryResults; - std::set& insertedResults; - std::vector arguments; - std::vector temps; - std::map tempReplacements; - IndexStmt epilog; - std::string reason = ""; - - LowerAttrQuery(TensorVar result, Assemble::AttrQueryResults& queryResults, - std::set& insertedResults) : - result(result), queryResults(queryResults), - insertedResults(insertedResults) {} - - IndexStmt lower(IndexStmt stmt) { - arguments = getArguments(stmt); - temps = getTemporaries(stmt); - for (const auto& tmp : temps) { - tempReplacements[tmp] = TensorVar("q" + tmp.getName(), - Type(Bool, tmp.getType().getShape()), - tmp.getFormat()); - } + if (parallelize.getOutputRaceStrategy() == OutputRaceStrategy::Atomics) { + // want to avoid extra atomics by accumulating variable and then + // reducing at end + IndexStmt body = scalarPromote(foralli.getStmt(), provGraph, + false, true); + stmt = forall(i, body, foralli.getMergeStrategy(), parallelize.getParallelUnit(), + parallelize.getOutputRaceStrategy(), + foralli.getUnrollFactor()); + return; + } - queryResults = Assemble::AttrQueryResults(); - epilog = IndexStmt(); - stmt = IndexNotationRewriter::rewrite(stmt); - if (epilog.defined()) { - stmt = Where(epilog, stmt); - } - return stmt; - } - void visit(const ForallNode* op) { - IndexStmt s = rewrite(op->stmt); - if (s == op->stmt) { - stmt = op; - } else if (s.defined()) { - stmt = Forall(op->indexVar, s, op->parallel_unit, - op->output_race_strategy, op->unrollFactor); - } else { - stmt = IndexStmt(); - } - } + stmt = forall(i, foralli.getStmt(), foralli.getMergeStrategy(), parallelize.getParallelUnit(), parallelize.getOutputRaceStrategy(), foralli.getUnrollFactor()); + return; + } - void visit(const WhereNode* op) { - IndexStmt producer = rewrite(op->producer); - IndexStmt consumer = rewrite(op->consumer); - if (producer == op->producer && consumer == op->consumer) { - stmt = op; - } else if (consumer.defined()) { - stmt = producer.defined() ? Where(consumer, producer) : consumer; - } else { - stmt = IndexStmt(); - } - } + if (foralli.getParallelUnit() != ParallelUnit::NotParallel) { + parentParallelUnits.insert(foralli.getParallelUnit()); + } + IndexNotationRewriter::visit(node); + } + }; - void visit(const AssignmentNode* op) { - IndexExpr rhs = rewrite(op->rhs); + ParallelizeRewriter rewriter; + rewriter.parallelize = *this; + IndexStmt rewritten = rewriter.rewriteParallel(stmt); + if (!rewriter.reason.empty()) { + *reason = rewriter.reason; + return IndexStmt(); + } + return rewritten; +} - const auto resultAccess = op->lhs; - const auto resultTensor = resultAccess.getTensorVar(); - if (resultTensor != result) { - // TODO: Should check that annihilator of original reduction op equals - // fill value of original result - Access lhs = to(rewrite(op->lhs)); - IndexExpr reduceOp = op->op.defined() ? Add() : IndexExpr(); - stmt = (rhs != op->rhs) ? Assignment(lhs, rhs, reduceOp) : op; - return; - } - /// GENGHAN: Why do we need defined? - if (op->op.defined()) { - reason = "Precondition failed: Ungrouped insertion not support for " - "output tensors that are scattered into"; - return; - } +void Parallelize::print(std::ostream& os) const { + os << "parallelize(" << geti() << ")"; +} - queryResults[resultTensor] = - std::vector>(resultTensor.getOrder()); - const auto indices = resultAccess.getIndexVars(); - const auto modeFormats = resultTensor.getFormat().getModeFormats(); - const auto modeOrdering = resultTensor.getFormat().getModeOrdering(); +std::ostream& operator<<(std::ostream& os, const Parallelize& parallelize) { + parallelize.print(os); + return os; +} - std::vector parentCoords; - std::vector childCoords; - for (size_t i = 0; i < indices.size(); ++i) { - childCoords.push_back(indices[modeOrdering[i]]); - } - for (size_t i = 0; i < indices.size(); ++i) { - const auto modeName = resultTensor.getName() + std::to_string(i + 1); - - parentCoords.push_back(indices[modeOrdering[i]]); - childCoords.erase(childCoords.begin()); - - for (const auto& query: - modeFormats[i].getAttrQueries(parentCoords, childCoords)) { - const auto& groupBy = query.getGroupBy(); - - // TODO: support multiple aggregations in single query - taco_iassert(query.getAttrs().size() == 1); - - std::vector queryDims; - for (const auto& coord : groupBy) { - const auto pos = std::find(groupBy.begin(), groupBy.end(), coord) - - groupBy.begin(); - const auto dim = resultTensor.getType().getShape().getDimension(pos); - queryDims.push_back(dim); - } - - for (const auto& attr : query.getAttrs()) { - switch (attr.aggr) { - case AttrQuery::COUNT: - { - std::vector dedupCoords = groupBy; - dedupCoords.insert(dedupCoords.end(), attr.params.begin(), - attr.params.end()); - std::vector dedupDims(dedupCoords.size()); - TensorVar dedupTmp(modeName + "_dedup", Type(Bool, dedupDims)); - stmt = Assignment(dedupTmp(dedupCoords), rhs, Add()); - insertedResults.insert(dedupTmp); - - const auto resultName = modeName + "_" + attr.label; - TensorVar queryResult(resultName, Type(Int32, queryDims)); - epilog = Assignment(queryResult(groupBy), - Cast(dedupTmp(dedupCoords), Int()), Add()); - for (const auto& coord : util::reverse(dedupCoords)) { - epilog = forall(coord, epilog); - } - insertedResults.insert(queryResult); - - queryResults[resultTensor][i] = {queryResult}; - return; - } - case AttrQuery::IDENTITY: - case AttrQuery::MIN: - case AttrQuery::MAX: - default: - taco_not_supported_yet; - break; - } - } - } - } +// class SetAssembleStrategy - stmt = IndexStmt(); - } +struct SetAssembleStrategy::Content { + TensorVar result; + AssembleStrategy strategy; + bool separatelySchedulable; +}; - void visit(const AccessNode* op) { - if (util::contains(arguments, op->tensorVar)) { - expr = Access(op->tensorVar, op->indexVars, op->packageModifiers(), - true); - return; - } else if (util::contains(temps, op->tensorVar)) { - expr = Access(tempReplacements[op->tensorVar], op->indexVars, - op->packageModifiers()); - return; - } +SetAssembleStrategy::SetAssembleStrategy(TensorVar result, + AssembleStrategy strategy, + bool separatelySchedulable) : + content(new Content) { + content->result = result; + content->strategy = strategy; + content->separatelySchedulable = separatelySchedulable; +} - expr = op; - } +TensorVar SetAssembleStrategy::getResult() const { + return content->result; +} - void visit(const CallNode* op) { - std::vector args; - bool rewritten = false; - for(auto& arg : op->args) { - IndexExpr rewrittenArg = rewrite(arg); - args.push_back(rewrittenArg); - if (arg != rewrittenArg) { - rewritten = true; - } - } +AssembleStrategy SetAssembleStrategy::getAssembleStrategy() const { + return content->strategy; +} - if (rewritten) { - const std::map subs = util::zipToMap(op->args, args); - IterationAlgebra newAlg = replaceAlgIndexExprs(op->iterAlg, subs); - - struct InferSymbolic : public IterationAlgebraVisitorStrict { - IndexExpr ret; - - IndexExpr infer(IterationAlgebra alg) { - ret = IndexExpr(); - alg.accept(this); - return ret; - } - virtual void visit(const RegionNode* op) { - ret = op->expr(); - } - - virtual void visit(const ComplementNode* op) { - taco_not_supported_yet; - } - - virtual void visit(const IntersectNode* op) { - IndexExpr lhs = infer(op->a); - IndexExpr rhs = infer(op->b); - ret = lhs * rhs; - } - - virtual void visit(const UnionNode* op) { - IndexExpr lhs = infer(op->a); - IndexExpr rhs = infer(op->b); - ret = lhs + rhs; - } - }; - expr = InferSymbolic().infer(newAlg); - } - else { - expr = op; - } - } - }; - LowerAttrQuery queryLowerer(getResult(), queryResults, insertedResults); - loweredQueries = queryLowerer.lower(loweredQueries); +bool SetAssembleStrategy::getSeparatelySchedulable() const { + return content->separatelySchedulable; +} - if (!queryLowerer.reason.empty()) { - *reason = queryLowerer.reason; - return IndexStmt(); +IndexStmt SetAssembleStrategy::apply(IndexStmt stmt, string* reason) const { + INIT_REASON(reason); + + if (getAssembleStrategy() == AssembleStrategy::Append) { + return stmt; + } + + bool hasSeqInsertEdge = false; + bool hasInsertCoord = false; + bool hasNonpureYieldPos = false; + for (const auto& modeFormat : getResult().getFormat().getModeFormats()) { + if (hasSeqInsertEdge) { + if (modeFormat.hasSeqInsertEdge()) { + *reason = "Precondition failed: The output tensor does not support " + "ungrouped insertion (cannot have multiple modes requiring " + "non-trivial edge insertion)"; + return IndexStmt(); + } + } else { + hasSeqInsertEdge = (hasSeqInsertEdge || modeFormat.hasSeqInsertEdge()); + if (modeFormat.hasSeqInsertEdge()) { + if (hasInsertCoord) { + *reason = "Precondition failed: The output tensor does not support " + "ungrouped insertion (cannot have mode requiring " + "non-trivial coordinate insertion above mode requiring " + "non-trivial edge insertion)"; + return IndexStmt(); } - - // Convert redundant reductions to assignments - loweredQueries = eliminateRedundantReductions(loweredQueries, - &insertedResults); - - // Inline definitions of temporaries into their corresponding uses, as long - // as the temporaries are not the results of reductions - std::set inlinedResults; - struct InlineTemporaries : public IndexNotationRewriter { - using IndexNotationRewriter::visit; - - const std::set& insertedResults; - std::set& inlinedResults; - std::map> tmpUse; - - InlineTemporaries(const std::set& insertedResults, - std::set& inlinedResults) : - insertedResults(insertedResults), inlinedResults(inlinedResults) {} - - void visit(const WhereNode* op) { - IndexStmt consumer = rewrite(op->consumer); - IndexStmt producer = rewrite(op->producer); - if (producer == op->producer && consumer == op->consumer) { - stmt = op; - } else { - stmt = new WhereNode(consumer, producer); - } + hasSeqInsertEdge = true; + } + hasInsertCoord = (hasInsertCoord || modeFormat.hasInsertCoord()); + } + if (hasNonpureYieldPos && !modeFormat.isBranchless()) { + *reason = "Precondition failed: The output tensor does not support " + "ungrouped insertion (a mode that has a non-pure " + "implementation of yield_pos cannot be followed by a " + "non-branchless mode)"; + return IndexStmt(); + } else if (!modeFormat.isYieldPosPure()) { + hasNonpureYieldPos = true; + } + } + + IndexStmt loweredQueries = stmt; + + // If attribute query computation should be independently schedulable, then + // need to use fresh index variables + if (getSeparatelySchedulable()) { + std::map ivReplacements; + for (const auto& indexVar : getIndexVars(stmt)) { + ivReplacements[indexVar] = IndexVar("q" + indexVar.getName()); + } + loweredQueries = replace(loweredQueries, ivReplacements); + } + + // FIXME: Unneeded if scalar promotion is made default when concretizing + loweredQueries = scalarPromote(loweredQueries); + + // Tracks all tensors that correspond to attribute query results or that are + // used to compute attribute queries + std::set insertedResults; + + // Lower attribute queries to canonical forms using same schedule as + // actual computation + Assemble::AttrQueryResults queryResults; + struct LowerAttrQuery : public IndexNotationRewriter { + using IndexNotationRewriter::visit; + + TensorVar result; + Assemble::AttrQueryResults& queryResults; + std::set& insertedResults; + std::vector arguments; + std::vector temps; + std::map tempReplacements; + IndexStmt epilog; + std::string reason = ""; + + LowerAttrQuery(TensorVar result, Assemble::AttrQueryResults& queryResults, + std::set& insertedResults) : + result(result), queryResults(queryResults), + insertedResults(insertedResults) {} + + IndexStmt lower(IndexStmt stmt) { + arguments = getArguments(stmt); + temps = getTemporaries(stmt); + for (const auto& tmp : temps) { + tempReplacements[tmp] = TensorVar("q" + tmp.getName(), + Type(Bool, tmp.getType().getShape()), + tmp.getFormat()); + } + + queryResults = Assemble::AttrQueryResults(); + epilog = IndexStmt(); + stmt = IndexNotationRewriter::rewrite(stmt); + if (epilog.defined()) { + stmt = Where(epilog, stmt); + } + return stmt; + } + + void visit(const ForallNode* op) { + IndexStmt s = rewrite(op->stmt); + if (s == op->stmt) { + stmt = op; + } else if (s.defined()) { + stmt = Forall(op->indexVar, s, op->merge_strategy, op->parallel_unit, + op->output_race_strategy, op->unrollFactor); + } else { + stmt = IndexStmt(); + } + } + + void visit(const WhereNode* op) { + IndexStmt producer = rewrite(op->producer); + IndexStmt consumer = rewrite(op->consumer); + if (producer == op->producer && consumer == op->consumer) { + stmt = op; + } else if (consumer.defined()) { + stmt = producer.defined() ? Where(consumer, producer) : consumer; + } else { + stmt = IndexStmt(); + } + } + + void visit(const AssignmentNode* op) { + IndexExpr rhs = rewrite(op->rhs); + + const auto resultAccess = op->lhs; + const auto resultTensor = resultAccess.getTensorVar(); + + if (resultTensor != result) { + // TODO: Should check that annihilator of original reduction op equals + // fill value of original result + Access lhs = to(rewrite(op->lhs)); + IndexExpr reduceOp = op->op.defined() ? Add() : IndexExpr(); + stmt = (rhs != op->rhs) ? Assignment(lhs, rhs, reduceOp) : op; + return; + } + + if (op->op.defined()) { + reason = "Precondition failed: Ungrouped insertion not support for " + "output tensors that are scattered into"; + return; + } + + queryResults[resultTensor] = + std::vector>(resultTensor.getOrder()); + + const auto indices = resultAccess.getIndexVars(); + const auto modeFormats = resultTensor.getFormat().getModeFormats(); + const auto modeOrdering = resultTensor.getFormat().getModeOrdering(); + + std::vector parentCoords; + std::vector childCoords; + for (size_t i = 0; i < indices.size(); ++i) { + childCoords.push_back(indices[modeOrdering[i]]); + } + + for (size_t i = 0; i < indices.size(); ++i) { + const auto modeName = resultTensor.getName() + std::to_string(i + 1); + + parentCoords.push_back(indices[modeOrdering[i]]); + childCoords.erase(childCoords.begin()); + + for (const auto& query: + modeFormats[i].getAttrQueries(parentCoords, childCoords)) { + const auto& groupBy = query.getGroupBy(); + + // TODO: support multiple aggregations in single query + taco_iassert(query.getAttrs().size() == 1); + + std::vector queryDims; + for (const auto& coord : groupBy) { + const auto pos = std::find(groupBy.begin(), groupBy.end(), coord) + - groupBy.begin(); + const auto dim = resultTensor.getType().getShape().getDimension(pos); + queryDims.push_back(dim); + } + + for (const auto& attr : query.getAttrs()) { + switch (attr.aggr) { + case AttrQuery::COUNT: + { + std::vector dedupCoords = groupBy; + dedupCoords.insert(dedupCoords.end(), attr.params.begin(), + attr.params.end()); + std::vector dedupDims(dedupCoords.size()); + TensorVar dedupTmp(modeName + "_dedup", Type(Bool, dedupDims)); + stmt = Assignment(dedupTmp(dedupCoords), rhs, Add()); + insertedResults.insert(dedupTmp); + + const auto resultName = modeName + "_" + attr.label; + TensorVar queryResult(resultName, Type(Int32, queryDims)); + epilog = Assignment(queryResult(groupBy), + Cast(dedupTmp(dedupCoords), Int()), Add()); + for (const auto& coord : util::reverse(dedupCoords)) { + epilog = forall(coord, epilog); + } + insertedResults.insert(queryResult); + + queryResults[resultTensor][i] = {queryResult}; + return; + } + case AttrQuery::IDENTITY: + case AttrQuery::MIN: + case AttrQuery::MAX: + default: + taco_not_supported_yet; + break; } + } + } + } - void visit(const AssignmentNode* op) { - const auto lhsTensor = op->lhs.getTensorVar(); - if (util::contains(tmpUse, lhsTensor) && !op->op.defined()) { - std::map indexMap; - const auto& oldIndices = - to(tmpUse[lhsTensor].first).getIndexVars(); - const auto& newIndices = op->lhs.getIndexVars(); - for (const auto& mapping : util::zip(oldIndices, newIndices)) { - indexMap[mapping.first] = mapping.second; - } + stmt = IndexStmt(); + } - std::vector newCoords; - const auto& oldCoords = - tmpUse[lhsTensor].second.getLhs().getIndexVars(); - for (const auto& oldCoord : oldCoords) { - newCoords.push_back(indexMap.at(oldCoord)); - } + void visit(const AccessNode* op) { + if (util::contains(arguments, op->tensorVar)) { + expr = Access(op->tensorVar, op->indexVars, op->packageModifiers(), + true); + return; + } else if (util::contains(temps, op->tensorVar)) { + expr = Access(tempReplacements[op->tensorVar], op->indexVars, + op->packageModifiers()); + return; + } - IndexExpr reduceOp = tmpUse[lhsTensor].second.getOperator(); - TensorVar queryResult = - tmpUse[lhsTensor].second.getLhs().getTensorVar(); - IndexExpr rhs = op->rhs; - if (rhs.getDataType() != queryResult.getType().getDataType()) { - rhs = Cast(rhs, queryResult.getType().getDataType()); - } - stmt = Assignment(queryResult(newCoords), rhs, reduceOp); - inlinedResults.insert(queryResult); - return; - } + expr = op; + } - const Access rhsAccess = isa(op->rhs) ? to(op->rhs) - : (isa(op->rhs) && isa(to(op->rhs).getA())) - ? to(to(op->rhs).getA()) : Access(); - if (rhsAccess.defined()) { - const auto rhsTensor = rhsAccess.getTensorVar(); - if (util::contains(insertedResults, rhsTensor)) { - tmpUse[rhsTensor] = std::make_pair(rhsAccess, Assignment(op)); - } - } - stmt = op; - } + void visit(const CallNode* op) { + std::vector args; + bool rewritten = false; + for(auto& arg : op->args) { + IndexExpr rewrittenArg = rewrite(arg); + args.push_back(rewrittenArg); + if (arg != rewrittenArg) { + rewritten = true; + } + } + + if (rewritten) { + const std::map subs = util::zipToMap(op->args, args); + IterationAlgebra newAlg = replaceAlgIndexExprs(op->iterAlg, subs); + + struct InferSymbolic : public IterationAlgebraVisitorStrict { + IndexExpr ret; + + IndexExpr infer(IterationAlgebra alg) { + ret = IndexExpr(); + alg.accept(this); + return ret; + } + virtual void visit(const RegionNode* op) { + ret = op->expr(); + } + + virtual void visit(const ComplementNode* op) { + taco_not_supported_yet; + } + + virtual void visit(const IntersectNode* op) { + IndexExpr lhs = infer(op->a); + IndexExpr rhs = infer(op->b); + ret = lhs * rhs; + } + + virtual void visit(const UnionNode* op) { + IndexExpr lhs = infer(op->a); + IndexExpr rhs = infer(op->b); + ret = lhs + rhs; + } }; - loweredQueries = InlineTemporaries(insertedResults, - inlinedResults).rewrite(loweredQueries); - - // Eliminate computation of redundant temporaries - struct EliminateRedundantTemps : public IndexNotationRewriter { - using IndexNotationRewriter::visit; - - const std::set& inlinedResults; - - EliminateRedundantTemps(const std::set& inlinedResults) : - inlinedResults(inlinedResults) {} - - void visit(const ForallNode* op) { - IndexStmt s = rewrite(op->stmt); - if (s == op->stmt) { - stmt = op; - } else if (s.defined()) { - stmt = new ForallNode(op->indexVar, s, op->parallel_unit, - op->output_race_strategy, op->unrollFactor); - } else { - stmt = IndexStmt(); - } - } - - void visit(const WhereNode* op) { - IndexStmt consumer = rewrite(op->consumer); - if (consumer == op->consumer) { - stmt = op; - } else if (consumer.defined()) { - stmt = new WhereNode(consumer, op->producer); - } else { - stmt = op->producer; - } - } + expr = InferSymbolic().infer(newAlg); + } + else { + expr = op; + } + } + }; + LowerAttrQuery queryLowerer(getResult(), queryResults, insertedResults); + loweredQueries = queryLowerer.lower(loweredQueries); + + if (!queryLowerer.reason.empty()) { + *reason = queryLowerer.reason; + return IndexStmt(); + } + + // Convert redundant reductions to assignments + loweredQueries = eliminateRedundantReductions(loweredQueries, + &insertedResults); + + // Inline definitions of temporaries into their corresponding uses, as long + // as the temporaries are not the results of reductions + std::set inlinedResults; + struct InlineTemporaries : public IndexNotationRewriter { + using IndexNotationRewriter::visit; + + const std::set& insertedResults; + std::set& inlinedResults; + std::map> tmpUse; + + InlineTemporaries(const std::set& insertedResults, + std::set& inlinedResults) : + insertedResults(insertedResults), inlinedResults(inlinedResults) {} + + void visit(const WhereNode* op) { + IndexStmt consumer = rewrite(op->consumer); + IndexStmt producer = rewrite(op->producer); + if (producer == op->producer && consumer == op->consumer) { + stmt = op; + } else { + stmt = new WhereNode(consumer, producer); + } + } + + void visit(const AssignmentNode* op) { + const auto lhsTensor = op->lhs.getTensorVar(); + if (util::contains(tmpUse, lhsTensor) && !op->op.defined()) { + std::map indexMap; + const auto& oldIndices = + to(tmpUse[lhsTensor].first).getIndexVars(); + const auto& newIndices = op->lhs.getIndexVars(); + for (const auto& mapping : util::zip(oldIndices, newIndices)) { + indexMap[mapping.first] = mapping.second; + } - void visit(const AssignmentNode* op) { - const auto lhsTensor = op->lhs.getTensorVar(); - if (util::contains(inlinedResults, lhsTensor)) { - stmt = IndexStmt(); - } else { - stmt = op; - } - } - }; - loweredQueries = - EliminateRedundantTemps(inlinedResults).rewrite(loweredQueries); + std::vector newCoords; + const auto& oldCoords = + tmpUse[lhsTensor].second.getLhs().getIndexVars(); + for (const auto& oldCoord : oldCoords) { + newCoords.push_back(indexMap.at(oldCoord)); + } - return Assemble(loweredQueries, stmt, queryResults); - } + IndexExpr reduceOp = tmpUse[lhsTensor].second.getOperator(); + TensorVar queryResult = + tmpUse[lhsTensor].second.getLhs().getTensorVar(); + IndexExpr rhs = op->rhs; + if (rhs.getDataType() != queryResult.getType().getDataType()) { + rhs = Cast(rhs, queryResult.getType().getDataType()); + } + stmt = Assignment(queryResult(newCoords), rhs, reduceOp); + inlinedResults.insert(queryResult); + return; + } + + const Access rhsAccess = isa(op->rhs) ? to(op->rhs) + : (isa(op->rhs) && isa(to(op->rhs).getA())) + ? to(to(op->rhs).getA()) : Access(); + if (rhsAccess.defined()) { + const auto rhsTensor = rhsAccess.getTensorVar(); + if (util::contains(insertedResults, rhsTensor)) { + tmpUse[rhsTensor] = std::make_pair(rhsAccess, Assignment(op)); + } + } + stmt = op; + } + }; + loweredQueries = InlineTemporaries(insertedResults, + inlinedResults).rewrite(loweredQueries); + + // Eliminate computation of redundant temporaries + struct EliminateRedundantTemps : public IndexNotationRewriter { + using IndexNotationRewriter::visit; + + const std::set& inlinedResults; + + EliminateRedundantTemps(const std::set& inlinedResults) : + inlinedResults(inlinedResults) {} + + void visit(const ForallNode* op) { + IndexStmt s = rewrite(op->stmt); + if (s == op->stmt) { + stmt = op; + } else if (s.defined()) { + stmt = new ForallNode(op->indexVar, s, op->merge_strategy, op->parallel_unit, + op->output_race_strategy, op->unrollFactor); + } else { + stmt = IndexStmt(); + } + } + + void visit(const WhereNode* op) { + IndexStmt consumer = rewrite(op->consumer); + if (consumer == op->consumer) { + stmt = op; + } else if (consumer.defined()) { + stmt = new WhereNode(consumer, op->producer); + } else { + stmt = op->producer; + } + } + + void visit(const AssignmentNode* op) { + const auto lhsTensor = op->lhs.getTensorVar(); + if (util::contains(inlinedResults, lhsTensor)) { + stmt = IndexStmt(); + } else { + stmt = op; + } + } + }; + loweredQueries = + EliminateRedundantTemps(inlinedResults).rewrite(loweredQueries); + + return Assemble(loweredQueries, stmt, queryResults); +} - void SetAssembleStrategy::print(std::ostream& os) const { - os << "assemble(" << getResult() << ", " - << AssembleStrategy_NAMES[(int)getAssembleStrategy()] << ")"; - } +void SetAssembleStrategy::print(std::ostream& os) const { + os << "assemble(" << getResult() << ", " + << AssembleStrategy_NAMES[(int)getAssembleStrategy()] << ")"; +} - std::ostream& operator<<(std::ostream& os, - const SetAssembleStrategy& assemble) { - assemble.print(os); - return os; - } +std::ostream& operator<<(std::ostream& os, + const SetAssembleStrategy& assemble) { + assemble.print(os); + return os; +} // Autoscheduling functions - IndexStmt parallelizeOuterLoop(IndexStmt stmt) { - // get outer ForAll - Forall forall; - bool matched = false; - match(stmt, - function([&forall, &matched]( - const ForallNode* node, Matcher* ctx) { - if (!matched) forall = node; - matched = true; - }) - ); - - if (!matched) return stmt; - string reason; - - if (should_use_CUDA_codegen()) { - for (const auto& temp : getTemporaries(stmt)) { - // Don't parallelize computations that use non-scalar temporaries. - if (temp.getOrder() > 0) { - return stmt; - } - } +IndexStmt parallelizeOuterLoop(IndexStmt stmt) { + // get outer ForAll + Forall forall; + bool matched = false; + match(stmt, + function([&forall, &matched]( + const ForallNode* node, Matcher* ctx) { + if (!matched) forall = node; + matched = true; + }) + ); + + if (!matched) return stmt; + string reason; + + if (should_use_CUDA_codegen()) { + for (const auto& temp : getTemporaries(stmt)) { + // Don't parallelize computations that use non-scalar temporaries. + if (temp.getOrder() > 0) { + return stmt; + } + } - IndexVar i1, i2; - IndexStmt parallelized256 = stmt.split(forall.getIndexVar(), i1, i2, 256); - parallelized256 = Parallelize(i1, ParallelUnit::GPUBlock, OutputRaceStrategy::NoRaces).apply(parallelized256, &reason); - if (parallelized256 == IndexStmt()) { - return stmt; - } + IndexVar i1, i2; + IndexStmt parallelized256 = stmt.split(forall.getIndexVar(), i1, i2, 256); + parallelized256 = Parallelize(i1, ParallelUnit::GPUBlock, OutputRaceStrategy::NoRaces).apply(parallelized256, &reason); + if (parallelized256 == IndexStmt()) { + return stmt; + } - parallelized256 = Parallelize(i2, ParallelUnit::GPUThread, OutputRaceStrategy::NoRaces).apply(parallelized256, &reason); - if (parallelized256 == IndexStmt()) { - return stmt; - } - return parallelized256; - } - else { - IndexStmt parallelized = Parallelize(forall.getIndexVar(), ParallelUnit::CPUThread, OutputRaceStrategy::NoRaces).apply(stmt, &reason); - if (parallelized == IndexStmt()) { - // can't parallelize - return stmt; - } - return parallelized; - } + parallelized256 = Parallelize(i2, ParallelUnit::GPUThread, OutputRaceStrategy::NoRaces).apply(parallelized256, &reason); + if (parallelized256 == IndexStmt()) { + return stmt; } + return parallelized256; + } + else { + IndexStmt parallelized = Parallelize(forall.getIndexVar(), ParallelUnit::CPUThread, OutputRaceStrategy::NoRaces).apply(stmt, &reason); + if (parallelized == IndexStmt()) { + // can't parallelize + return stmt; + } + return parallelized; + } +} // Takes in a set of pairs of IndexVar and level for a given tensor and orders // the IndexVars by tensor level - static vector> - varOrderFromTensorLevels(set>> tensorLevelVars) { - vector>> sortedPairs(tensorLevelVars.begin(), - tensorLevelVars.end()); - auto comparator = [](const pair> &left, - const pair> &right) { - return left.second.first < right.second.first; - }; - std::sort(sortedPairs.begin(), sortedPairs.end(), comparator); - - vector> varOrder; - std::transform(sortedPairs.begin(), - sortedPairs.end(), - std::back_inserter(varOrder), - [](const std::pair>& p) { - return pair(p.first, p.second.second); - }); - return varOrder; - } +static vector> +varOrderFromTensorLevels(set>> tensorLevelVars) { + vector>> sortedPairs(tensorLevelVars.begin(), + tensorLevelVars.end()); + auto comparator = [](const pair> &left, + const pair> &right) { + return left.second.first < right.second.first; + }; + std::sort(sortedPairs.begin(), sortedPairs.end(), comparator); + + vector> varOrder; + std::transform(sortedPairs.begin(), + sortedPairs.end(), + std::back_inserter(varOrder), + [](const std::pair>& p) { + return pair(p.first, p.second.second); + }); + return varOrder; +} // Takes in varOrders from many tensors and creates a map of dependencies between IndexVars - static map> - depsFromVarOrders(map>> varOrders) { - map> deps; - for (const auto& varOrderPair : varOrders) { - const auto& varOrder = varOrderPair.second; - for (auto firstit = varOrder.begin(); firstit != varOrder.end(); ++firstit) { - for (auto secondit = firstit + 1; secondit != varOrder.end(); ++secondit) { - if (firstit->second || secondit->second) { // one of the dimensions must enforce constraints - if (deps.count(secondit->first)) { - deps[secondit->first].insert(firstit->first); - } else { - deps[secondit->first] = {firstit->first}; - } - } - } - } - } - return deps; - } - - - static vector - topologicallySort(map> hardDeps, - map> softDeps, - vector originalOrder) { - vector sortedVars; - unsigned long countVars = originalOrder.size(); - while (sortedVars.size() < countVars) { - // Scan for variable with no dependencies - IndexVar freeVar; - size_t freeVarPos = std::numeric_limits::max(); - size_t minSoftDepsRemaining = std::numeric_limits::max(); - for (size_t i = 0; i < originalOrder.size(); ++i) { - IndexVar var = originalOrder[i]; - if (!hardDeps.count(var) || hardDeps[var].empty()) { - const size_t softDepsRemaining = softDeps.count(var) ? - softDeps[var].size() : 0; - if (softDepsRemaining < minSoftDepsRemaining) { - freeVar = var; - freeVarPos = i; - minSoftDepsRemaining = softDepsRemaining; - } - } - } - - // No free var found there is a cycle - taco_iassert(freeVarPos != std::numeric_limits::max()) - << "Cycles in iteration graphs must be resolved, through transpose, " - << "before the expression is passed to the topological sorting " - << "routine."; - - sortedVars.push_back(freeVar); - - // remove dependencies on variable - for (auto& varTensorDepsPair : hardDeps) { - auto& varTensorDeps = varTensorDepsPair.second; - if (varTensorDeps.count(freeVar)) { - varTensorDeps.erase(freeVar); - } - } - for (auto& varTensorDepsPair : softDeps) { - auto& varTensorDeps = varTensorDepsPair.second; - if (varTensorDeps.count(freeVar)) { - varTensorDeps.erase(freeVar); - } - } - originalOrder.erase(originalOrder.begin() + freeVarPos); +static map> +depsFromVarOrders(map>> varOrders) { + map> deps; + for (const auto& varOrderPair : varOrders) { + const auto& varOrder = varOrderPair.second; + for (auto firstit = varOrder.begin(); firstit != varOrder.end(); ++firstit) { + for (auto secondit = firstit + 1; secondit != varOrder.end(); ++secondit) { + if (firstit->second || secondit->second) { // one of the dimensions must enforce constraints + if (deps.count(secondit->first)) { + deps[secondit->first].insert(firstit->first); + } else { + deps[secondit->first] = {firstit->first}; + } } - return sortedVars; + } } + } + return deps; +} - IndexStmt reorderLoopsTopologically(IndexStmt stmt) { - // Collect tensorLevelVars which stores the pairs of IndexVar and tensor - // level that each tensor is accessed at - struct DAGBuilder : public IndexNotationVisitor { - using IndexNotationVisitor::visit; - // int is level, bool is if level enforces constraints (ie not dense) - map>>> tensorLevelVars; - IndexStmt innerBody; - map forallParallelUnit; - map forallOutputRaceStrategy; - vector indexVarOriginalOrder; - Iterators iterators; - - DAGBuilder(Iterators iterators) : iterators(iterators) {}; - - void visit(const ForallNode* node) { - Forall foralli(node); - IndexVar i = foralli.getIndexVar(); - - MergeLattice lattice = MergeLattice::make(foralli, iterators, ProvenanceGraph(), {}); // TODO - indexVarOriginalOrder.push_back(i); - forallParallelUnit[i] = foralli.getParallelUnit(); - forallOutputRaceStrategy[i] = foralli.getOutputRaceStrategy(); - - // Iterator and if Iterator enforces constraints - vector> depIterators; - for (Iterator iterator : lattice.points()[0].iterators()) { - if (!iterator.isDimensionIterator()) { - depIterators.push_back({iterator, true}); - } - } - - for (Iterator iterator : lattice.points()[0].locators()) { - depIterators.push_back({iterator, false}); - } - - // add result iterators that append - for (Iterator iterator : lattice.results()) { - depIterators.push_back({iterator, !iterator.hasInsert()}); - } - - for (const auto& iteratorPair : depIterators) { - int level = iteratorPair.first.getMode().getLevel(); - string tensor = to(iteratorPair.first.getTensor())->name; - if (tensorLevelVars.count(tensor)) { - tensorLevelVars[tensor].insert({{i, {level, iteratorPair.second}}}); - } - else { - tensorLevelVars[tensor] = {{{i, {level, iteratorPair.second}}}}; - } - } - - if (!isa(foralli.getStmt())) { - innerBody = foralli.getStmt(); - return; // Only reorder first contiguous section of ForAlls - } - IndexNotationVisitor::visit(node); - } - }; - - Iterators iterators(stmt); - DAGBuilder dagBuilder(iterators); - stmt.accept(&dagBuilder); - - // Construct tensor dependencies (sorted list of IndexVars) from tensorLevelVars - map>> tensorVarOrders; - for (const auto& tensorLevelVar : dagBuilder.tensorLevelVars) { - tensorVarOrders[tensorLevelVar.first] = - varOrderFromTensorLevels(tensorLevelVar.second); +static vector +topologicallySort(map> hardDeps, + map> softDeps, + vector originalOrder) { + vector sortedVars; + unsigned long countVars = originalOrder.size(); + while (sortedVars.size() < countVars) { + // Scan for variable with no dependencies + IndexVar freeVar; + size_t freeVarPos = std::numeric_limits::max(); + size_t minSoftDepsRemaining = std::numeric_limits::max(); + for (size_t i = 0; i < originalOrder.size(); ++i) { + IndexVar var = originalOrder[i]; + if (!hardDeps.count(var) || hardDeps[var].empty()) { + const size_t softDepsRemaining = softDeps.count(var) ? + softDeps[var].size() : 0; + if (softDepsRemaining < minSoftDepsRemaining) { + freeVar = var; + freeVarPos = i; + minSoftDepsRemaining = softDepsRemaining; } - const auto hardDeps = depsFromVarOrders(tensorVarOrders); - - struct CollectSoftDependencies : public IndexNotationVisitor { - using IndexNotationVisitor::visit; - - map> softDeps; - - void visit(const AssignmentNode* op) { - op->lhs.accept(this); - op->rhs.accept(this); - } - - void visit(const AccessNode* node) { - const auto& modeOrdering = node->tensorVar.getFormat().getModeOrdering(); - for (size_t i = 1; i < (size_t)node->tensorVar.getOrder(); ++i) { - const auto srcVar = node->indexVars[modeOrdering[i - 1]]; - const auto dstVar = node->indexVars[modeOrdering[i]]; - softDeps[dstVar].insert(srcVar); - } - } - }; - CollectSoftDependencies collectSoftDeps; - stmt.accept(&collectSoftDeps); - - const auto sortedVars = topologicallySort(hardDeps, collectSoftDeps.softDeps, - dagBuilder.indexVarOriginalOrder); - - // Reorder Foralls use a rewriter in case new nodes introduced outside of Forall - struct TopoReorderRewriter : public IndexNotationRewriter { - using IndexNotationRewriter::visit; - - const vector& sortedVars; - IndexStmt innerBody; - const map forallParallelUnit; - const map forallOutputRaceStrategy; - - TopoReorderRewriter(const vector& sortedVars, IndexStmt innerBody, - const map forallParallelUnit, - const map forallOutputRaceStrategy) - : sortedVars(sortedVars), innerBody(innerBody), - forallParallelUnit(forallParallelUnit), forallOutputRaceStrategy(forallOutputRaceStrategy) { - } - - void visit(const ForallNode* node) { - Forall foralli(node); - IndexVar i = foralli.getIndexVar(); - - // first forall must be in collected variables - taco_iassert(util::contains(sortedVars, i)); - stmt = innerBody; - for (auto it = sortedVars.rbegin(); it != sortedVars.rend(); ++it) { - stmt = forall(*it, stmt, forallParallelUnit.at(*it), forallOutputRaceStrategy.at(*it), foralli.getUnrollFactor()); - } - return; - } - - }; - TopoReorderRewriter rewriter(sortedVars, dagBuilder.innerBody, - dagBuilder.forallParallelUnit, dagBuilder.forallOutputRaceStrategy); - return rewriter.rewrite(stmt); + } } - IndexStmt scalarPromote(IndexStmt stmt, ProvenanceGraph provGraph, - bool isWholeStmt, bool promoteScalar) { - std::map hoistLevel; - std::map reduceOp; - struct FindHoistLevel : public IndexNotationVisitor { - using IndexNotationVisitor::visit; - - std::map& hoistLevel; - std::map& reduceOp; - std::map> hoistIndices; - std::set derivedIndices; - std::set indices; - const ProvenanceGraph& provGraph; - const bool isWholeStmt; - const bool promoteScalar; - - FindHoistLevel(std::map& hoistLevel, - std::map& reduceOp, - const ProvenanceGraph& provGraph, - bool isWholeStmt, bool promoteScalar) : - hoistLevel(hoistLevel), reduceOp(reduceOp), provGraph(provGraph), - isWholeStmt(isWholeStmt), promoteScalar(promoteScalar) {} - - void visit(const ForallNode* node) { - Forall foralli(node); - IndexVar i = foralli.getIndexVar(); - - // Don't allow hoisting out of forall's for GPU warp and block reduction - if (foralli.getParallelUnit() == ParallelUnit::GPUWarpReduction || - foralli.getParallelUnit() == ParallelUnit::GPUBlockReduction) { - FindHoistLevel findHoistLevel(hoistLevel, reduceOp, provGraph, false, - promoteScalar); - foralli.getStmt().accept(&findHoistLevel); - return; - } - - std::vector resultAccesses; - std::tie(resultAccesses, std::ignore) = getResultAccesses(foralli); - for (const auto& resultAccess : resultAccesses) { - if (!promoteScalar && resultAccess.getIndexVars().empty()) { - continue; - } - - std::set resultIndices(resultAccess.getIndexVars().begin(), - resultAccess.getIndexVars().end()); - if (std::includes(indices.begin(), indices.end(), - resultIndices.begin(), resultIndices.end()) && - !util::contains(hoistLevel, resultAccess)) { - hoistLevel[resultAccess] = node; - hoistIndices[resultAccess] = indices; - - auto resultDerivedIndices = resultIndices; - for (const auto& iv : resultIndices) { - for (const auto& div : provGraph.getFullyDerivedDescendants(iv)) { - resultDerivedIndices.insert(div); - } - } - if (!isWholeStmt || resultDerivedIndices != derivedIndices) { - reduceOp[resultAccess] = IndexExpr(); - } - } - } - - auto newIndices = provGraph.newlyRecoverableParents(i, derivedIndices); - newIndices.push_back(i); - derivedIndices.insert(newIndices.begin(), newIndices.end()); - - const auto underivedIndices = getIndexVars(foralli); - for (const auto& newIndex : newIndices) { - if (util::contains(underivedIndices, newIndex)) { - indices.insert(newIndex); - } - } - - IndexNotationVisitor::visit(node); - - for (const auto& newIndex : newIndices) { - indices.erase(newIndex); - derivedIndices.erase(newIndex); - } - } + // No free var found there is a cycle + taco_iassert(freeVarPos != std::numeric_limits::max()) + << "Cycles in iteration graphs must be resolved, through transpose, " + << "before the expression is passed to the topological sorting " + << "routine."; - void visit(const AssignmentNode* op) { - if (util::contains(hoistLevel, op->lhs) && - hoistIndices[op->lhs] == indices) { - hoistLevel.erase(op->lhs); - } - if (util::contains(reduceOp, op->lhs)) { - reduceOp[op->lhs] = op->op; - } - } - }; - FindHoistLevel findHoistLevel(hoistLevel, reduceOp, provGraph, isWholeStmt, - promoteScalar); - stmt.accept(&findHoistLevel); - - struct HoistWrites : public IndexNotationRewriter { - using IndexNotationRewriter::visit; - - const std::map& hoistLevel; - const std::map& reduceOp; - - HoistWrites(const std::map& hoistLevel, - const std::map& reduceOp) : - hoistLevel(hoistLevel), reduceOp(reduceOp) {} - - void visit(const ForallNode* node) { - Forall foralli(node); - IndexVar i = foralli.getIndexVar(); - IndexStmt body = rewrite(foralli.getStmt()); - - std::vector consumers; - for (const auto& resultAccess : hoistLevel) { - if (resultAccess.second == node) { - // This assumes the index expression yields at most one result tensor; - // will not work correctly if there are multiple results. - TensorVar resultVar = resultAccess.first.getTensorVar(); - TensorVar val("t" + i.getName() + resultVar.getName(), - Type(resultVar.getType().getDataType(), {})); - body = ReplaceReductionExpr( - map({{resultAccess.first, val()}})).rewrite(body); - - IndexExpr op = util::contains(reduceOp, resultAccess.first) - ? reduceOp.at(resultAccess.first) : IndexExpr(); - IndexStmt consumer = Assignment(Access(resultAccess.first), val(), op); - consumers.push_back(consumer); - } - } + sortedVars.push_back(freeVar); - if (body == foralli.getStmt()) { - taco_iassert(consumers.empty()); - stmt = node; - return; - } - - stmt = forall(i, body, foralli.getParallelUnit(), - foralli.getOutputRaceStrategy(), foralli.getUnrollFactor()); - for (const auto& consumer : consumers) { - stmt = where(consumer, stmt); - } - } - }; - HoistWrites hoistWrites(hoistLevel, reduceOp); - return hoistWrites.rewrite(stmt); + // remove dependencies on variable + for (auto& varTensorDepsPair : hardDeps) { + auto& varTensorDeps = varTensorDepsPair.second; + if (varTensorDeps.count(freeVar)) { + varTensorDeps.erase(freeVar); + } } - - IndexStmt scalarPromote(IndexStmt stmt) { - return scalarPromote(stmt, ProvenanceGraph(stmt), true, false); - } - - static bool compare(std::vector vars1, std::vector vars2) { - return vars1 == vars2; + for (auto& varTensorDepsPair : softDeps) { + auto& varTensorDeps = varTensorDepsPair.second; + if (varTensorDeps.count(freeVar)) { + varTensorDeps.erase(freeVar); + } } + originalOrder.erase(originalOrder.begin() + freeVarPos); + } + return sortedVars; +} -// TODO Temporary function to insert workspaces into SpMM kernels - static IndexStmt optimizeSpMM(IndexStmt stmt) { - if (!isa(stmt)) { - return stmt; - } - Forall foralli = to(stmt); - IndexVar i = foralli.getIndexVar(); - if (!isa(foralli.getStmt())) { - return stmt; +IndexStmt reorderLoopsTopologically(IndexStmt stmt) { + // Collect tensorLevelVars which stores the pairs of IndexVar and tensor + // level that each tensor is accessed at + struct DAGBuilder : public IndexNotationVisitor { + using IndexNotationVisitor::visit; + // int is level, bool is if level enforces constraints (ie not dense) + map>>> tensorLevelVars; + IndexStmt innerBody; + map forallParallelUnit; + map forallOutputRaceStrategy; + vector indexVarOriginalOrder; + Iterators iterators; + + DAGBuilder(Iterators iterators) : iterators(iterators) {}; + + void visit(const ForallNode* node) { + Forall foralli(node); + IndexVar i = foralli.getIndexVar(); + + MergeLattice lattice = MergeLattice::make(foralli, iterators, ProvenanceGraph(), {}); // TODO + indexVarOriginalOrder.push_back(i); + forallParallelUnit[i] = foralli.getParallelUnit(); + forallOutputRaceStrategy[i] = foralli.getOutputRaceStrategy(); + + // Iterator and if Iterator enforces constraints + vector> depIterators; + for (Iterator iterator : lattice.points()[0].iterators()) { + if (!iterator.isDimensionIterator()) { + depIterators.push_back({iterator, true}); } - Forall forallk = to(foralli.getStmt()); - IndexVar k = forallk.getIndexVar(); - - if (!isa(forallk.getStmt())) { - return stmt; + } + + for (Iterator iterator : lattice.points()[0].locators()) { + depIterators.push_back({iterator, false}); + } + + // add result iterators that append + for (Iterator iterator : lattice.results()) { + depIterators.push_back({iterator, !iterator.hasInsert()}); + } + + for (const auto& iteratorPair : depIterators) { + int level = iteratorPair.first.getMode().getLevel(); + string tensor = to(iteratorPair.first.getTensor())->name; + if (tensorLevelVars.count(tensor)) { + tensorLevelVars[tensor].insert({{i, {level, iteratorPair.second}}}); } - Forall forallj = to(forallk.getStmt()); - IndexVar j = forallj.getIndexVar(); - - if (!isa(forallj.getStmt())) { - return stmt; + else { + tensorLevelVars[tensor] = {{{i, {level, iteratorPair.second}}}}; } - Assignment assignment = to(forallj.getStmt()); + } + + if (!isa(foralli.getStmt())) { + innerBody = foralli.getStmt(); + return; // Only reorder first contiguous section of ForAlls + } + IndexNotationVisitor::visit(node); + } + }; + + Iterators iterators(stmt); + DAGBuilder dagBuilder(iterators); + stmt.accept(&dagBuilder); + + // Construct tensor dependencies (sorted list of IndexVars) from tensorLevelVars + map>> tensorVarOrders; + for (const auto& tensorLevelVar : dagBuilder.tensorLevelVars) { + tensorVarOrders[tensorLevelVar.first] = + varOrderFromTensorLevels(tensorLevelVar.second); + } + const auto hardDeps = depsFromVarOrders(tensorVarOrders); + + struct CollectSoftDependencies : public IndexNotationVisitor { + using IndexNotationVisitor::visit; + + map> softDeps; + + void visit(const AssignmentNode* op) { + op->lhs.accept(this); + op->rhs.accept(this); + } + + void visit(const AccessNode* node) { + const auto& modeOrdering = node->tensorVar.getFormat().getModeOrdering(); + for (size_t i = 1; i < (size_t)node->tensorVar.getOrder(); ++i) { + const auto srcVar = node->indexVars[modeOrdering[i - 1]]; + const auto dstVar = node->indexVars[modeOrdering[i]]; + softDeps[dstVar].insert(srcVar); + } + } + }; + CollectSoftDependencies collectSoftDeps; + stmt.accept(&collectSoftDeps); + + const auto sortedVars = topologicallySort(hardDeps, collectSoftDeps.softDeps, + dagBuilder.indexVarOriginalOrder); + + // Reorder Foralls use a rewriter in case new nodes introduced outside of Forall + struct TopoReorderRewriter : public IndexNotationRewriter { + using IndexNotationRewriter::visit; + + const vector& sortedVars; + IndexStmt innerBody; + const map forallParallelUnit; + const map forallOutputRaceStrategy; + + TopoReorderRewriter(const vector& sortedVars, IndexStmt innerBody, + const map forallParallelUnit, + const map forallOutputRaceStrategy) + : sortedVars(sortedVars), innerBody(innerBody), + forallParallelUnit(forallParallelUnit), forallOutputRaceStrategy(forallOutputRaceStrategy) { + } + + void visit(const ForallNode* node) { + Forall foralli(node); + IndexVar i = foralli.getIndexVar(); + + // first forall must be in collected variables + taco_iassert(util::contains(sortedVars, i)); + stmt = innerBody; + for (auto it = sortedVars.rbegin(); it != sortedVars.rend(); ++it) { + stmt = forall(*it, stmt, foralli.getMergeStrategy(), forallParallelUnit.at(*it), forallOutputRaceStrategy.at(*it), foralli.getUnrollFactor()); + } + return; + } + + }; + TopoReorderRewriter rewriter(sortedVars, dagBuilder.innerBody, + dagBuilder.forallParallelUnit, dagBuilder.forallOutputRaceStrategy); + return rewriter.rewrite(stmt); +} - if (!isa(assignment.getRhs())) { - return stmt; +IndexStmt scalarPromote(IndexStmt stmt, ProvenanceGraph provGraph, + bool isWholeStmt, bool promoteScalar) { + std::map hoistLevel; + std::map reduceOp; + struct FindHoistLevel : public IndexNotationVisitor { + using IndexNotationVisitor::visit; + + std::map& hoistLevel; + std::map& reduceOp; + std::map> hoistIndices; + std::set derivedIndices; + std::set indices; + const ProvenanceGraph& provGraph; + const bool isWholeStmt; + const bool promoteScalar; + + FindHoistLevel(std::map& hoistLevel, + std::map& reduceOp, + const ProvenanceGraph& provGraph, + bool isWholeStmt, bool promoteScalar) : + hoistLevel(hoistLevel), reduceOp(reduceOp), provGraph(provGraph), + isWholeStmt(isWholeStmt), promoteScalar(promoteScalar) {} + + void visit(const ForallNode* node) { + Forall foralli(node); + IndexVar i = foralli.getIndexVar(); + + // Don't allow hoisting out of forall's for GPU warp and block reduction + if (foralli.getParallelUnit() == ParallelUnit::GPUWarpReduction || + foralli.getParallelUnit() == ParallelUnit::GPUBlockReduction) { + FindHoistLevel findHoistLevel(hoistLevel, reduceOp, provGraph, false, + promoteScalar); + foralli.getStmt().accept(&findHoistLevel); + return; + } + + std::vector resultAccesses; + std::tie(resultAccesses, std::ignore) = getResultAccesses(foralli); + for (const auto& resultAccess : resultAccesses) { + if (!promoteScalar && resultAccess.getIndexVars().empty()) { + continue; } - Mul mul = to(assignment.getRhs()); - taco_iassert(isa(assignment.getLhs())); - if (!isa(mul.getA())) { - return stmt; - } - if (!isa(mul.getB())) { - return stmt; + std::set resultIndices(resultAccess.getIndexVars().begin(), + resultAccess.getIndexVars().end()); + if (std::includes(indices.begin(), indices.end(), + resultIndices.begin(), resultIndices.end()) && + !util::contains(hoistLevel, resultAccess)) { + hoistLevel[resultAccess] = node; + hoistIndices[resultAccess] = indices; + + auto resultDerivedIndices = resultIndices; + for (const auto& iv : resultIndices) { + for (const auto& div : provGraph.getFullyDerivedDescendants(iv)) { + resultDerivedIndices.insert(div); + } + } + if (!isWholeStmt || resultDerivedIndices != derivedIndices) { + reduceOp[resultAccess] = IndexExpr(); + } } + } - Access Aaccess = to(assignment.getLhs()); - Access Baccess = to(mul.getA()); - Access Caccess = to(mul.getB()); + auto newIndices = provGraph.newlyRecoverableParents(i, derivedIndices); + newIndices.push_back(i); + derivedIndices.insert(newIndices.begin(), newIndices.end()); - if (Aaccess.getIndexVars().size() != 2 || - Baccess.getIndexVars().size() != 2 || - Caccess.getIndexVars().size() != 2) { - return stmt; + const auto underivedIndices = getIndexVars(foralli); + for (const auto& newIndex : newIndices) { + if (util::contains(underivedIndices, newIndex)) { + indices.insert(newIndex); } - - if (!compare(Aaccess.getIndexVars(), {i,j}) || - !compare(Baccess.getIndexVars(), {i,k}) || - !compare(Caccess.getIndexVars(), {k,j})) { - return stmt; - } - - TensorVar A = Aaccess.getTensorVar(); - if (A.getFormat().getModeFormats()[0].getName() != "dense" || - A.getFormat().getModeFormats()[1].getName() != "compressed" || - A.getFormat().getModeOrdering()[0] != 0 || - A.getFormat().getModeOrdering()[1] != 1) { - return stmt; - } - - // I think we can to linear combination of rows as long as there are no permutations in the format and the - // level formats are ordered. The i -> k -> j loops should iterate over the data structures without issue. - TensorVar B = Baccess.getTensorVar(); - if (!B.getFormat().getModeFormats()[0].isOrdered() || - !B.getFormat().getModeFormats()[1].isOrdered() || - B.getFormat().getModeOrdering()[0] != 0 || - B.getFormat().getModeOrdering()[1] != 1) { - return stmt; + } + + IndexNotationVisitor::visit(node); + + for (const auto& newIndex : newIndices) { + indices.erase(newIndex); + derivedIndices.erase(newIndex); + } + } + + void visit(const AssignmentNode* op) { + if (util::contains(hoistLevel, op->lhs) && + hoistIndices[op->lhs] == indices) { + hoistLevel.erase(op->lhs); + } + if (util::contains(reduceOp, op->lhs)) { + reduceOp[op->lhs] = op->op; + } + } + }; + FindHoistLevel findHoistLevel(hoistLevel, reduceOp, provGraph, isWholeStmt, + promoteScalar); + stmt.accept(&findHoistLevel); + + struct HoistWrites : public IndexNotationRewriter { + using IndexNotationRewriter::visit; + + const std::map& hoistLevel; + const std::map& reduceOp; + + HoistWrites(const std::map& hoistLevel, + const std::map& reduceOp) : + hoistLevel(hoistLevel), reduceOp(reduceOp) {} + + void visit(const ForallNode* node) { + Forall foralli(node); + IndexVar i = foralli.getIndexVar(); + IndexStmt body = rewrite(foralli.getStmt()); + + std::vector consumers; + for (const auto& resultAccess : hoistLevel) { + if (resultAccess.second == node) { + // This assumes the index expression yields at most one result tensor; + // will not work correctly if there are multiple results. + TensorVar resultVar = resultAccess.first.getTensorVar(); + TensorVar val("t" + i.getName() + resultVar.getName(), + Type(resultVar.getType().getDataType(), {})); + body = ReplaceReductionExpr( + map({{resultAccess.first, val()}})).rewrite(body); + + IndexExpr op = util::contains(reduceOp, resultAccess.first) + ? reduceOp.at(resultAccess.first) : IndexExpr(); + IndexStmt consumer = Assignment(Access(resultAccess.first), val(), op); + consumers.push_back(consumer); } + } + + if (body == foralli.getStmt()) { + taco_iassert(consumers.empty()); + stmt = node; + return; + } + + stmt = forall(i, body, foralli.getMergeStrategy(), foralli.getParallelUnit(), + foralli.getOutputRaceStrategy(), foralli.getUnrollFactor()); + for (const auto& consumer : consumers) { + stmt = where(consumer, stmt); + } + } + }; + HoistWrites hoistWrites(hoistLevel, reduceOp); + return hoistWrites.rewrite(stmt); +} - TensorVar C = Caccess.getTensorVar(); - if (!C.getFormat().getModeFormats()[0].isOrdered() || - !C.getFormat().getModeFormats()[1].isOrdered() || - C.getFormat().getModeOrdering()[0] != 0 || - C.getFormat().getModeOrdering()[1] != 1) { - return stmt; - } +IndexStmt scalarPromote(IndexStmt stmt) { + return scalarPromote(stmt, ProvenanceGraph(stmt), true, false); +} - // It's an SpMM statement so return an optimized SpMM statement - TensorVar w("w", - Type(A.getType().getDataType(), - {A.getType().getShape().getDimension(1)}), - taco::dense); - return forall(i, - where(forall(j, - A(i,j) = w(j)), - forall(k, - forall(j, - w(j) += B(i,k) * C(k,j))))); - } +static bool compare(std::vector vars1, std::vector vars2) { + return vars1 == vars2; +} - IndexStmt insertTemporaries(IndexStmt stmt) - { - IndexStmt spmm = optimizeSpMM(stmt); - if (spmm != stmt) { - return spmm; - } +// TODO Temporary function to insert workspaces into SpMM kernels +static IndexStmt optimizeSpMM(IndexStmt stmt) { + if (!isa(stmt)) { + return stmt; + } + Forall foralli = to(stmt); + IndexVar i = foralli.getIndexVar(); + + if (!isa(foralli.getStmt())) { + return stmt; + } + Forall forallk = to(foralli.getStmt()); + IndexVar k = forallk.getIndexVar(); + + if (!isa(forallk.getStmt())) { + return stmt; + } + Forall forallj = to(forallk.getStmt()); + IndexVar j = forallj.getIndexVar(); + + if (!isa(forallj.getStmt())) { + return stmt; + } + Assignment assignment = to(forallj.getStmt()); + + if (!isa(assignment.getRhs())) { + return stmt; + } + Mul mul = to(assignment.getRhs()); + + taco_iassert(isa(assignment.getLhs())); + if (!isa(mul.getA())) { + return stmt; + } + if (!isa(mul.getB())) { + return stmt; + } + + Access Aaccess = to(assignment.getLhs()); + Access Baccess = to(mul.getA()); + Access Caccess = to(mul.getB()); + + if (Aaccess.getIndexVars().size() != 2 || + Baccess.getIndexVars().size() != 2 || + Caccess.getIndexVars().size() != 2) { + return stmt; + } + + if (!compare(Aaccess.getIndexVars(), {i,j}) || + !compare(Baccess.getIndexVars(), {i,k}) || + !compare(Caccess.getIndexVars(), {k,j})) { + return stmt; + } + + TensorVar A = Aaccess.getTensorVar(); + if (A.getFormat().getModeFormats()[0].getName() != "dense" || + A.getFormat().getModeFormats()[1].getName() != "compressed" || + A.getFormat().getModeOrdering()[0] != 0 || + A.getFormat().getModeOrdering()[1] != 1) { + return stmt; + } + + // I think we can to linear combination of rows as long as there are no permutations in the format and the + // level formats are ordered. The i -> k -> j loops should iterate over the data structures without issue. + TensorVar B = Baccess.getTensorVar(); + if (!B.getFormat().getModeFormats()[0].isOrdered() || + !B.getFormat().getModeFormats()[1].isOrdered() || + B.getFormat().getModeOrdering()[0] != 0 || + B.getFormat().getModeOrdering()[1] != 1) { + return stmt; + } + + TensorVar C = Caccess.getTensorVar(); + if (!C.getFormat().getModeFormats()[0].isOrdered() || + !C.getFormat().getModeFormats()[1].isOrdered() || + C.getFormat().getModeOrdering()[0] != 0 || + C.getFormat().getModeOrdering()[1] != 1) { + return stmt; + } + + // It's an SpMM statement so return an optimized SpMM statement + TensorVar w("w", + Type(A.getType().getDataType(), + {A.getType().getShape().getDimension(1)}), + taco::dense); + return forall(i, + where(forall(j, + A(i,j) = w(j)), + forall(k, + forall(j, + w(j) += B(i,k) * C(k,j))))); +} - // TODO Implement general workspacing when scattering into sparse result modes +IndexStmt insertTemporaries(IndexStmt stmt) +{ + IndexStmt spmm = optimizeSpMM(stmt); + if (spmm != stmt) { + return spmm; + } - // Result dimensions that are indexed by free variables dominated by a - // reduction variable are scattered into. If any of these are compressed - // then we introduce a dense workspace to scatter into instead. The where - // statement must push the reduction loop into the producer side, leaving - // only the free variable loops on the consumer side. + // TODO Implement general workspacing when scattering into sparse result modes - //vector reductionVars = getReductionVars(stmt); - //... + // Result dimensions that are indexed by free variables dominated by a + // reduction variable are scattered into. If any of these are compressed + // then we introduce a dense workspace to scatter into instead. The where + // statement must push the reduction loop into the producer side, leaving + // only the free variable loops on the consumer side. - return stmt; - } + //vector reductionVars = getReductionVars(stmt); + //... + return stmt; } + +} \ No newline at end of file diff --git a/src/ir/ir_rewriter.cpp b/src/ir/ir_rewriter.cpp index 2a1c2b723..e8d154181 100644 --- a/src/ir/ir_rewriter.cpp +++ b/src/ir/ir_rewriter.cpp @@ -435,7 +435,7 @@ void IRRewriter::visit(const Allocate* op) { stmt = op; } else { - stmt = Allocate::make(var, num_elements, op->is_realloc, op->old_elements); + stmt = Allocate::make(var, num_elements, op->is_realloc, op->old_elements, op->clear); } } diff --git a/src/ir_tags.cpp b/src/ir_tags.cpp index af3dbd775..4afe9cef8 100644 --- a/src/ir_tags.cpp +++ b/src/ir_tags.cpp @@ -6,5 +6,6 @@ const char *ParallelUnit_NAMES[] = {"NotParallel", "DefaultUnit", "GPUBlock", "G const char *OutputRaceStrategy_NAMES[] = {"IgnoreRaces", "NoRaces", "Atomics", "Temporary", "ParallelReduction"}; const char *BoundType_NAMES[] = {"MinExact", "MinConstraint", "MaxExact", "MaxConstraint"}; const char *AssembleStrategy_NAMES[] = {"Append", "Insert"}; +const char *MergeStrategy_NAMES[] = {"TwoFinger", "Gallop"}; } diff --git a/src/lower/lowerer_impl_imperative.cpp b/src/lower/lowerer_impl_imperative.cpp index f370242db..ff4c10b21 100644 --- a/src/lower/lowerer_impl_imperative.cpp +++ b/src/lower/lowerer_impl_imperative.cpp @@ -872,7 +872,7 @@ Stmt LowererImplImperative::lowerForall(Forall forall) std::vector underivedAncestors = provGraph.getUnderivedAncestors(forall.getIndexVar()); taco_iassert(underivedAncestors.size() == 1); // TODO: add support for fused coordinate of pos loop loops = lowerMergeLattice(caseLattice, underivedAncestors[0], - forall.getStmt(), reducedAccesses); + forall.getStmt(), reducedAccesses, forall.getMergeStrategy()); } // taco_iassert(loops.defined()); @@ -1203,7 +1203,7 @@ Stmt LowererImplImperative::lowerForallDimension(Forall forall, } Stmt body = lowerForallBody(coordinate, forall.getStmt(), locators, inserters, - appenders, caseLattice, reducedAccesses); + appenders, caseLattice, reducedAccesses, forall.getMergeStrategy()); if (forall.getParallelUnit() != ParallelUnit::NotParallel && forall.getOutputRaceStrategy() == OutputRaceStrategy::Atomics) { markAssignsAtomicDepth--; @@ -1264,7 +1264,7 @@ Stmt LowererImplImperative::lowerForallDimension(Forall forall, } Stmt declareVar = VarDecl::make(coordinate, Load::make(indexList, loopVar)); - Stmt body = lowerForallBody(coordinate, forall.getStmt(), locators, inserters, appenders, caseLattice, reducedAccesses); + Stmt body = lowerForallBody(coordinate, forall.getStmt(), locators, inserters, appenders, caseLattice, reducedAccesses, forall.getMergeStrategy()); Stmt resetGuard = ir::Store::make(bitGuard, coordinate, ir::Literal::make(false), markAssignsAtomicDepth > 0, atomicParallelUnit); if (forall.getParallelUnit() != ParallelUnit::NotParallel && forall.getOutputRaceStrategy() == OutputRaceStrategy::Atomics) { @@ -1340,7 +1340,7 @@ Stmt LowererImplImperative::lowerForallPosition(Forall forall, Iterator iterator markAssignsAtomicDepth++; } - Stmt body = lowerForallBody(coordinate, forall.getStmt(), locators, inserters, appenders, caseLattice, reducedAccesses); + Stmt body = lowerForallBody(coordinate, forall.getStmt(), locators, inserters, appenders, caseLattice, reducedAccesses, forall.getMergeStrategy()); if (forall.getParallelUnit() != ParallelUnit::NotParallel && forall.getOutputRaceStrategy() == OutputRaceStrategy::Atomics) { markAssignsAtomicDepth--; @@ -1391,24 +1391,28 @@ Stmt LowererImplImperative::lowerForallPosition(Forall forall, Iterator iterator endBound = endBounds[1]; } - LoopKind kind = LoopKind::Serial; - if (forall.getParallelUnit() == ParallelUnit::CPUVector && !ignoreVectorize) { - kind = LoopKind::Vectorized; - } - else if (forall.getParallelUnit() != ParallelUnit::NotParallel - && forall.getOutputRaceStrategy() != OutputRaceStrategy::ParallelReduction && !ignoreVectorize) { - kind = LoopKind::Runtime; + Stmt loop = Block::make(strideGuard, declareCoordinate, boundsGuard, body); + if (iterator.isBranchless() && iterator.isCompact() && + (iterator.getParent().isRoot() || iterator.getParent().isUnique())) { + loop = Block::make(VarDecl::make(iterator.getPosVar(), startBound), loop); + } else { + LoopKind kind = LoopKind::Serial; + if (forall.getParallelUnit() == ParallelUnit::CPUVector && !ignoreVectorize) { + kind = LoopKind::Vectorized; + } + else if (forall.getParallelUnit() != ParallelUnit::NotParallel && + forall.getOutputRaceStrategy() != OutputRaceStrategy::ParallelReduction && + !ignoreVectorize) { + kind = LoopKind::Runtime; + } + + loop = For::make(iterator.getPosVar(), startBound, endBound, 1, loop, kind, + ignoreVectorize ? ParallelUnit::NotParallel : forall.getParallelUnit(), + ignoreVectorize ? 0 : forall.getUnrollFactor()); } // Loop with preamble and postamble - return Block::blanks( - boundsCompute, - For::make(iterator.getPosVar(), startBound, endBound, 1, - Block::make(strideGuard, declareCoordinate, boundsGuard, body), - kind, - ignoreVectorize ? ParallelUnit::NotParallel : forall.getParallelUnit(), ignoreVectorize ? 0 : forall.getUnrollFactor()), - posAppend); - + return Block::blanks(boundsCompute, loop, posAppend); } Stmt LowererImplImperative::lowerForallFusedPosition(Forall forall, Iterator iterator, @@ -1510,7 +1514,7 @@ Stmt LowererImplImperative::lowerForallFusedPosition(Forall forall, Iterator ite } Stmt body = lowerForallBody(coordinate, forall.getStmt(), - locators, inserters, appenders, caseLattice, reducedAccesses); + locators, inserters, appenders, caseLattice, reducedAccesses, forall.getMergeStrategy()); if (forall.getParallelUnit() != ParallelUnit::NotParallel && forall.getOutputRaceStrategy() == OutputRaceStrategy::Atomics) { markAssignsAtomicDepth--; @@ -1578,7 +1582,8 @@ Stmt LowererImplImperative::lowerForallFusedPosition(Forall forall, Iterator ite Stmt LowererImplImperative::lowerMergeLattice(MergeLattice caseLattice, IndexVar coordinateVar, IndexStmt statement, - const std::set& reducedAccesses) + const std::set& reducedAccesses, + MergeStrategy mergestrategy) { // Lower merge lattice always gets called from lowerForAll. So we want loop lattice MergeLattice loopLattice = caseLattice.getLoopLattice(); @@ -1604,7 +1609,7 @@ Stmt LowererImplImperative::lowerMergeLattice(MergeLattice caseLattice, IndexVar // points in the merge lattice. IndexStmt zeroedStmt = zero(statement, getExhaustedAccesses(point, caseLattice)); MergeLattice sublattice = caseLattice.subLattice(point); - Stmt mergeLoop = lowerMergePoint(sublattice, coordinate, coordinateVar, zeroedStmt, reducedAccesses, resolvedCoordDeclared); + Stmt mergeLoop = lowerMergePoint(sublattice, coordinate, coordinateVar, zeroedStmt, reducedAccesses, resolvedCoordDeclared, mergestrategy); mergeLoopsVec.push_back(mergeLoop); } Stmt mergeLoops = Block::make(mergeLoopsVec); @@ -1619,7 +1624,8 @@ Stmt LowererImplImperative::lowerMergeLattice(MergeLattice caseLattice, IndexVar Stmt LowererImplImperative::lowerMergePoint(MergeLattice pointLattice, ir::Expr coordinate, IndexVar coordinateVar, IndexStmt statement, - const std::set& reducedAccesses, bool resolvedCoordDeclared) + const std::set& reducedAccesses, bool resolvedCoordDeclared, + MergeStrategy mergeStrategy) { MergePoint point = pointLattice.points().front(); @@ -1640,7 +1646,9 @@ Stmt LowererImplImperative::lowerMergePoint(MergeLattice pointLattice, std::vector indexSetStmts; for (auto& iter : filter(iterators, [](Iterator it) { return it.hasIndexSet(); })) { // For each iterator A with an index set B, emit the following code: - // setMatch = min(A, B); // Check whether A matches its index set at this point. + // // Check whether A matches its index set at this point. + // // Using max instead of min because we will be merging the iterators by galloping. + // setMatch = max(A, B); // if (A == setMatch && B == setMatch) { // // If there was a match, project down the values of the iterators // // to be the position variable of the index set iterator. This has the @@ -1648,17 +1656,17 @@ Stmt LowererImplImperative::lowerMergePoint(MergeLattice pointLattice, // A_coord = B_pos; // B_coord = B_pos; // } else { - // // Advance the iterator and it's index set iterator accordingly if + // // Advance the iterator and its index set iterator accordingly if // // there wasn't a match. - // A_pos += (A == setMatch); - // B_pos += (B == setMatch); + // A_pos = taco_gallop(int *A_array, int A_pos, int A_arrayEnd, int setMatch); + // B_pos = taco_gallop(int *B_array, int B_pos, int B_arrayEnd, int setMatch); // // We must continue so that we only proceed to the rest of the cases in // // the merge if there actually is a point present for A. // continue; // } auto setMatch = ir::Var::make("setMatch", Int()); auto indexSetIter = iter.getIndexSetIterator(); - indexSetStmts.push_back(ir::VarDecl::make(setMatch, ir::Min::make(this->coordinates({iter, indexSetIter})))); + indexSetStmts.push_back(ir::VarDecl::make(setMatch, ir::Max::make(this->coordinates({iter, indexSetIter})))); // Equality checks for each iterator. auto iterEq = ir::Eq::make(iter.getCoordVar(), setMatch); auto setEq = ir::Eq::make(indexSetIter.getCoordVar(), setMatch); @@ -1668,9 +1676,25 @@ Stmt LowererImplImperative::lowerMergePoint(MergeLattice pointLattice, ir::Assign::make(indexSetIter.getCoordVar(), indexSetIter.getPosVar()) ); // Code to increment both iterator variables. + auto ivar = iter.getIteratorVar(); + Expr iteratorParentPos = iter.getParent().getPosVar(); + ModeFunction iterBounds = iter.posBounds(iteratorParentPos); + vector iterGallopArgs = { + iter.getMode().getModePack().getArray(1), + ivar, iterBounds[1], + setMatch + }; + auto indexVar = indexSetIter.getIteratorVar(); + Expr indexIterParentPos = indexSetIter.getParent().getPosVar(); + ModeFunction indexIterBounds = indexSetIter.posBounds(indexIterParentPos); + vector indexGallopArgs = { + indexSetIter.getMode().getModePack().getArray(1), + indexVar, indexIterBounds[1], + setMatch + }; auto incr = ir::Block::make( - compoundAssign(iter.getIteratorVar(), ir::Cast::make(Eq::make(iter.getCoordVar(), setMatch), iter.getIteratorVar().type())), - compoundAssign(indexSetIter.getIteratorVar(), ir::Cast::make(Eq::make(indexSetIter.getCoordVar(), setMatch), indexSetIter.getIteratorVar().type())), + ir::Assign::make(ivar, ir::Call::make("taco_gallop", iterGallopArgs, ivar.type())), + ir::Assign::make(indexVar, ir::Call::make("taco_gallop", indexGallopArgs, indexVar.type())), ir::Continue::make() ); // Code that uses the defined parts together in the if-then-else. @@ -1678,7 +1702,13 @@ Stmt LowererImplImperative::lowerMergePoint(MergeLattice pointLattice, } // Merge iterator coordinate variables - Stmt resolvedCoordinate = resolveCoordinate(mergers, coordinate, !resolvedCoordDeclared); + bool mergeWithMax; + if (mergeStrategy == MergeStrategy::Gallop) { + mergeWithMax = true; + } else { + mergeWithMax = false; + } + Stmt resolvedCoordinate = resolveCoordinate(mergers, coordinate, !resolvedCoordDeclared, mergeWithMax); // Locate positions Stmt loadLocatorPosVars = declLocatePosVars(locators); @@ -1692,10 +1722,10 @@ Stmt LowererImplImperative::lowerMergePoint(MergeLattice pointLattice, // One case for each child lattice point lp Stmt caseStmts = lowerMergeCases(coordinate, coordinateVar, statement, pointLattice, - reducedAccesses); + reducedAccesses, mergeStrategy); // Increment iterator position variables - Stmt incIteratorVarStmts = codeToIncIteratorVars(coordinate, coordinateVar, iterators, mergers); + Stmt incIteratorVarStmts = codeToIncIteratorVars(coordinate, coordinateVar, iterators, mergers, mergeStrategy); /// While loop over rangers return While::make(checkThatNoneAreExhausted(rangers), @@ -1708,7 +1738,7 @@ Stmt LowererImplImperative::lowerMergePoint(MergeLattice pointLattice, incIteratorVarStmts)); } -Stmt LowererImplImperative::resolveCoordinate(std::vector mergers, ir::Expr coordinate, bool emitVarDecl) { +Stmt LowererImplImperative::resolveCoordinate(std::vector mergers, ir::Expr coordinate, bool emitVarDecl, bool mergeWithMax) { if (mergers.size() == 1) { Iterator merger = mergers[0]; if (merger.hasPosIter()) { @@ -1763,24 +1793,30 @@ Stmt LowererImplImperative::resolveCoordinate(std::vector mergers, ir: } else { // Multiple position iterators so the smallest is the resolved coordinate - if (emitVarDecl) { - return VarDecl::make(coordinate, Min::make(coordinates(mergers))); + Expr merged; + if (mergeWithMax) { + merged = Max::make(coordinates(mergers)); + } else { + merged = Min::make(coordinates(mergers)); } - else { - return Assign::make(coordinate, Min::make(coordinates(mergers))); + if (emitVarDecl) { + return VarDecl::make(coordinate, merged); + } else { + return Assign::make(coordinate, merged); } } } Stmt LowererImplImperative::lowerMergeCases(ir::Expr coordinate, IndexVar coordinateVar, IndexStmt stmt, MergeLattice caseLattice, - const std::set& reducedAccesses) + const std::set& reducedAccesses, + MergeStrategy mergeStrategy) { vector result; if (caseLattice.anyModeIteratorIsLeaf() && caseLattice.needExplicitZeroChecks()) { // Can check value array of some tensor - Stmt body = lowerMergeCasesWithExplicitZeroChecks(coordinate, coordinateVar, stmt, caseLattice, reducedAccesses); + Stmt body = lowerMergeCasesWithExplicitZeroChecks(coordinate, coordinateVar, stmt, caseLattice, reducedAccesses, mergeStrategy); result.push_back(body); return Block::make(result); } @@ -1798,7 +1834,7 @@ Stmt LowererImplImperative::lowerMergeCases(ir::Expr coordinate, IndexVar coordi // Just one iterator so no conditional taco_iassert(!loopLattice.points()[0].isOmitter()); Stmt body = lowerForallBody(coordinate, stmt, {}, inserters, - appenders, loopLattice, reducedAccesses); + appenders, loopLattice, reducedAccesses, mergeStrategy); result.push_back(body); } else if (!loopLattice.points().empty()) { @@ -1823,10 +1859,10 @@ Stmt LowererImplImperative::lowerMergeCases(ir::Expr coordinate, IndexVar coordi // Construct case body IndexStmt zeroedStmt = zero(stmt, getExhaustedAccesses(point, loopLattice)); Stmt body = lowerForallBody(coordinate, zeroedStmt, {}, - inserters, appenders, MergeLattice({point}), reducedAccesses); + inserters, appenders, MergeLattice({point}), reducedAccesses, mergeStrategy); if (coordComparisons.empty()) { Stmt body = lowerForallBody(coordinate, stmt, {}, inserters, - appenders, MergeLattice({point}), reducedAccesses); + appenders, MergeLattice({point}), reducedAccesses, mergeStrategy); result.push_back(body); break; } @@ -1915,7 +1951,8 @@ std::vector LowererImplImperative::constructInnerLoopCasePreamble(ir:: vector LowererImplImperative::lowerCasesFromMap(map iteratorToCondition, ir::Expr coordinate, IndexStmt stmt, const MergeLattice& lattice, - const std::set& reducedAccesses) { + const std::set& reducedAccesses, + MergeStrategy mergeStrategy) { vector appenders; vector inserters; @@ -1949,10 +1986,10 @@ vector LowererImplImperative::lowerCasesFromMap(map iterat // Construct case body IndexStmt zeroedStmt = zero(stmt, getExhaustedAccesses(point, lattice)); Stmt body = lowerForallBody(coordinate, zeroedStmt, {}, - inserters, appenders, MergeLattice({point}), reducedAccesses); + inserters, appenders, MergeLattice({point}), reducedAccesses, mergeStrategy); if (isNonZeroComparisions.empty()) { Stmt body = lowerForallBody(coordinate, stmt, {}, inserters, - appenders, MergeLattice({point}), reducedAccesses); + appenders, MergeLattice({point}), reducedAccesses, mergeStrategy); result.push_back(body); break; } @@ -1972,7 +2009,7 @@ vector LowererImplImperative::lowerCasesFromMap(map iterat Access access = iterators.modeAccess(it).getAccess(); IndexStmt initStmt = Assignment(access, Literal::zero(access.getDataType())); Stmt initialization = lowerForallBody(coordinate, initStmt, {}, inserters, - appenders, MergeLattice({}), reducedAccesses); + appenders, MergeLattice({}), reducedAccesses, mergeStrategy); stmts.push_back(initialization); } } @@ -1987,7 +2024,8 @@ vector LowererImplImperative::lowerCasesFromMap(map iterat /// Lowers a merge lattice to cases assuming there are no more loops to be emitted in stmt. Stmt LowererImplImperative::lowerMergeCasesWithExplicitZeroChecks(ir::Expr coordinate, IndexVar coordinateVar, IndexStmt stmt, - MergeLattice lattice, const std::set& reducedAccesses) { + MergeLattice lattice, const std::set& reducedAccesses, + MergeStrategy mergeStrategy) { vector result; if (lattice.points().size() == 1 && lattice.iterators().size() == 1 @@ -1999,14 +2037,14 @@ Stmt LowererImplImperative::lowerMergeCasesWithExplicitZeroChecks(ir::Expr coord tie(appenders, inserters) = splitAppenderAndInserters(lattice.results()); taco_iassert(!lattice.points()[0].isOmitter()); Stmt body = lowerForallBody(coordinate, stmt, {}, inserters, - appenders, lattice, reducedAccesses); + appenders, lattice, reducedAccesses, mergeStrategy); result.push_back(body); } else if (!lattice.points().empty()) { map iteratorToConditionMap; vector preamble = constructInnerLoopCasePreamble(coordinate, coordinateVar, lattice, iteratorToConditionMap); util::append(result, preamble); - vector cases = lowerCasesFromMap(iteratorToConditionMap, coordinate, stmt, lattice, reducedAccesses); + vector cases = lowerCasesFromMap(iteratorToConditionMap, coordinate, stmt, lattice, reducedAccesses, mergeStrategy); util::append(result, cases); } @@ -2018,7 +2056,8 @@ Stmt LowererImplImperative::lowerForallBody(Expr coordinate, IndexStmt stmt, vector inserters, vector appenders, MergeLattice caseLattice, - const set& reducedAccesses) { + const set& reducedAccesses, + MergeStrategy mergeStrategy) { // Inserter positions Stmt declInserterPosVars = declLocatePosVars(inserters); @@ -2051,7 +2090,7 @@ Stmt LowererImplImperative::lowerForallBody(Expr coordinate, IndexStmt stmt, } // This will lower the body for each case to actually compute. Therefore, we don't need to resize assembly arrays - std::vector loweredCases = lowerCasesFromMap(caseMap, coordinate, stmt, caseLattice, reducedAccesses); + std::vector loweredCases = lowerCasesFromMap(caseMap, coordinate, stmt, caseLattice, reducedAccesses, mergeStrategy); append(stmts, loweredCases); Stmt body = Block::make(stmts); @@ -2067,13 +2106,26 @@ Stmt LowererImplImperative::lowerForallBody(Expr coordinate, IndexStmt stmt, // Code to append coordinates Stmt appendCoords = appendCoordinate(appenders, coordinate); + std::vector stmts; + + // Code to increment iterators when merging by galloping. + if (mergeStrategy == MergeStrategy::Gallop && caseLattice.iterators().size() > 1) { + for (auto it : caseLattice.iterators()) { + Expr ivar = it.getIteratorVar(); + stmts.push_back(compoundAssign(ivar, 1)); + } + } + + Stmt incr = Block::make(stmts); + // TODO: Emit code to insert coordinates return Block::make(initVals, declInserterPosVars, declLocatorPosVars, body, - appendCoords); + appendCoords, + incr); } Expr LowererImplImperative::getTemporarySize(Where where) { @@ -3489,7 +3541,7 @@ Stmt LowererImplImperative::codeToInitializeIteratorVar(Iterator iterator, vecto else { result.push_back(codeToLoadCoordinatesFromPosIterators(iterators, true)); - Stmt stmt = resolveCoordinate(mergers, coordinate, true); + Stmt stmt = resolveCoordinate(mergers, coordinate, true, false); taco_iassert(stmt != Stmt()); result.push_back(stmt); result.push_back(codeToRecoverDerivedIndexVar(coordinateVar, iterator.getIndexVar(), true)); @@ -3541,7 +3593,7 @@ Stmt LowererImplImperative::codeToRecoverDerivedIndexVar(IndexVar underived, Ind return Stmt(); } -Stmt LowererImplImperative::codeToIncIteratorVars(Expr coordinate, IndexVar coordinateVar, vector iterators, vector mergers) { +Stmt LowererImplImperative::codeToIncIteratorVars(Expr coordinate, IndexVar coordinateVar, vector iterators, vector mergers, MergeStrategy strategy) { if (iterators.size() == 1) { Expr ivar = iterators[0].getIteratorVar(); @@ -3568,12 +3620,23 @@ Stmt LowererImplImperative::codeToIncIteratorVars(Expr coordinate, IndexVar coor for (auto& iterator : levelIterators) { Expr ivar = iterator.getIteratorVar(); if (iterator.isUnique()) { - Expr increment = iterator.isFull() - ? 1 - : ir::Cast::make(Eq::make(iterator.getCoordVar(), - coordinate), - ivar.type()); - result.push_back(compoundAssign(ivar, increment)); + if (iterator.isFull()) { + Expr increment = 1; + result.push_back(compoundAssign(ivar, increment)); + } else if (strategy == MergeStrategy::Gallop) { + Expr iteratorParentPos = iterator.getParent().getPosVar(); + ModeFunction iterBounds = iterator.posBounds(iteratorParentPos); + result.push_back(iterBounds.compute()); + vector gallopArgs = { + iterator.getMode().getModePack().getArray(1), + ivar, iterBounds[1], + coordinate, + }; + result.push_back(ir::Assign::make(ivar, ir::Call::make("taco_gallop", gallopArgs, ivar.type()))); + } else { // strategy == MergeStrategy::TwoFinger + Expr increment = ir::Cast::make(Eq::make(iterator.getCoordVar(), coordinate), ivar.type()); + result.push_back(compoundAssign(ivar, increment)); + } } else if (!iterator.isLeaf()) { result.push_back(Assign::make(ivar, iterator.getSegendVar())); } @@ -3589,7 +3652,7 @@ Stmt LowererImplImperative::codeToIncIteratorVars(Expr coordinate, IndexVar coor } else { result.push_back(codeToLoadCoordinatesFromPosIterators(iterators, false)); - Stmt stmt = resolveCoordinate(mergers, coordinate, false); + Stmt stmt = resolveCoordinate(mergers, coordinate, false, false); taco_iassert(stmt != Stmt()); result.push_back(stmt); result.push_back(codeToRecoverDerivedIndexVar(coordinateVar, iterator.getIndexVar(), false)); diff --git a/src/lower/mode_format_compressed.cpp b/src/lower/mode_format_compressed.cpp index 11d8c5fe9..1b751515a 100644 --- a/src/lower/mode_format_compressed.cpp +++ b/src/lower/mode_format_compressed.cpp @@ -17,8 +17,8 @@ CompressedModeFormat::CompressedModeFormat(bool isFull, bool isOrdered, bool isUnique, bool isZeroless, long long allocSize) : ModeFormatImpl("compressed", isFull, isOrdered, isUnique, false, true, - isZeroless, false, true, false, false, true, true, true, - false), + isZeroless, false, false, true, false, false, true, true, + true, false), allocSize(allocSize) { } diff --git a/src/lower/mode_format_dense.cpp b/src/lower/mode_format_dense.cpp index ff8eed4f4..7bf1991f1 100644 --- a/src/lower/mode_format_dense.cpp +++ b/src/lower/mode_format_dense.cpp @@ -11,7 +11,7 @@ DenseModeFormat::DenseModeFormat() : DenseModeFormat(true, true, false) { DenseModeFormat::DenseModeFormat(const bool isOrdered, const bool isUnique, const bool isZeroless) : ModeFormatImpl("dense", true, isOrdered, isUnique, false, true, isZeroless, - false, false, true, true, false, false, false, true) { + true, false, false, true, true, false, false, false, true) { } ModeFormat DenseModeFormat::copy( diff --git a/src/lower/mode_format_impl.cpp b/src/lower/mode_format_impl.cpp index fada2c67d..32c389063 100644 --- a/src/lower/mode_format_impl.cpp +++ b/src/lower/mode_format_impl.cpp @@ -147,15 +147,16 @@ std::ostream& operator<<(std::ostream& os, const ModeFunction& modeFunction) { // class ModeTypeImpl ModeFormatImpl::ModeFormatImpl(const std::string name, bool isFull, bool isOrdered, bool isUnique, bool isBranchless, - bool isCompact, bool isZeroless, + bool isCompact, bool isZeroless, bool isPadded, bool hasCoordValIter, bool hasCoordPosIter, bool hasLocate, bool hasInsert, bool hasAppend, bool hasSeqInsertEdge, bool hasInsertCoord, bool isYieldPosPure) : name(name), isFull(isFull), isOrdered(isOrdered), isUnique(isUnique), isBranchless(isBranchless), isCompact(isCompact), isZeroless(isZeroless), - hasCoordValIter(hasCoordValIter), hasCoordPosIter(hasCoordPosIter), - hasLocate(hasLocate), hasInsert(hasInsert), hasAppend(hasAppend), + isPadded(isPadded), hasCoordValIter(hasCoordValIter), + hasCoordPosIter(hasCoordPosIter), hasLocate(hasLocate), + hasInsert(hasInsert), hasAppend(hasAppend), hasSeqInsertEdge(hasSeqInsertEdge), hasInsertCoord(hasInsertCoord), isYieldPosPure(isYieldPosPure) { } diff --git a/src/lower/mode_format_singleton.cpp b/src/lower/mode_format_singleton.cpp index 82eaf2bf6..86ae66b4f 100644 --- a/src/lower/mode_format_singleton.cpp +++ b/src/lower/mode_format_singleton.cpp @@ -10,15 +10,15 @@ using namespace taco::ir; namespace taco { SingletonModeFormat::SingletonModeFormat() : - SingletonModeFormat(false, true, true, false) { + SingletonModeFormat(false, true, true, false, false) { } SingletonModeFormat::SingletonModeFormat(bool isFull, bool isOrdered, bool isUnique, bool isZeroless, - long long allocSize) : + bool isPadded, long long allocSize) : ModeFormatImpl("singleton", isFull, isOrdered, isUnique, true, true, - isZeroless, false, true, false, false, true, false, true, - true), + isZeroless, isPadded, false, true, false, false, true, + false, true, true), allocSize(allocSize) { } @@ -28,6 +28,7 @@ ModeFormat SingletonModeFormat::copy( bool isOrdered = this->isOrdered; bool isUnique = this->isUnique; bool isZeroless = this->isZeroless; + bool isPadded = this->isPadded; for (const auto property : properties) { switch (property) { case ModeFormat::FULL: @@ -54,13 +55,19 @@ ModeFormat SingletonModeFormat::copy( case ModeFormat::NOT_ZEROLESS: isZeroless = false; break; + case ModeFormat::PADDED: + isPadded = true; + break; + case ModeFormat::NOT_PADDED: + isPadded = false; + break; default: break; } } const auto singletonVariant = std::make_shared(isFull, isOrdered, isUnique, - isZeroless); + isZeroless, isPadded); return ModeFormat(singletonVariant); } @@ -128,7 +135,7 @@ Expr SingletonModeFormat::getAssembledSize(Expr prevSize, Mode mode) const { Stmt SingletonModeFormat::getInitCoords(Expr prevSize, std::vector queries, Mode mode) const { Expr crdArray = getCoordArray(mode.getModePack()); - return Allocate::make(crdArray, prevSize, false, Expr()); + return Allocate::make(crdArray, prevSize, false, Expr(), isPadded); } ModeFunction SingletonModeFormat::getYieldPos(Expr parentPos, diff --git a/src/tensor.cpp b/src/tensor.cpp index 1a89a30ec..257c396c3 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -965,6 +965,7 @@ TensorBase::getHelperFunctions(const Format& format, Datatype ctype, TensorVar packedTensor(Type(ctype, Shape(dims)), format); // Define packing and iterator routines in index notation. + // TODO: Use `generatePackCOOStmt` function to generate pack routine. std::vector indexVars(format.getOrder()); IndexStmt packStmt = (packedTensor(indexVars) = bufferTensor(indexVars)); IndexStmt iterateStmt = Yield(indexVars, packedTensor(indexVars)); @@ -974,6 +975,21 @@ TensorBase::getHelperFunctions(const Format& format, Datatype ctype, iterateStmt = forall(indexVars[mode], iterateStmt); } + bool doAppend = true; + for (int i = format.getOrder() - 1; i >= 0; --i) { + const auto modeFormat = format.getModeFormats()[i]; + if (modeFormat.isBranchless() && i != 0) { + const auto parentModeFormat = format.getModeFormats()[i - 1]; + if (parentModeFormat.isUnique() || !parentModeFormat.hasAppend()) { + doAppend = false; + break; + } + } + } + if (!doAppend) { + packStmt = packStmt.assemble(packedTensor, AssembleStrategy::Insert); + } + // Lower packing and iterator code. helperModule->addFunction(lower(packStmt, "pack", true, true)); helperModule->addFunction(lower(iterateStmt, "iterate", false, true)); diff --git a/test/test_tensors.cpp b/test/test_tensors.cpp index 03f86aa09..2acdb1dae 100644 --- a/test/test_tensors.cpp +++ b/test/test_tensors.cpp @@ -137,6 +137,16 @@ TensorData d5d_data() { }); } +TensorData d5e_data() { + return TensorData({5}, { + {{0}, 1}, + {{1}, 2}, + {{2}, 3}, + {{3}, 4}, + {{4}, 5} + }); +} + TensorData d8a_data() { return TensorData({8}, { {{0}, 1}, @@ -328,6 +338,23 @@ TensorData d333a_data() { }); } +TensorData d355a_data() { + return TensorData({3,5,5}, { + {{0,0,0}, 1}, + {{0,1,1}, 2}, + {{0,2,1}, 3}, + {{0,3,1}, 4}, + {{0,4,1}, 5}, + {{1,0,1}, 6}, + {{1,1,0}, 7}, + {{1,2,0}, 8}, + {{1,4,2}, 9}, + {{2,1,2}, 10}, + {{2,2,3}, 11}, + {{2,4,4}, 12}, + }); +} + TensorData d32b_data() { return TensorData({3,2}, { {{0,0}, 10}, @@ -406,6 +433,10 @@ Tensor d5d(std::string name, Format format) { return d5d_data().makeTensor(name, format); } +Tensor d5e(std::string name, Format format) { + return d5e_data().makeTensor(name, format); +} + Tensor d8a(std::string name, Format format) { return d8a_data().makeTensor(name, format); } @@ -486,6 +517,10 @@ Tensor d333a(std::string name, Format format) { return d333a_data().makeTensor(name, format); } +Tensor d355a(std::string name, Format format) { + return d355a_data().makeTensor(name, format); +} + Tensor d32b(std::string name, Format format) { return d32b_data().makeTensor(name, format); } diff --git a/test/test_tensors.h b/test/test_tensors.h index bb9197650..dba4b2cd1 100644 --- a/test/test_tensors.h +++ b/test/test_tensors.h @@ -101,6 +101,7 @@ TensorData d5a_data(); TensorData d5b_data(); TensorData d5c_data(); TensorData d5d_data(); +TensorData d5e_data(); TensorData d8a_data(); TensorData d8b_data(); @@ -127,6 +128,8 @@ TensorData d233c_data(); TensorData d333a_data(); +TensorData d355a_data(); + TensorData d32b_data(); TensorData d3322a_data(); @@ -146,6 +149,7 @@ Tensor d5a(std::string name, Format format); Tensor d5b(std::string name, Format format); Tensor d5c(std::string name, Format format); Tensor d5d(std::string name, Format format); +Tensor d5e(std::string name, Format format); Tensor d8a(std::string name, Format format); Tensor d8b(std::string name, Format format); @@ -175,6 +179,8 @@ Tensor d233c(std::string name, Format format); Tensor d333a(std::string name, Format format); +Tensor d355a(std::string name, Format format); + Tensor d32b(std::string name, Format format); Tensor d3322a(std::string name, Format format); diff --git a/test/tests-expr_storage.cpp b/test/tests-expr_storage.cpp index 28924760f..0bc8d6909 100644 --- a/test/tests-expr_storage.cpp +++ b/test/tests-expr_storage.cpp @@ -957,6 +957,25 @@ INSTANTIATE_TEST_CASE_P(bspmv, expr, ) ); +Format ell({Dense, Dense, Singleton({ModeFormat::UNIQUE, ModeFormat::PADDED})}); + +INSTANTIATE_TEST_CASE_P(espmv, expr, + Values( + TestData(Tensor("a",{5},Format({Dense})), + {i}, + d355a("B", ell)(j,i,k) * + d5e("c",Format({Dense}))(k), + { + { + // Dense index + {5} + }, + }, + {13,41,58,8,97} + ) + ) +); + INSTANTIATE_TEST_CASE_P(matrix_sum, expr, Values( TestData(Tensor("a",{},Format()), diff --git a/test/tests-index_notation.cpp b/test/tests-index_notation.cpp index 8090e71de..df6cbf938 100644 --- a/test/tests-index_notation.cpp +++ b/test/tests-index_notation.cpp @@ -149,7 +149,7 @@ TEST(notation, isomorphic) { ASSERT_FALSE(isomorphic(forall(i, forall(j, A(i,j) = B(i,j) + C(i,j))), forall(i, forall(j, A(j,i) = B(j,i) + C(j,i))))); ASSERT_FALSE(isomorphic(forall(i, forall(j, A(i,j) = B(i,j) + C(i,j), - ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces)), + MergeStrategy::TwoFinger, ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces)), forall(j, forall(i, A(j,i) = B(j,i) + C(j,i))))); ASSERT_TRUE(isomorphic(sum(j, B(i,j) + C(i,j)), sum(i, B(j,i) + C(j,i)))); ASSERT_FALSE(isomorphic(sum(j, B(i,j) + C(i,j)), sum(j, B(j,i) + C(j,i)))); diff --git a/test/tests-merge_lattice.cpp b/test/tests-merge_lattice.cpp index c156597b5..36adf41a4 100644 --- a/test/tests-merge_lattice.cpp +++ b/test/tests-merge_lattice.cpp @@ -26,8 +26,8 @@ namespace tests { class HashedModeFormat : public ModeFormatImpl { public: HashedModeFormat() : ModeFormatImpl("hashed", false, false, true, false, - false, false, false, true, true, true, - false, true, true, false) {} + false, false, false, false, true, true, + true, false, true, true, false) {} ModeFormat copy(std::vector properties) const { return ModeFormat(std::make_shared()); diff --git a/test/tests-scheduling.cpp b/test/tests-scheduling.cpp index ca827421b..ee564577b 100644 --- a/test/tests-scheduling.cpp +++ b/test/tests-scheduling.cpp @@ -1053,3 +1053,85 @@ TEST(scheduling, divide) { return stmt.fuse(i, j, f).pos(f, fpos, A(i, j)).divide(fpos, f0, f1, 4).split(f1, i1, i2, 16).split(i2, i3, i4, 8); }); } + +TEST(scheduling, mergeby) { + auto dim = 256; + float sparsity = 0.1; + Tensor A("A", {dim, dim}, {Sparse, Sparse}); + Tensor B("B", {dim, dim}, {Dense, Sparse}); + Tensor x("x", {dim}, Sparse); + IndexVar i("i"), i1("i1"), i2("i2"), ipos("ipos"), j("j"), f("f"), fpos("fpos"), f0("f0"), f1("f1"); + + srand(59393); + for (int i = 0; i < dim; i++) { + for (int j = 0; j < dim; j++) { + auto rand_float = (float)rand()/(float)(RAND_MAX); + if (rand_float < sparsity) { + A.insert({i, j},((int)(rand_float * 10 / sparsity))); + B.insert({i, j},((int)(rand_float * 10 / sparsity))); + } + } + } + + for (int j = 0; j < dim; j++) { + float rand_float = (float)rand()/(float)(RAND_MAX); + x.insert({j}, ((int)(rand_float*10))); + } + + x.pack(); A.pack(); B.pack(); + + auto test = [&](std::function f) { + Tensor y("y", {dim}, Dense); + y(i) = A(i, j) * B(i, j) * x(j); + auto stmt = f(y.getAssignment().concretize()); + y.compile(stmt); + y.evaluate(); + Tensor expected("expected", {dim}, Dense); + expected(i) = A(i, j) * B(i, j) * x(j); + expected.evaluate(); + ASSERT_TRUE(equals(expected, y)) << expected << endl << y << endl; + }; + + // Test that a simple mergeby works. + test([&](IndexStmt stmt) { + return stmt.mergeby(j, MergeStrategy::Gallop); + }); + + // Testing Two Finger merge. + test([&](IndexStmt stmt) { + return stmt.mergeby(j, MergeStrategy::TwoFinger); + }); + + // Merging a dimension with a dense iterator with Gallop should be no-op. + test([&](IndexStmt stmt) { + return stmt.mergeby(i, MergeStrategy::Gallop); + }); + + // Test interaction between mergeby and other directives + test([&](IndexStmt stmt) { + return stmt.mergeby(i, MergeStrategy::Gallop).split(i, i1, i2, 16); + }); + + test([&](IndexStmt stmt) { + return stmt.mergeby(i, MergeStrategy::Gallop).split(i, i1, i2, 32).unroll(i1, 4); + }); + + test([&](IndexStmt stmt) { + return stmt.mergeby(i, MergeStrategy::Gallop).pos(i, ipos, A(i,j)); + }); + + test([&](IndexStmt stmt) { + return stmt.mergeby(i, MergeStrategy::Gallop).parallelize(i, ParallelUnit::CPUThread, OutputRaceStrategy::NoRaces); + }); +} + +TEST(scheduling, mergeby_gallop_error) { + Tensor x("x", {8}, Format({Sparse})); + Tensor y("y", {8}, Format({Dense})); + Tensor z("z", {8}, Format({Sparse})); + IndexVar i("i"), ipos("ipos"); + y(i) = x(i) + z(i); + + IndexStmt stmt = y.getAssignment().concretize(); + ASSERT_THROW(stmt.mergeby(i, MergeStrategy::Gallop), taco::TacoException); +} \ No newline at end of file diff --git a/test/tests-transformation.cpp b/test/tests-transformation.cpp index abfec3d45..83ff16510 100644 --- a/test/tests-transformation.cpp +++ b/test/tests-transformation.cpp @@ -239,15 +239,15 @@ INSTANTIATE_TEST_CASE_P(parallelize, apply, Values( TransformationTest(Parallelize(i), forall(i, w(i) = b(i)), - forall(i, w(i) = b(i), ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces) + forall(i, w(i) = b(i), MergeStrategy::TwoFinger, ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces) ), TransformationTest(Parallelize(i), forall(i, forall(j, W(i,j) = A(i,j))), - forall(i, forall(j, W(i,j) = A(i,j)), ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces) + forall(i, forall(j, W(i,j) = A(i,j)), MergeStrategy::TwoFinger, ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces) ), TransformationTest(Parallelize(j), forall(i, forall(j, W(i,j) = A(i,j))), - forall(i, forall(j, W(i,j) = A(i,j), ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces)) + forall(i, forall(j, W(i,j) = A(i,j), MergeStrategy::TwoFinger, ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces)) ) ) ); @@ -264,8 +264,8 @@ INSTANTIATE_TEST_CASE_P(misc, reorderLoopsTopologically, Values( NotationTest(forall(i, w(i) = b(i)), forall(i, w(i) = b(i))), - NotationTest(forall(i, w(i) = b(i), ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces), - forall(i, w(i) = b(i), ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces)), + NotationTest(forall(i, w(i) = b(i), MergeStrategy::TwoFinger, ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces), + forall(i, w(i) = b(i), MergeStrategy::TwoFinger, ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces)), NotationTest(forall(i, forall(j, W(i,j) = A(i,j))), forall(i, forall(j, W(i,j) = A(i,j)))), diff --git a/tools/taco.cpp b/tools/taco.cpp index 449b09918..78023e7f6 100644 --- a/tools/taco.cpp +++ b/tools/taco.cpp @@ -102,8 +102,8 @@ static void printUsageInfo() { printFlag("f=:", "Specify the format of a tensor in the expression. Formats are " "specified per dimension using d (dense), s (sparse), " - "u (sparse, not unique), q (singleton), or c (singleton, not unique). " - "All formats default to dense. " + "u (sparse, not unique), q (singleton), c (singleton, not unique), " + "or p (singleton, padded). All formats default to dense. " "The ordering of modes can also be optionally specified as a " "comma-delimited list of modes in the order they should be stored. " "Examples: A:ds (i.e., CSR), B:ds:1,0 (i.e., CSC), c:d (i.e., " @@ -497,6 +497,24 @@ static bool setSchedulingCommands(vector> scheduleCommands, parse stmt = stmt.reorder(reorderedVars); + } else if (command == "mergeby") { + taco_uassert(scheduleCommand.size() == 2) << "'mergeby' scheduling directive takes 2 parameters: mergeby(i, strategy)"; + string i, strat; + MergeStrategy strategy; + + i = scheduleCommand[0]; + strat = scheduleCommand[1]; + if (strat == "TwoFinger") { + strategy = MergeStrategy::TwoFinger; + } else if (strat == "Gallop") { + strategy = MergeStrategy::Gallop; + } else { + taco_uerror << "Merge strategy not defined."; + goto end; + } + + stmt = stmt.mergeby(findVar(i), strategy); + } else if (command == "bound") { taco_uassert(scheduleCommand.size() == 4) << "'bound' scheduling directive takes 4 parameters: bound(i, i1, bound, type)"; string i, i1, type; @@ -751,6 +769,9 @@ int main(int argc, char* argv[]) { case 'q': modeTypes.push_back(ModeFormat::Singleton); break; + case 'p': + modeTypes.push_back(ModeFormat::Singleton(ModeFormat::PADDED)); + break; default: return reportError("Incorrect format descriptor", 3); break;