From cf4509589b72c54f5f86f58efe0edce1739a1339 Mon Sep 17 00:00:00 2001 From: Stephen Chou Date: Thu, 16 Sep 2021 15:43:19 -0400 Subject: [PATCH 1/4] Fixed support for ELL format --- src/index_notation/index_notation.cpp | 16 ++++++++++++ src/ir/ir_rewriter.cpp | 2 +- src/lower/lowerer_impl_imperative.cpp | 34 ++++++++++++++------------ src/lower/mode_format_singleton.cpp | 2 +- src/tensor.cpp | 16 ++++++++++++ test/test_tensors.cpp | 35 +++++++++++++++++++++++++++ test/test_tensors.h | 6 +++++ test/tests-expr_storage.cpp | 17 +++++++++++++ 8 files changed, 111 insertions(+), 17 deletions(-) diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index 1e462e47a..06f7c85d4 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -3538,6 +3538,22 @@ IndexStmt generatePackStmt(TensorVar tensor, packStmt = forall(indexVars[mode], packStmt); } + bool doAppend = true; + const Format lhsFormat = otherIsOnRight ? format : otherFormat; + for (int i = lhsFormat.getOrder() - 1; i >= 0; --i) { + const auto modeFormat = lhsFormat.getModeFormats()[i]; + if (modeFormat.isBranchless() && i != 0) { + const auto parentModeFormat = lhsFormat.getModeFormats()[i - 1]; + if (parentModeFormat.isUnique() || !parentModeFormat.hasAppend()) { + doAppend = false; + break; + } + } + } + if (!doAppend) { + packStmt = packStmt.assemble(otherIsOnRight ? tensor : other, AssembleStrategy::Insert); + } + return packStmt; } diff --git a/src/ir/ir_rewriter.cpp b/src/ir/ir_rewriter.cpp index eed6f2bab..9309358bc 100644 --- a/src/ir/ir_rewriter.cpp +++ b/src/ir/ir_rewriter.cpp @@ -425,7 +425,7 @@ void IRRewriter::visit(const Allocate* op) { stmt = op; } else { - stmt = Allocate::make(var, num_elements, op->is_realloc, op->old_elements); + stmt = Allocate::make(var, num_elements, op->is_realloc, op->old_elements, op->clear); } } diff --git a/src/lower/lowerer_impl_imperative.cpp b/src/lower/lowerer_impl_imperative.cpp index c289be13e..978e61113 100644 --- a/src/lower/lowerer_impl_imperative.cpp +++ b/src/lower/lowerer_impl_imperative.cpp @@ -1335,24 +1335,28 @@ Stmt LowererImplImperative::lowerForallPosition(Forall forall, Iterator iterator endBound = endBounds[1]; } - LoopKind kind = LoopKind::Serial; - if (forall.getParallelUnit() == ParallelUnit::CPUVector && !ignoreVectorize) { - kind = LoopKind::Vectorized; - } - else if (forall.getParallelUnit() != ParallelUnit::NotParallel - && forall.getOutputRaceStrategy() != OutputRaceStrategy::ParallelReduction && !ignoreVectorize) { - kind = LoopKind::Runtime; + Stmt loop = Block::make(strideGuard, declareCoordinate, boundsGuard, body); + if (iterator.isBranchless() && iterator.isCompact() && + (iterator.getParent().isRoot() || iterator.getParent().isUnique())) { + loop = Block::make(VarDecl::make(iterator.getPosVar(), startBound), loop); + } else { + LoopKind kind = LoopKind::Serial; + if (forall.getParallelUnit() == ParallelUnit::CPUVector && !ignoreVectorize) { + kind = LoopKind::Vectorized; + } + else if (forall.getParallelUnit() != ParallelUnit::NotParallel && + forall.getOutputRaceStrategy() != OutputRaceStrategy::ParallelReduction && + !ignoreVectorize) { + kind = LoopKind::Runtime; + } + + loop = For::make(iterator.getPosVar(), startBound, endBound, 1, loop, kind, + ignoreVectorize ? ParallelUnit::NotParallel : forall.getParallelUnit(), + ignoreVectorize ? 0 : forall.getUnrollFactor()); } // Loop with preamble and postamble - return Block::blanks( - boundsCompute, - For::make(iterator.getPosVar(), startBound, endBound, 1, - Block::make(strideGuard, declareCoordinate, boundsGuard, body), - kind, - ignoreVectorize ? ParallelUnit::NotParallel : forall.getParallelUnit(), ignoreVectorize ? 0 : forall.getUnrollFactor()), - posAppend); - + return Block::blanks(boundsCompute, loop, posAppend); } Stmt LowererImplImperative::lowerForallFusedPosition(Forall forall, Iterator iterator, diff --git a/src/lower/mode_format_singleton.cpp b/src/lower/mode_format_singleton.cpp index 237a76160..578b7e1b5 100644 --- a/src/lower/mode_format_singleton.cpp +++ b/src/lower/mode_format_singleton.cpp @@ -128,7 +128,7 @@ Expr SingletonModeFormat::getAssembledSize(Expr prevSize, Mode mode) const { Stmt SingletonModeFormat::getInitCoords(Expr prevSize, std::vector queries, Mode mode) const { Expr crdArray = getCoordArray(mode.getModePack()); - return Allocate::make(crdArray, prevSize, false, Expr()); + return Allocate::make(crdArray, prevSize, false, Expr(), true); } ModeFunction SingletonModeFormat::getYieldPos(Expr parentPos, diff --git a/src/tensor.cpp b/src/tensor.cpp index fab437ff1..e6f16ae85 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -941,6 +941,7 @@ TensorBase::getHelperFunctions(const Format& format, Datatype ctype, TensorVar packedTensor(Type(ctype, Shape(dims)), format); // Define packing and iterator routines in index notation. + // TODO: Use `generatePackCOOStmt` function to generate pack routine. std::vector indexVars(format.getOrder()); IndexStmt packStmt = (packedTensor(indexVars) = bufferTensor(indexVars)); IndexStmt iterateStmt = Yield(indexVars, packedTensor(indexVars)); @@ -950,6 +951,21 @@ TensorBase::getHelperFunctions(const Format& format, Datatype ctype, iterateStmt = forall(indexVars[mode], iterateStmt); } + bool doAppend = true; + for (int i = format.getOrder() - 1; i >= 0; --i) { + const auto modeFormat = format.getModeFormats()[i]; + if (modeFormat.isBranchless() && i != 0) { + const auto parentModeFormat = format.getModeFormats()[i - 1]; + if (parentModeFormat.isUnique() || !parentModeFormat.hasAppend()) { + doAppend = false; + break; + } + } + } + if (!doAppend) { + packStmt = packStmt.assemble(packedTensor, AssembleStrategy::Insert); + } + // Lower packing and iterator code. helperModule->addFunction(lower(packStmt, "pack", true, true)); helperModule->addFunction(lower(iterateStmt, "iterate", false, true)); diff --git a/test/test_tensors.cpp b/test/test_tensors.cpp index 03f86aa09..2acdb1dae 100644 --- a/test/test_tensors.cpp +++ b/test/test_tensors.cpp @@ -137,6 +137,16 @@ TensorData d5d_data() { }); } +TensorData d5e_data() { + return TensorData({5}, { + {{0}, 1}, + {{1}, 2}, + {{2}, 3}, + {{3}, 4}, + {{4}, 5} + }); +} + TensorData d8a_data() { return TensorData({8}, { {{0}, 1}, @@ -328,6 +338,23 @@ TensorData d333a_data() { }); } +TensorData d355a_data() { + return TensorData({3,5,5}, { + {{0,0,0}, 1}, + {{0,1,1}, 2}, + {{0,2,1}, 3}, + {{0,3,1}, 4}, + {{0,4,1}, 5}, + {{1,0,1}, 6}, + {{1,1,0}, 7}, + {{1,2,0}, 8}, + {{1,4,2}, 9}, + {{2,1,2}, 10}, + {{2,2,3}, 11}, + {{2,4,4}, 12}, + }); +} + TensorData d32b_data() { return TensorData({3,2}, { {{0,0}, 10}, @@ -406,6 +433,10 @@ Tensor d5d(std::string name, Format format) { return d5d_data().makeTensor(name, format); } +Tensor d5e(std::string name, Format format) { + return d5e_data().makeTensor(name, format); +} + Tensor d8a(std::string name, Format format) { return d8a_data().makeTensor(name, format); } @@ -486,6 +517,10 @@ Tensor d333a(std::string name, Format format) { return d333a_data().makeTensor(name, format); } +Tensor d355a(std::string name, Format format) { + return d355a_data().makeTensor(name, format); +} + Tensor d32b(std::string name, Format format) { return d32b_data().makeTensor(name, format); } diff --git a/test/test_tensors.h b/test/test_tensors.h index bb9197650..dba4b2cd1 100644 --- a/test/test_tensors.h +++ b/test/test_tensors.h @@ -101,6 +101,7 @@ TensorData d5a_data(); TensorData d5b_data(); TensorData d5c_data(); TensorData d5d_data(); +TensorData d5e_data(); TensorData d8a_data(); TensorData d8b_data(); @@ -127,6 +128,8 @@ TensorData d233c_data(); TensorData d333a_data(); +TensorData d355a_data(); + TensorData d32b_data(); TensorData d3322a_data(); @@ -146,6 +149,7 @@ Tensor d5a(std::string name, Format format); Tensor d5b(std::string name, Format format); Tensor d5c(std::string name, Format format); Tensor d5d(std::string name, Format format); +Tensor d5e(std::string name, Format format); Tensor d8a(std::string name, Format format); Tensor d8b(std::string name, Format format); @@ -175,6 +179,8 @@ Tensor d233c(std::string name, Format format); Tensor d333a(std::string name, Format format); +Tensor d355a(std::string name, Format format); + Tensor d32b(std::string name, Format format); Tensor d3322a(std::string name, Format format); diff --git a/test/tests-expr_storage.cpp b/test/tests-expr_storage.cpp index 28924760f..04ba25733 100644 --- a/test/tests-expr_storage.cpp +++ b/test/tests-expr_storage.cpp @@ -957,6 +957,23 @@ INSTANTIATE_TEST_CASE_P(bspmv, expr, ) ); +INSTANTIATE_TEST_CASE_P(espmv, expr, + Values( + TestData(Tensor("a",{5},Format({Dense})), + {i}, + d355a("B",Format({Dense, Dense, Singleton(ModeFormat::UNIQUE)}))(j,i,k) * + d5e("c",Format({Dense}))(k), + { + { + // Dense index + {5} + }, + }, + {13,41,58,8,97} + ) + ) +); + INSTANTIATE_TEST_CASE_P(matrix_sum, expr, Values( TestData(Tensor("a",{},Format()), From 3744a821122f3202f480cfa1341eabf6d4fc99a0 Mon Sep 17 00:00:00 2001 From: Stephen Chou Date: Thu, 16 Sep 2021 18:07:04 -0400 Subject: [PATCH 2/4] Added padded property --- include/taco/format.h | 4 +++- include/taco/lower/mode_format_impl.h | 8 +++++--- include/taco/lower/mode_format_singleton.h | 5 +++-- src/format.cpp | 15 +++++++++++++++ src/lower/mode_format_compressed.cpp | 4 ++-- src/lower/mode_format_dense.cpp | 2 +- src/lower/mode_format_impl.cpp | 7 ++++--- src/lower/mode_format_singleton.cpp | 19 +++++++++++++------ test/tests-expr_storage.cpp | 4 +++- test/tests-merge_lattice.cpp | 4 ++-- tools/taco.cpp | 3 +++ 11 files changed, 54 insertions(+), 21 deletions(-) diff --git a/include/taco/format.h b/include/taco/format.h index 7a46f2410..81bdadda4 100644 --- a/include/taco/format.h +++ b/include/taco/format.h @@ -97,7 +97,8 @@ class ModeFormat { /// Properties of a mode format enum Property { FULL, NOT_FULL, ORDERED, NOT_ORDERED, UNIQUE, NOT_UNIQUE, BRANCHLESS, - NOT_BRANCHLESS, COMPACT, NOT_COMPACT, ZEROLESS, NOT_ZEROLESS + NOT_BRANCHLESS, COMPACT, NOT_COMPACT, ZEROLESS, NOT_ZEROLESS, PADDED, + NOT_PADDED }; /// Instantiates an undefined mode format @@ -129,6 +130,7 @@ class ModeFormat { bool isBranchless() const; bool isCompact() const; bool isZeroless() const; + bool isPadded() const; /// Returns true if a mode format has a specific capability, false otherwise bool hasCoordValIter() const; diff --git a/include/taco/lower/mode_format_impl.h b/include/taco/lower/mode_format_impl.h index 3e2cbad66..824802286 100644 --- a/include/taco/lower/mode_format_impl.h +++ b/include/taco/lower/mode_format_impl.h @@ -106,9 +106,10 @@ class ModeFormatImpl { public: ModeFormatImpl(std::string name, bool isFull, bool isOrdered, bool isUnique, bool isBranchless, bool isCompact, bool isZeroless, - bool hasCoordValIter, bool hasCoordPosIter, bool hasLocate, - bool hasInsert, bool hasAppend, bool hasSeqInsertEdge, - bool hasInsertCoord, bool isYieldPosPure); + bool isPadded, bool hasCoordValIter, bool hasCoordPosIter, + bool hasLocate, bool hasInsert, bool hasAppend, + bool hasSeqInsertEdge, bool hasInsertCoord, + bool isYieldPosPure); virtual ~ModeFormatImpl(); @@ -246,6 +247,7 @@ class ModeFormatImpl { const bool isBranchless; const bool isCompact; const bool isZeroless; + const bool isPadded; const bool hasCoordValIter; const bool hasCoordPosIter; diff --git a/include/taco/lower/mode_format_singleton.h b/include/taco/lower/mode_format_singleton.h index 65bfac2c0..1be6ce53a 100644 --- a/include/taco/lower/mode_format_singleton.h +++ b/include/taco/lower/mode_format_singleton.h @@ -10,8 +10,9 @@ class SingletonModeFormat : public ModeFormatImpl { using ModeFormatImpl::getInsertCoord; SingletonModeFormat(); - SingletonModeFormat(bool isFull, bool isOrdered, - bool isUnique, bool isZeroless, long long allocSize = DEFAULT_ALLOC_SIZE); + SingletonModeFormat(bool isFull, bool isOrdered, bool isUnique, + bool isZeroless, bool isPadded, + long long allocSize = DEFAULT_ALLOC_SIZE); ~SingletonModeFormat() override {} diff --git a/src/format.cpp b/src/format.cpp index a5144ae24..0424a1a52 100644 --- a/src/format.cpp +++ b/src/format.cpp @@ -187,6 +187,11 @@ bool ModeFormat::hasProperties(const std::vector& properties) const { return false; } break; + case PADDED: + if (!isPadded()) { + return false; + } + break; case NOT_FULL: if (isFull()) { return false; @@ -217,6 +222,11 @@ bool ModeFormat::hasProperties(const std::vector& properties) const { return false; } break; + case NOT_PADDED: + if (isPadded()) { + return false; + } + break; } } return true; @@ -252,6 +262,11 @@ bool ModeFormat::isZeroless() const { return impl->isZeroless; } +bool ModeFormat::isPadded() const { + taco_iassert(defined()); + return impl->isPadded; +} + bool ModeFormat::hasCoordValIter() const { taco_iassert(defined()); return impl->hasCoordValIter; diff --git a/src/lower/mode_format_compressed.cpp b/src/lower/mode_format_compressed.cpp index d5b341ba7..41b0fe992 100644 --- a/src/lower/mode_format_compressed.cpp +++ b/src/lower/mode_format_compressed.cpp @@ -17,8 +17,8 @@ CompressedModeFormat::CompressedModeFormat(bool isFull, bool isOrdered, bool isUnique, bool isZeroless, long long allocSize) : ModeFormatImpl("compressed", isFull, isOrdered, isUnique, false, true, - isZeroless, false, true, false, false, true, true, true, - false), + isZeroless, false, false, true, false, false, true, true, + true, false), allocSize(allocSize) { } diff --git a/src/lower/mode_format_dense.cpp b/src/lower/mode_format_dense.cpp index ff8eed4f4..7bf1991f1 100644 --- a/src/lower/mode_format_dense.cpp +++ b/src/lower/mode_format_dense.cpp @@ -11,7 +11,7 @@ DenseModeFormat::DenseModeFormat() : DenseModeFormat(true, true, false) { DenseModeFormat::DenseModeFormat(const bool isOrdered, const bool isUnique, const bool isZeroless) : ModeFormatImpl("dense", true, isOrdered, isUnique, false, true, isZeroless, - false, false, true, true, false, false, false, true) { + true, false, false, true, true, false, false, false, true) { } ModeFormat DenseModeFormat::copy( diff --git a/src/lower/mode_format_impl.cpp b/src/lower/mode_format_impl.cpp index fada2c67d..32c389063 100644 --- a/src/lower/mode_format_impl.cpp +++ b/src/lower/mode_format_impl.cpp @@ -147,15 +147,16 @@ std::ostream& operator<<(std::ostream& os, const ModeFunction& modeFunction) { // class ModeTypeImpl ModeFormatImpl::ModeFormatImpl(const std::string name, bool isFull, bool isOrdered, bool isUnique, bool isBranchless, - bool isCompact, bool isZeroless, + bool isCompact, bool isZeroless, bool isPadded, bool hasCoordValIter, bool hasCoordPosIter, bool hasLocate, bool hasInsert, bool hasAppend, bool hasSeqInsertEdge, bool hasInsertCoord, bool isYieldPosPure) : name(name), isFull(isFull), isOrdered(isOrdered), isUnique(isUnique), isBranchless(isBranchless), isCompact(isCompact), isZeroless(isZeroless), - hasCoordValIter(hasCoordValIter), hasCoordPosIter(hasCoordPosIter), - hasLocate(hasLocate), hasInsert(hasInsert), hasAppend(hasAppend), + isPadded(isPadded), hasCoordValIter(hasCoordValIter), + hasCoordPosIter(hasCoordPosIter), hasLocate(hasLocate), + hasInsert(hasInsert), hasAppend(hasAppend), hasSeqInsertEdge(hasSeqInsertEdge), hasInsertCoord(hasInsertCoord), isYieldPosPure(isYieldPosPure) { } diff --git a/src/lower/mode_format_singleton.cpp b/src/lower/mode_format_singleton.cpp index 578b7e1b5..0b5946d03 100644 --- a/src/lower/mode_format_singleton.cpp +++ b/src/lower/mode_format_singleton.cpp @@ -10,15 +10,15 @@ using namespace taco::ir; namespace taco { SingletonModeFormat::SingletonModeFormat() : - SingletonModeFormat(false, true, true, false) { + SingletonModeFormat(false, true, true, false, false) { } SingletonModeFormat::SingletonModeFormat(bool isFull, bool isOrdered, bool isUnique, bool isZeroless, - long long allocSize) : + bool isPadded, long long allocSize) : ModeFormatImpl("singleton", isFull, isOrdered, isUnique, true, true, - isZeroless, false, true, false, false, true, false, true, - true), + isZeroless, isPadded, false, true, false, false, true, + false, true, true), allocSize(allocSize) { } @@ -28,6 +28,7 @@ ModeFormat SingletonModeFormat::copy( bool isOrdered = this->isOrdered; bool isUnique = this->isUnique; bool isZeroless = this->isZeroless; + bool isPadded = this->isPadded; for (const auto property : properties) { switch (property) { case ModeFormat::FULL: @@ -54,13 +55,19 @@ ModeFormat SingletonModeFormat::copy( case ModeFormat::NOT_ZEROLESS: isZeroless = false; break; + case ModeFormat::PADDED: + isPadded = true; + break; + case ModeFormat::NOT_PADDED: + isPadded = false; + break; default: break; } } const auto singletonVariant = std::make_shared(isFull, isOrdered, isUnique, - isZeroless); + isZeroless, isPadded); return ModeFormat(singletonVariant); } @@ -128,7 +135,7 @@ Expr SingletonModeFormat::getAssembledSize(Expr prevSize, Mode mode) const { Stmt SingletonModeFormat::getInitCoords(Expr prevSize, std::vector queries, Mode mode) const { Expr crdArray = getCoordArray(mode.getModePack()); - return Allocate::make(crdArray, prevSize, false, Expr(), true); + return Allocate::make(crdArray, prevSize, false, Expr(), isPadded); } ModeFunction SingletonModeFormat::getYieldPos(Expr parentPos, diff --git a/test/tests-expr_storage.cpp b/test/tests-expr_storage.cpp index 04ba25733..0bc8d6909 100644 --- a/test/tests-expr_storage.cpp +++ b/test/tests-expr_storage.cpp @@ -957,11 +957,13 @@ INSTANTIATE_TEST_CASE_P(bspmv, expr, ) ); +Format ell({Dense, Dense, Singleton({ModeFormat::UNIQUE, ModeFormat::PADDED})}); + INSTANTIATE_TEST_CASE_P(espmv, expr, Values( TestData(Tensor("a",{5},Format({Dense})), {i}, - d355a("B",Format({Dense, Dense, Singleton(ModeFormat::UNIQUE)}))(j,i,k) * + d355a("B", ell)(j,i,k) * d5e("c",Format({Dense}))(k), { { diff --git a/test/tests-merge_lattice.cpp b/test/tests-merge_lattice.cpp index e4c53bd9f..ca28448d3 100644 --- a/test/tests-merge_lattice.cpp +++ b/test/tests-merge_lattice.cpp @@ -24,8 +24,8 @@ namespace tests { class HashedModeFormat : public ModeFormatImpl { public: HashedModeFormat() : ModeFormatImpl("hashed", false, false, true, false, - false, false, false, true, true, true, - false, true, true, false) {} + false, false, false, false, true, true, + true, false, true, true, false) {} ModeFormat copy(std::vector properties) const { return ModeFormat(std::make_shared()); diff --git a/tools/taco.cpp b/tools/taco.cpp index 449b09918..b607b4ce8 100644 --- a/tools/taco.cpp +++ b/tools/taco.cpp @@ -751,6 +751,9 @@ int main(int argc, char* argv[]) { case 'q': modeTypes.push_back(ModeFormat::Singleton); break; + case 'p': + modeTypes.push_back(ModeFormat::Singleton(ModeFormat::PADDED)); + break; default: return reportError("Incorrect format descriptor", 3); break; From 51208a3d67ef6e8bc3db37e02d9f24d6b3c51e88 Mon Sep 17 00:00:00 2001 From: Stephen Chou Date: Thu, 12 May 2022 12:59:44 -0400 Subject: [PATCH 3/4] Updated CLI help --- tools/taco.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/taco.cpp b/tools/taco.cpp index b607b4ce8..fbdcfb984 100644 --- a/tools/taco.cpp +++ b/tools/taco.cpp @@ -102,8 +102,8 @@ static void printUsageInfo() { printFlag("f=:", "Specify the format of a tensor in the expression. Formats are " "specified per dimension using d (dense), s (sparse), " - "u (sparse, not unique), q (singleton), or c (singleton, not unique). " - "All formats default to dense. " + "u (sparse, not unique), q (singleton), c (singleton, not unique), " + "or p (singleton, padded). All formats default to dense. " "The ordering of modes can also be optionally specified as a " "comma-delimited list of modes in the order they should be stored. " "Examples: A:ds (i.e., CSR), B:ds:1,0 (i.e., CSC), c:d (i.e., " From 7a05d638eff301a07af4b5976ddd7c8e82f9d4df Mon Sep 17 00:00:00 2001 From: Remy Wang Date: Wed, 18 May 2022 17:33:29 -0700 Subject: [PATCH 4/4] Add mergeby scheduling directive. This commit addes a new scheduling directive called mergeby. This directive specifies if the iterators of a given variable is merged by Two Finger merge or Galloping. The default strategy is Two Finger merge which is the same as the old behavior. Galloping merges the iterators with exponential search, which can be more efficient if the iterator sizes are skewed. --- include/taco/index_notation/index_notation.h | 22 ++- .../index_notation/index_notation_nodes.h | 5 +- include/taco/index_notation/transformations.h | 21 +++ include/taco/ir_tags.h | 7 + include/taco/lower/lowerer_impl_imperative.h | 26 ++- src/codegen/codegen_c.cpp | 22 +++ src/index_notation/index_notation.cpp | 27 +++- .../index_notation_rewriter.cpp | 4 +- src/index_notation/transformations.cpp | 119 ++++++++++++-- src/ir_tags.cpp | 1 + src/lower/lowerer_impl_imperative.cpp | 153 ++++++++++++------ test/tests-index_notation.cpp | 2 +- test/tests-scheduling.cpp | 82 ++++++++++ test/tests-transformation.cpp | 10 +- tools/taco.cpp | 18 +++ 15 files changed, 436 insertions(+), 83 deletions(-) diff --git a/include/taco/index_notation/index_notation.h b/include/taco/index_notation/index_notation.h index 43cd36d7a..d5fe90a97 100644 --- a/include/taco/index_notation/index_notation.h +++ b/include/taco/index_notation/index_notation.h @@ -634,6 +634,23 @@ class IndexStmt : public util::IntrusivePtr { /// reorder takes a new ordering for a set of index variables that are directly nested in the iteration order IndexStmt reorder(std::vector reorderedvars) const; + /// The mergeby transformation specifies how to merge iterators on + /// the given index variable. By default, if an iterator is used for windowing + /// it will be merged with the "gallop" strategy. + /// All other iterators are merged with the "two finger" strategy. + /// The two finger strategy merges by advancing each iterator one at a time, + /// while the gallop strategy implements the exponential search algorithm. + /// + /// Preconditions: + /// This command applies to variables involving sparse iterators only; + /// it is a no-op if the variable invovles any dense iterators. + /// Any variable can be merged with the two finger strategy, whereas gallop + /// only applies to a variable if its merge lattice has a single point + /// (i.e. an intersection). For example, if a variable involves multiplications + /// only, it can be merged with gallop. + /// Furthermore, all iterators must be ordered for gallop to apply. + IndexStmt mergeby(IndexVar i, MergeStrategy strategy) const; + /// The parallelize /// transformation tags an index variable for parallel execution. The /// transformation takes as an argument the type of parallel hardware @@ -829,13 +846,14 @@ class Forall : public IndexStmt { Forall() = default; Forall(const ForallNode*); Forall(IndexVar indexVar, IndexStmt stmt); - Forall(IndexVar indexVar, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0); + Forall(IndexVar indexVar, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0); IndexVar getIndexVar() const; IndexStmt getStmt() const; ParallelUnit getParallelUnit() const; OutputRaceStrategy getOutputRaceStrategy() const; + MergeStrategy getMergeStrategy() const; size_t getUnrollFactor() const; @@ -844,7 +862,7 @@ class Forall : public IndexStmt { /// Create a forall index statement. Forall forall(IndexVar i, IndexStmt stmt); -Forall forall(IndexVar i, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0); +Forall forall(IndexVar i, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0); /// A where statment has a producer statement that binds a tensor variable in diff --git a/include/taco/index_notation/index_notation_nodes.h b/include/taco/index_notation/index_notation_nodes.h index 5289ff069..0feea404e 100644 --- a/include/taco/index_notation/index_notation_nodes.h +++ b/include/taco/index_notation/index_notation_nodes.h @@ -398,8 +398,8 @@ struct YieldNode : public IndexStmtNode { }; struct ForallNode : public IndexStmtNode { - ForallNode(IndexVar indexVar, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0) - : indexVar(indexVar), stmt(stmt), parallel_unit(parallel_unit), output_race_strategy(output_race_strategy), unrollFactor(unrollFactor) {} + ForallNode(IndexVar indexVar, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0) + : indexVar(indexVar), stmt(stmt), merge_strategy(merge_strategy), parallel_unit(parallel_unit), output_race_strategy(output_race_strategy), unrollFactor(unrollFactor) {} void accept(IndexStmtVisitorStrict* v) const { v->visit(this); @@ -407,6 +407,7 @@ struct ForallNode : public IndexStmtNode { IndexVar indexVar; IndexStmt stmt; + MergeStrategy merge_strategy; ParallelUnit parallel_unit; OutputRaceStrategy output_race_strategy; size_t unrollFactor = 0; diff --git a/include/taco/index_notation/transformations.h b/include/taco/index_notation/transformations.h index f898c92b9..b750e3961 100644 --- a/include/taco/index_notation/transformations.h +++ b/include/taco/index_notation/transformations.h @@ -22,6 +22,7 @@ class AddSuchThatPredicates; class Parallelize; class TopoReorder; class SetAssembleStrategy; +class SetMergeStrategy; /// A transformation is an optimization that transforms a statement in the /// concrete index notation into a new statement that computes the same result @@ -36,6 +37,7 @@ class Transformation { Transformation(TopoReorder); Transformation(AddSuchThatPredicates); Transformation(SetAssembleStrategy); + Transformation(SetMergeStrategy); IndexStmt apply(IndexStmt stmt, std::string *reason = nullptr) const; @@ -206,6 +208,25 @@ class SetAssembleStrategy : public TransformationInterface { /// Print a SetAssembleStrategy command. std::ostream &operator<<(std::ostream &, const SetAssembleStrategy&); +class SetMergeStrategy : public TransformationInterface { +public: + SetMergeStrategy(IndexVar i, MergeStrategy strategy); + + IndexVar geti() const; + MergeStrategy getMergeStrategy() const; + + IndexStmt apply(IndexStmt stmt, std::string *reason = nullptr) const; + + void print(std::ostream &os) const; + +private: + struct Content; + std::shared_ptr content; +}; + +/// Print a SetMergeStrategy command. +std::ostream &operator<<(std::ostream &, const SetMergeStrategy&); + // Autoscheduling functions /** diff --git a/include/taco/ir_tags.h b/include/taco/ir_tags.h index 5858a13e3..be5d8f6bb 100644 --- a/include/taco/ir_tags.h +++ b/include/taco/ir_tags.h @@ -33,6 +33,13 @@ enum class AssembleStrategy { }; extern const char *AssembleStrategy_NAMES[]; +/// MergeStrategy::TwoFinger merges iterators by incrementing one at a time +/// MergeStrategy::Galloping merges iterators by exponential search (galloping) +enum class MergeStrategy { + TwoFinger, Gallop +}; +extern const char *MergeStrategy_NAMES[]; + } #endif //TACO_IR_TAGS_H diff --git a/include/taco/lower/lowerer_impl_imperative.h b/include/taco/lower/lowerer_impl_imperative.h index 4498b37f0..fa97e3cd9 100644 --- a/include/taco/lower/lowerer_impl_imperative.h +++ b/include/taco/lower/lowerer_impl_imperative.h @@ -146,15 +146,18 @@ class LowererImplImperative : public LowererImpl { * \param statement * A concrete index notation statement to compute at the points in the * sparse iteration space described by the merge lattice. + * \param mergeStrategy + * A strategy for merging iterators. One of TwoFinger or Gallop. * * \return * IR code to compute the forall loop. */ virtual ir::Stmt lowerMergeLattice(MergeLattice lattice, IndexVar coordinateVar, IndexStmt statement, - const std::set& reducedAccesses); + const std::set& reducedAccesses, + MergeStrategy mergeStrategy); - virtual ir::Stmt resolveCoordinate(std::vector mergers, ir::Expr coordinate, bool emitVarDecl); + virtual ir::Stmt resolveCoordinate(std::vector mergers, ir::Expr coordinate, bool emitVarDecl, bool mergeWithMax); /** * Lower the merge point at the top of the given lattice to code that iterates @@ -169,15 +172,20 @@ class LowererImplImperative : public LowererImpl { * coordinate the merge point is at. * A concrete index notation statement to compute at the points in the * sparse iteration space region described by the merge point. + * \param mergeWithMax + * A boolean indicating whether coordinates should be combined with MAX instead of MIN. + * MAX is needed when the iterators are merged with the Gallop strategy. */ virtual ir::Stmt lowerMergePoint(MergeLattice pointLattice, ir::Expr coordinate, IndexVar coordinateVar, IndexStmt statement, - const std::set& reducedAccesses, bool resolvedCoordDeclared); + const std::set& reducedAccesses, bool resolvedCoordDeclared, + MergeStrategy mergestrategy); /// Lower a merge lattice to cases. virtual ir::Stmt lowerMergeCases(ir::Expr coordinate, IndexVar coordinateVar, IndexStmt stmt, MergeLattice lattice, - const std::set& reducedAccesses); + const std::set& reducedAccesses, + MergeStrategy mergeStrategy); /// Lower a forall loop body. virtual ir::Stmt lowerForallBody(ir::Expr coordinate, IndexStmt stmt, @@ -185,7 +193,8 @@ class LowererImplImperative : public LowererImpl { std::vector inserters, std::vector appenders, MergeLattice caseLattice, - const std::set& reducedAccesses); + const std::set& reducedAccesses, + MergeStrategy mergeStrategy); /// Lower a where statement. @@ -375,7 +384,7 @@ class LowererImplImperative : public LowererImpl { /// Conditionally increment iterator position variables. ir::Stmt codeToIncIteratorVars(ir::Expr coordinate, IndexVar coordinateVar, - std::vector iterators, std::vector mergers); + std::vector iterators, std::vector mergers, MergeStrategy strategy); ir::Stmt codeToLoadCoordinatesFromPosIterators(std::vector iterators, bool declVars); @@ -410,7 +419,8 @@ class LowererImplImperative : public LowererImpl { /// Lowers a merge lattice to cases assuming there are no more loops to be emitted in stmt. /// Will emit checks for explicit zeros for each mode iterator and each locator in the lattice. ir::Stmt lowerMergeCasesWithExplicitZeroChecks(ir::Expr coordinate, IndexVar coordinateVar, IndexStmt stmt, - MergeLattice lattice, const std::set& reducedAccesses); + MergeLattice lattice, const std::set& reducedAccesses, + MergeStrategy mergeStrategy); /// Constructs cases comparing the coordVar for each iterator to the resolved coordinate. /// Returns a vector where coordComparisons[i] corresponds to a case for iters[i] @@ -444,7 +454,7 @@ class LowererImplImperative : public LowererImpl { /// The map must be of iterators to exprs of boolean types std::vector lowerCasesFromMap(std::map iteratorToCondition, ir::Expr coordinate, IndexStmt stmt, const MergeLattice& lattice, - const std::set& reducedAccesses); + const std::set& reducedAccesses, MergeStrategy mergeStrategy); /// Constructs an expression which checks if this access is "zero" ir::Expr constructCheckForAccessZero(Access); diff --git a/src/codegen/codegen_c.cpp b/src/codegen/codegen_c.cpp index b57eb51bd..d53e3b06c 100644 --- a/src/codegen/codegen_c.cpp +++ b/src/codegen/codegen_c.cpp @@ -62,6 +62,28 @@ const string cHeaders = "int cmp(const void *a, const void *b) {\n" " return *((const int*)a) - *((const int*)b);\n" "}\n" + // Increment arrayStart until array[arrayStart] >= target or arrayStart >= arrayEnd + // using an exponential search algorithm: https://en.wikipedia.org/wiki/Exponential_search. + "int taco_gallop(int *array, int arrayStart, int arrayEnd, int target) {\n" + " if (array[arrayStart] >= target || arrayStart >= arrayEnd) {\n" + " return arrayStart;\n" + " }\n" + " int step = 1;\n" + " int curr = arrayStart;\n" + " while (curr + step < arrayEnd && array[curr + step] < target) {\n" + " curr += step;\n" + " step = step * 2;\n" + " }\n" + "\n" + " step = step / 2;\n" + " while (step > 0) {\n" + " if (curr + step < arrayEnd && array[curr + step] < target) {\n" + " curr += step;\n" + " }\n" + " step = step / 2;\n" + " }\n" + " return curr+1;\n" + "}\n" "int taco_binarySearchAfter(int *array, int arrayStart, int arrayEnd, int target) {\n" " if (array[arrayStart] >= target) {\n" " return arrayStart;\n" diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index 428b67dcd..2f565ca62 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -1907,6 +1907,15 @@ IndexStmt IndexStmt::reorder(std::vector reorderedvars) const { return transformed; } +IndexStmt IndexStmt::mergeby(IndexVar i, MergeStrategy strategy) const { + string reason; + IndexStmt transformed = SetMergeStrategy(i, strategy).apply(*this, &reason); + if (!transformed.defined()) { + taco_uerror << reason; + } + return transformed; +} + IndexStmt IndexStmt::parallelize(IndexVar i, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy) const { string reason; IndexStmt transformed = Parallelize(i, parallel_unit, output_race_strategy).apply(*this, &reason); @@ -2017,7 +2026,7 @@ IndexStmt IndexStmt::unroll(IndexVar i, size_t unrollFactor) const { void visit(const ForallNode* node) { if (node->indexVar == i) { - stmt = Forall(i, rewrite(node->stmt), node->parallel_unit, node->output_race_strategy, unrollFactor); + stmt = Forall(i, rewrite(node->stmt), node->merge_strategy, node->parallel_unit, node->output_race_strategy, unrollFactor); } else { IndexNotationRewriter::visit(node); @@ -2125,11 +2134,11 @@ Forall::Forall(const ForallNode* n) : IndexStmt(n) { } Forall::Forall(IndexVar indexVar, IndexStmt stmt) - : Forall(indexVar, stmt, ParallelUnit::NotParallel, OutputRaceStrategy::IgnoreRaces) { + : Forall(indexVar, stmt, MergeStrategy::TwoFinger, ParallelUnit::NotParallel, OutputRaceStrategy::IgnoreRaces) { } -Forall::Forall(IndexVar indexVar, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor) - : Forall(new ForallNode(indexVar, stmt, parallel_unit, output_race_strategy, unrollFactor)) { +Forall::Forall(IndexVar indexVar, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor) + : Forall(new ForallNode(indexVar, stmt, merge_strategy, parallel_unit, output_race_strategy, unrollFactor)) { } IndexVar Forall::getIndexVar() const { @@ -2148,6 +2157,10 @@ OutputRaceStrategy Forall::getOutputRaceStrategy() const { return getNode(*this)->output_race_strategy; } +MergeStrategy Forall::getMergeStrategy() const { + return getNode(*this)->merge_strategy; +} + size_t Forall::getUnrollFactor() const { return getNode(*this)->unrollFactor; } @@ -2156,8 +2169,8 @@ Forall forall(IndexVar i, IndexStmt stmt) { return Forall(i, stmt); } -Forall forall(IndexVar i, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor) { - return Forall(i, stmt, parallel_unit, output_race_strategy, unrollFactor); +Forall forall(IndexVar i, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor) { + return Forall(i, stmt, merge_strategy, parallel_unit, output_race_strategy, unrollFactor); } template <> bool isa(IndexStmt s) { @@ -3938,7 +3951,7 @@ struct Zero : public IndexNotationRewriterStrict { stmt = op; } else { - stmt = new ForallNode(op->indexVar, body, op->parallel_unit, op->output_race_strategy, op->unrollFactor); + stmt = new ForallNode(op->indexVar, body, op->merge_strategy, op->parallel_unit, op->output_race_strategy, op->unrollFactor); } } diff --git a/src/index_notation/index_notation_rewriter.cpp b/src/index_notation/index_notation_rewriter.cpp index 5caa2da4b..e8c123af1 100644 --- a/src/index_notation/index_notation_rewriter.cpp +++ b/src/index_notation/index_notation_rewriter.cpp @@ -185,7 +185,7 @@ void IndexNotationRewriter::visit(const ForallNode* op) { stmt = op; } else { - stmt = new ForallNode(op->indexVar, s, op->parallel_unit, op->output_race_strategy, op->unrollFactor); + stmt = new ForallNode(op->indexVar, s, op->merge_strategy, op->parallel_unit, op->output_race_strategy, op->unrollFactor); } } @@ -406,7 +406,7 @@ struct ReplaceIndexVars : public IndexNotationRewriter { stmt = op; } else { - stmt = new ForallNode(iv, s, op->parallel_unit, op->output_race_strategy, + stmt = new ForallNode(iv, s, op->merge_strategy, op->parallel_unit, op->output_race_strategy, op->unrollFactor); } } diff --git a/src/index_notation/transformations.cpp b/src/index_notation/transformations.cpp index a2c4dcef4..4f94790c4 100644 --- a/src/index_notation/transformations.cpp +++ b/src/index_notation/transformations.cpp @@ -34,6 +34,10 @@ Transformation::Transformation(ForAllReplace forallreplace) : transformation(new ForAllReplace(forallreplace)) { } +Transformation::Transformation(SetMergeStrategy setmergestrategy) + : transformation(new SetMergeStrategy(setmergestrategy)) { +} + Transformation::Transformation(Parallelize parallelize) : transformation(new Parallelize(parallelize)) { } @@ -131,6 +135,103 @@ std::ostream& operator<<(std::ostream& os, const Reorder& reorder) { return os; } +struct SetMergeStrategy::Content { + IndexVar i_var; + MergeStrategy strategy; +}; + +SetMergeStrategy::SetMergeStrategy(IndexVar i, MergeStrategy strategy) : content(new Content) { + content->i_var = i; + content->strategy = strategy; +} + +IndexVar SetMergeStrategy::geti() const { + return content->i_var; +} + +MergeStrategy SetMergeStrategy::getMergeStrategy() const { + return content->strategy; +} + +IndexStmt SetMergeStrategy::apply(IndexStmt stmt, string* reason) const { + INIT_REASON(reason); + + string r; + if (!isConcreteNotation(stmt, &r)) { + *reason = "The index statement is not valid concrete index notation: " + r; + return IndexStmt(); + } + + struct SetMergeStrategyRewriter : public IndexNotationRewriter { + using IndexNotationRewriter::visit; + + ProvenanceGraph provGraph; + map tensorVars; + set definedIndexVars; + + SetMergeStrategy transformation; + string reason; + SetMergeStrategyRewriter(SetMergeStrategy transformation) + : transformation(transformation) {} + + IndexStmt setmergestrategy(IndexStmt stmt) { + provGraph = ProvenanceGraph(stmt); + tensorVars = createIRTensorVars(stmt); + return rewrite(stmt); + } + + void visit(const ForallNode* node) { + Forall foralli(node); + IndexVar i = transformation.geti(); + + definedIndexVars.insert(foralli.getIndexVar()); + + if (foralli.getIndexVar() == i) { + Iterators iterators(foralli, tensorVars); + MergeLattice lattice = MergeLattice::make(foralli, iterators, provGraph, + definedIndexVars); + for (auto iterator : lattice.iterators()) { + if (!iterator.isOrdered()) { + reason = "Precondition failed: Variable " + + i.getName() + + " is not ordered and cannot be galloped."; + return; + } + } + if (lattice.points().size() != 1) { + reason = "Precondition failed: The merge lattice of variable " + + i.getName() + + " has more than 1 point and cannot be merged by galloping"; + return; + } + + MergeStrategy strategy = transformation.getMergeStrategy(); + stmt = rewrite(foralli.getStmt()); + stmt = Forall(node->indexVar, stmt, strategy, node->parallel_unit, + node->output_race_strategy, node->unrollFactor); + return; + } + IndexNotationRewriter::visit(node); + } + }; + SetMergeStrategyRewriter rewriter = SetMergeStrategyRewriter(*this); + IndexStmt rewritten = rewriter.setmergestrategy(stmt); + if (!rewriter.reason.empty()) { + *reason = rewriter.reason; + return IndexStmt(); + } + return rewritten; +} + +void SetMergeStrategy::print(std::ostream& os) const { + os << "mergeby(" << geti() << ", " + << MergeStrategy_NAMES[(int)getMergeStrategy()] << ")"; +} + +std::ostream& operator<<(std::ostream& os, const SetMergeStrategy& setmergestrategy) { + setmergestrategy.print(os); + return os; +} // class Precompute struct Precompute::Content { @@ -819,7 +920,7 @@ IndexStmt Parallelize::apply(IndexStmt stmt, std::string* reason) const { ); taco_iassert(!precomputeAssignments.empty()); - IndexStmt precomputed_stmt = forall(i, foralli.getStmt(), parallelize.getParallelUnit(), parallelize.getOutputRaceStrategy(), foralli.getUnrollFactor()); + IndexStmt precomputed_stmt = forall(i, foralli.getStmt(), foralli.getMergeStrategy(), parallelize.getParallelUnit(), parallelize.getOutputRaceStrategy(), foralli.getUnrollFactor()); for (auto assignment : precomputeAssignments) { // Construct temporary of correct type and size of outer loop TensorVar w(string("w_") + ParallelUnit_NAMES[(int) parallelize.getParallelUnit()], Type(assignment->lhs.getDataType(), {Dimension(i)}), taco::dense); @@ -828,7 +929,7 @@ IndexStmt Parallelize::apply(IndexStmt stmt, std::string* reason) const { IndexStmt producer = ReplaceReductionExpr(map({{assignment->lhs, w(i)}})).rewrite(precomputed_stmt); taco_iassert(isa(producer)); Forall producer_forall = to(producer); - producer = forall(producer_forall.getIndexVar(), producer_forall.getStmt(), parallelize.getParallelUnit(), parallelize.getOutputRaceStrategy(), foralli.getUnrollFactor()); + producer = forall(producer_forall.getIndexVar(), producer_forall.getStmt(), foralli.getMergeStrategy(), parallelize.getParallelUnit(), parallelize.getOutputRaceStrategy(), foralli.getUnrollFactor()); // build consumer that writes from temporary to output, mark consumer as parallel reduction ParallelUnit reductionUnit = ParallelUnit::CPUThreadGroupReduction; @@ -840,7 +941,7 @@ IndexStmt Parallelize::apply(IndexStmt stmt, std::string* reason) const { reductionUnit = ParallelUnit::GPUBlockReduction; } } - IndexStmt consumer = forall(i, Assignment(assignment->lhs, w(i), assignment->op), reductionUnit, OutputRaceStrategy::ParallelReduction); + IndexStmt consumer = forall(i, Assignment(assignment->lhs, w(i), assignment->op), foralli.getMergeStrategy(), reductionUnit, OutputRaceStrategy::ParallelReduction); precomputed_stmt = where(consumer, producer); } stmt = precomputed_stmt; @@ -852,14 +953,14 @@ IndexStmt Parallelize::apply(IndexStmt stmt, std::string* reason) const { // reducing at end IndexStmt body = scalarPromote(foralli.getStmt(), provGraph, false, true); - stmt = forall(i, body, parallelize.getParallelUnit(), + stmt = forall(i, body, foralli.getMergeStrategy(), parallelize.getParallelUnit(), parallelize.getOutputRaceStrategy(), foralli.getUnrollFactor()); return; } - stmt = forall(i, foralli.getStmt(), parallelize.getParallelUnit(), parallelize.getOutputRaceStrategy(), foralli.getUnrollFactor()); + stmt = forall(i, foralli.getStmt(), foralli.getMergeStrategy(), parallelize.getParallelUnit(), parallelize.getOutputRaceStrategy(), foralli.getUnrollFactor()); return; } @@ -1026,7 +1127,7 @@ IndexStmt SetAssembleStrategy::apply(IndexStmt stmt, string* reason) const { if (s == op->stmt) { stmt = op; } else if (s.defined()) { - stmt = Forall(op->indexVar, s, op->parallel_unit, + stmt = Forall(op->indexVar, s, op->merge_strategy, op->parallel_unit, op->output_race_strategy, op->unrollFactor); } else { stmt = IndexStmt(); @@ -1297,7 +1398,7 @@ IndexStmt SetAssembleStrategy::apply(IndexStmt stmt, string* reason) const { if (s == op->stmt) { stmt = op; } else if (s.defined()) { - stmt = new ForallNode(op->indexVar, s, op->parallel_unit, + stmt = new ForallNode(op->indexVar, s, op->merge_strategy, op->parallel_unit, op->output_race_strategy, op->unrollFactor); } else { stmt = IndexStmt(); @@ -1607,7 +1708,7 @@ IndexStmt reorderLoopsTopologically(IndexStmt stmt) { taco_iassert(util::contains(sortedVars, i)); stmt = innerBody; for (auto it = sortedVars.rbegin(); it != sortedVars.rend(); ++it) { - stmt = forall(*it, stmt, forallParallelUnit.at(*it), forallOutputRaceStrategy.at(*it), foralli.getUnrollFactor()); + stmt = forall(*it, stmt, foralli.getMergeStrategy(), forallParallelUnit.at(*it), forallOutputRaceStrategy.at(*it), foralli.getUnrollFactor()); } return; } @@ -1753,7 +1854,7 @@ IndexStmt scalarPromote(IndexStmt stmt, ProvenanceGraph provGraph, return; } - stmt = forall(i, body, foralli.getParallelUnit(), + stmt = forall(i, body, foralli.getMergeStrategy(), foralli.getParallelUnit(), foralli.getOutputRaceStrategy(), foralli.getUnrollFactor()); for (const auto& consumer : consumers) { stmt = where(consumer, stmt); diff --git a/src/ir_tags.cpp b/src/ir_tags.cpp index af3dbd775..4afe9cef8 100644 --- a/src/ir_tags.cpp +++ b/src/ir_tags.cpp @@ -6,5 +6,6 @@ const char *ParallelUnit_NAMES[] = {"NotParallel", "DefaultUnit", "GPUBlock", "G const char *OutputRaceStrategy_NAMES[] = {"IgnoreRaces", "NoRaces", "Atomics", "Temporary", "ParallelReduction"}; const char *BoundType_NAMES[] = {"MinExact", "MinConstraint", "MaxExact", "MaxConstraint"}; const char *AssembleStrategy_NAMES[] = {"Append", "Insert"}; +const char *MergeStrategy_NAMES[] = {"TwoFinger", "Gallop"}; } diff --git a/src/lower/lowerer_impl_imperative.cpp b/src/lower/lowerer_impl_imperative.cpp index b09dcc1fa..ff4c10b21 100644 --- a/src/lower/lowerer_impl_imperative.cpp +++ b/src/lower/lowerer_impl_imperative.cpp @@ -872,7 +872,7 @@ Stmt LowererImplImperative::lowerForall(Forall forall) std::vector underivedAncestors = provGraph.getUnderivedAncestors(forall.getIndexVar()); taco_iassert(underivedAncestors.size() == 1); // TODO: add support for fused coordinate of pos loop loops = lowerMergeLattice(caseLattice, underivedAncestors[0], - forall.getStmt(), reducedAccesses); + forall.getStmt(), reducedAccesses, forall.getMergeStrategy()); } // taco_iassert(loops.defined()); @@ -1203,7 +1203,7 @@ Stmt LowererImplImperative::lowerForallDimension(Forall forall, } Stmt body = lowerForallBody(coordinate, forall.getStmt(), locators, inserters, - appenders, caseLattice, reducedAccesses); + appenders, caseLattice, reducedAccesses, forall.getMergeStrategy()); if (forall.getParallelUnit() != ParallelUnit::NotParallel && forall.getOutputRaceStrategy() == OutputRaceStrategy::Atomics) { markAssignsAtomicDepth--; @@ -1264,7 +1264,7 @@ Stmt LowererImplImperative::lowerForallDimension(Forall forall, } Stmt declareVar = VarDecl::make(coordinate, Load::make(indexList, loopVar)); - Stmt body = lowerForallBody(coordinate, forall.getStmt(), locators, inserters, appenders, caseLattice, reducedAccesses); + Stmt body = lowerForallBody(coordinate, forall.getStmt(), locators, inserters, appenders, caseLattice, reducedAccesses, forall.getMergeStrategy()); Stmt resetGuard = ir::Store::make(bitGuard, coordinate, ir::Literal::make(false), markAssignsAtomicDepth > 0, atomicParallelUnit); if (forall.getParallelUnit() != ParallelUnit::NotParallel && forall.getOutputRaceStrategy() == OutputRaceStrategy::Atomics) { @@ -1340,7 +1340,7 @@ Stmt LowererImplImperative::lowerForallPosition(Forall forall, Iterator iterator markAssignsAtomicDepth++; } - Stmt body = lowerForallBody(coordinate, forall.getStmt(), locators, inserters, appenders, caseLattice, reducedAccesses); + Stmt body = lowerForallBody(coordinate, forall.getStmt(), locators, inserters, appenders, caseLattice, reducedAccesses, forall.getMergeStrategy()); if (forall.getParallelUnit() != ParallelUnit::NotParallel && forall.getOutputRaceStrategy() == OutputRaceStrategy::Atomics) { markAssignsAtomicDepth--; @@ -1514,7 +1514,7 @@ Stmt LowererImplImperative::lowerForallFusedPosition(Forall forall, Iterator ite } Stmt body = lowerForallBody(coordinate, forall.getStmt(), - locators, inserters, appenders, caseLattice, reducedAccesses); + locators, inserters, appenders, caseLattice, reducedAccesses, forall.getMergeStrategy()); if (forall.getParallelUnit() != ParallelUnit::NotParallel && forall.getOutputRaceStrategy() == OutputRaceStrategy::Atomics) { markAssignsAtomicDepth--; @@ -1582,7 +1582,8 @@ Stmt LowererImplImperative::lowerForallFusedPosition(Forall forall, Iterator ite Stmt LowererImplImperative::lowerMergeLattice(MergeLattice caseLattice, IndexVar coordinateVar, IndexStmt statement, - const std::set& reducedAccesses) + const std::set& reducedAccesses, + MergeStrategy mergestrategy) { // Lower merge lattice always gets called from lowerForAll. So we want loop lattice MergeLattice loopLattice = caseLattice.getLoopLattice(); @@ -1608,7 +1609,7 @@ Stmt LowererImplImperative::lowerMergeLattice(MergeLattice caseLattice, IndexVar // points in the merge lattice. IndexStmt zeroedStmt = zero(statement, getExhaustedAccesses(point, caseLattice)); MergeLattice sublattice = caseLattice.subLattice(point); - Stmt mergeLoop = lowerMergePoint(sublattice, coordinate, coordinateVar, zeroedStmt, reducedAccesses, resolvedCoordDeclared); + Stmt mergeLoop = lowerMergePoint(sublattice, coordinate, coordinateVar, zeroedStmt, reducedAccesses, resolvedCoordDeclared, mergestrategy); mergeLoopsVec.push_back(mergeLoop); } Stmt mergeLoops = Block::make(mergeLoopsVec); @@ -1623,7 +1624,8 @@ Stmt LowererImplImperative::lowerMergeLattice(MergeLattice caseLattice, IndexVar Stmt LowererImplImperative::lowerMergePoint(MergeLattice pointLattice, ir::Expr coordinate, IndexVar coordinateVar, IndexStmt statement, - const std::set& reducedAccesses, bool resolvedCoordDeclared) + const std::set& reducedAccesses, bool resolvedCoordDeclared, + MergeStrategy mergeStrategy) { MergePoint point = pointLattice.points().front(); @@ -1644,7 +1646,9 @@ Stmt LowererImplImperative::lowerMergePoint(MergeLattice pointLattice, std::vector indexSetStmts; for (auto& iter : filter(iterators, [](Iterator it) { return it.hasIndexSet(); })) { // For each iterator A with an index set B, emit the following code: - // setMatch = min(A, B); // Check whether A matches its index set at this point. + // // Check whether A matches its index set at this point. + // // Using max instead of min because we will be merging the iterators by galloping. + // setMatch = max(A, B); // if (A == setMatch && B == setMatch) { // // If there was a match, project down the values of the iterators // // to be the position variable of the index set iterator. This has the @@ -1652,17 +1656,17 @@ Stmt LowererImplImperative::lowerMergePoint(MergeLattice pointLattice, // A_coord = B_pos; // B_coord = B_pos; // } else { - // // Advance the iterator and it's index set iterator accordingly if + // // Advance the iterator and its index set iterator accordingly if // // there wasn't a match. - // A_pos += (A == setMatch); - // B_pos += (B == setMatch); + // A_pos = taco_gallop(int *A_array, int A_pos, int A_arrayEnd, int setMatch); + // B_pos = taco_gallop(int *B_array, int B_pos, int B_arrayEnd, int setMatch); // // We must continue so that we only proceed to the rest of the cases in // // the merge if there actually is a point present for A. // continue; // } auto setMatch = ir::Var::make("setMatch", Int()); auto indexSetIter = iter.getIndexSetIterator(); - indexSetStmts.push_back(ir::VarDecl::make(setMatch, ir::Min::make(this->coordinates({iter, indexSetIter})))); + indexSetStmts.push_back(ir::VarDecl::make(setMatch, ir::Max::make(this->coordinates({iter, indexSetIter})))); // Equality checks for each iterator. auto iterEq = ir::Eq::make(iter.getCoordVar(), setMatch); auto setEq = ir::Eq::make(indexSetIter.getCoordVar(), setMatch); @@ -1672,9 +1676,25 @@ Stmt LowererImplImperative::lowerMergePoint(MergeLattice pointLattice, ir::Assign::make(indexSetIter.getCoordVar(), indexSetIter.getPosVar()) ); // Code to increment both iterator variables. + auto ivar = iter.getIteratorVar(); + Expr iteratorParentPos = iter.getParent().getPosVar(); + ModeFunction iterBounds = iter.posBounds(iteratorParentPos); + vector iterGallopArgs = { + iter.getMode().getModePack().getArray(1), + ivar, iterBounds[1], + setMatch + }; + auto indexVar = indexSetIter.getIteratorVar(); + Expr indexIterParentPos = indexSetIter.getParent().getPosVar(); + ModeFunction indexIterBounds = indexSetIter.posBounds(indexIterParentPos); + vector indexGallopArgs = { + indexSetIter.getMode().getModePack().getArray(1), + indexVar, indexIterBounds[1], + setMatch + }; auto incr = ir::Block::make( - compoundAssign(iter.getIteratorVar(), ir::Cast::make(Eq::make(iter.getCoordVar(), setMatch), iter.getIteratorVar().type())), - compoundAssign(indexSetIter.getIteratorVar(), ir::Cast::make(Eq::make(indexSetIter.getCoordVar(), setMatch), indexSetIter.getIteratorVar().type())), + ir::Assign::make(ivar, ir::Call::make("taco_gallop", iterGallopArgs, ivar.type())), + ir::Assign::make(indexVar, ir::Call::make("taco_gallop", indexGallopArgs, indexVar.type())), ir::Continue::make() ); // Code that uses the defined parts together in the if-then-else. @@ -1682,7 +1702,13 @@ Stmt LowererImplImperative::lowerMergePoint(MergeLattice pointLattice, } // Merge iterator coordinate variables - Stmt resolvedCoordinate = resolveCoordinate(mergers, coordinate, !resolvedCoordDeclared); + bool mergeWithMax; + if (mergeStrategy == MergeStrategy::Gallop) { + mergeWithMax = true; + } else { + mergeWithMax = false; + } + Stmt resolvedCoordinate = resolveCoordinate(mergers, coordinate, !resolvedCoordDeclared, mergeWithMax); // Locate positions Stmt loadLocatorPosVars = declLocatePosVars(locators); @@ -1696,10 +1722,10 @@ Stmt LowererImplImperative::lowerMergePoint(MergeLattice pointLattice, // One case for each child lattice point lp Stmt caseStmts = lowerMergeCases(coordinate, coordinateVar, statement, pointLattice, - reducedAccesses); + reducedAccesses, mergeStrategy); // Increment iterator position variables - Stmt incIteratorVarStmts = codeToIncIteratorVars(coordinate, coordinateVar, iterators, mergers); + Stmt incIteratorVarStmts = codeToIncIteratorVars(coordinate, coordinateVar, iterators, mergers, mergeStrategy); /// While loop over rangers return While::make(checkThatNoneAreExhausted(rangers), @@ -1712,7 +1738,7 @@ Stmt LowererImplImperative::lowerMergePoint(MergeLattice pointLattice, incIteratorVarStmts)); } -Stmt LowererImplImperative::resolveCoordinate(std::vector mergers, ir::Expr coordinate, bool emitVarDecl) { +Stmt LowererImplImperative::resolveCoordinate(std::vector mergers, ir::Expr coordinate, bool emitVarDecl, bool mergeWithMax) { if (mergers.size() == 1) { Iterator merger = mergers[0]; if (merger.hasPosIter()) { @@ -1767,24 +1793,30 @@ Stmt LowererImplImperative::resolveCoordinate(std::vector mergers, ir: } else { // Multiple position iterators so the smallest is the resolved coordinate - if (emitVarDecl) { - return VarDecl::make(coordinate, Min::make(coordinates(mergers))); + Expr merged; + if (mergeWithMax) { + merged = Max::make(coordinates(mergers)); + } else { + merged = Min::make(coordinates(mergers)); } - else { - return Assign::make(coordinate, Min::make(coordinates(mergers))); + if (emitVarDecl) { + return VarDecl::make(coordinate, merged); + } else { + return Assign::make(coordinate, merged); } } } Stmt LowererImplImperative::lowerMergeCases(ir::Expr coordinate, IndexVar coordinateVar, IndexStmt stmt, MergeLattice caseLattice, - const std::set& reducedAccesses) + const std::set& reducedAccesses, + MergeStrategy mergeStrategy) { vector result; if (caseLattice.anyModeIteratorIsLeaf() && caseLattice.needExplicitZeroChecks()) { // Can check value array of some tensor - Stmt body = lowerMergeCasesWithExplicitZeroChecks(coordinate, coordinateVar, stmt, caseLattice, reducedAccesses); + Stmt body = lowerMergeCasesWithExplicitZeroChecks(coordinate, coordinateVar, stmt, caseLattice, reducedAccesses, mergeStrategy); result.push_back(body); return Block::make(result); } @@ -1802,7 +1834,7 @@ Stmt LowererImplImperative::lowerMergeCases(ir::Expr coordinate, IndexVar coordi // Just one iterator so no conditional taco_iassert(!loopLattice.points()[0].isOmitter()); Stmt body = lowerForallBody(coordinate, stmt, {}, inserters, - appenders, loopLattice, reducedAccesses); + appenders, loopLattice, reducedAccesses, mergeStrategy); result.push_back(body); } else if (!loopLattice.points().empty()) { @@ -1827,10 +1859,10 @@ Stmt LowererImplImperative::lowerMergeCases(ir::Expr coordinate, IndexVar coordi // Construct case body IndexStmt zeroedStmt = zero(stmt, getExhaustedAccesses(point, loopLattice)); Stmt body = lowerForallBody(coordinate, zeroedStmt, {}, - inserters, appenders, MergeLattice({point}), reducedAccesses); + inserters, appenders, MergeLattice({point}), reducedAccesses, mergeStrategy); if (coordComparisons.empty()) { Stmt body = lowerForallBody(coordinate, stmt, {}, inserters, - appenders, MergeLattice({point}), reducedAccesses); + appenders, MergeLattice({point}), reducedAccesses, mergeStrategy); result.push_back(body); break; } @@ -1919,7 +1951,8 @@ std::vector LowererImplImperative::constructInnerLoopCasePreamble(ir:: vector LowererImplImperative::lowerCasesFromMap(map iteratorToCondition, ir::Expr coordinate, IndexStmt stmt, const MergeLattice& lattice, - const std::set& reducedAccesses) { + const std::set& reducedAccesses, + MergeStrategy mergeStrategy) { vector appenders; vector inserters; @@ -1953,10 +1986,10 @@ vector LowererImplImperative::lowerCasesFromMap(map iterat // Construct case body IndexStmt zeroedStmt = zero(stmt, getExhaustedAccesses(point, lattice)); Stmt body = lowerForallBody(coordinate, zeroedStmt, {}, - inserters, appenders, MergeLattice({point}), reducedAccesses); + inserters, appenders, MergeLattice({point}), reducedAccesses, mergeStrategy); if (isNonZeroComparisions.empty()) { Stmt body = lowerForallBody(coordinate, stmt, {}, inserters, - appenders, MergeLattice({point}), reducedAccesses); + appenders, MergeLattice({point}), reducedAccesses, mergeStrategy); result.push_back(body); break; } @@ -1976,7 +2009,7 @@ vector LowererImplImperative::lowerCasesFromMap(map iterat Access access = iterators.modeAccess(it).getAccess(); IndexStmt initStmt = Assignment(access, Literal::zero(access.getDataType())); Stmt initialization = lowerForallBody(coordinate, initStmt, {}, inserters, - appenders, MergeLattice({}), reducedAccesses); + appenders, MergeLattice({}), reducedAccesses, mergeStrategy); stmts.push_back(initialization); } } @@ -1991,7 +2024,8 @@ vector LowererImplImperative::lowerCasesFromMap(map iterat /// Lowers a merge lattice to cases assuming there are no more loops to be emitted in stmt. Stmt LowererImplImperative::lowerMergeCasesWithExplicitZeroChecks(ir::Expr coordinate, IndexVar coordinateVar, IndexStmt stmt, - MergeLattice lattice, const std::set& reducedAccesses) { + MergeLattice lattice, const std::set& reducedAccesses, + MergeStrategy mergeStrategy) { vector result; if (lattice.points().size() == 1 && lattice.iterators().size() == 1 @@ -2003,14 +2037,14 @@ Stmt LowererImplImperative::lowerMergeCasesWithExplicitZeroChecks(ir::Expr coord tie(appenders, inserters) = splitAppenderAndInserters(lattice.results()); taco_iassert(!lattice.points()[0].isOmitter()); Stmt body = lowerForallBody(coordinate, stmt, {}, inserters, - appenders, lattice, reducedAccesses); + appenders, lattice, reducedAccesses, mergeStrategy); result.push_back(body); } else if (!lattice.points().empty()) { map iteratorToConditionMap; vector preamble = constructInnerLoopCasePreamble(coordinate, coordinateVar, lattice, iteratorToConditionMap); util::append(result, preamble); - vector cases = lowerCasesFromMap(iteratorToConditionMap, coordinate, stmt, lattice, reducedAccesses); + vector cases = lowerCasesFromMap(iteratorToConditionMap, coordinate, stmt, lattice, reducedAccesses, mergeStrategy); util::append(result, cases); } @@ -2022,7 +2056,8 @@ Stmt LowererImplImperative::lowerForallBody(Expr coordinate, IndexStmt stmt, vector inserters, vector appenders, MergeLattice caseLattice, - const set& reducedAccesses) { + const set& reducedAccesses, + MergeStrategy mergeStrategy) { // Inserter positions Stmt declInserterPosVars = declLocatePosVars(inserters); @@ -2055,7 +2090,7 @@ Stmt LowererImplImperative::lowerForallBody(Expr coordinate, IndexStmt stmt, } // This will lower the body for each case to actually compute. Therefore, we don't need to resize assembly arrays - std::vector loweredCases = lowerCasesFromMap(caseMap, coordinate, stmt, caseLattice, reducedAccesses); + std::vector loweredCases = lowerCasesFromMap(caseMap, coordinate, stmt, caseLattice, reducedAccesses, mergeStrategy); append(stmts, loweredCases); Stmt body = Block::make(stmts); @@ -2071,13 +2106,26 @@ Stmt LowererImplImperative::lowerForallBody(Expr coordinate, IndexStmt stmt, // Code to append coordinates Stmt appendCoords = appendCoordinate(appenders, coordinate); + std::vector stmts; + + // Code to increment iterators when merging by galloping. + if (mergeStrategy == MergeStrategy::Gallop && caseLattice.iterators().size() > 1) { + for (auto it : caseLattice.iterators()) { + Expr ivar = it.getIteratorVar(); + stmts.push_back(compoundAssign(ivar, 1)); + } + } + + Stmt incr = Block::make(stmts); + // TODO: Emit code to insert coordinates return Block::make(initVals, declInserterPosVars, declLocatorPosVars, body, - appendCoords); + appendCoords, + incr); } Expr LowererImplImperative::getTemporarySize(Where where) { @@ -3493,7 +3541,7 @@ Stmt LowererImplImperative::codeToInitializeIteratorVar(Iterator iterator, vecto else { result.push_back(codeToLoadCoordinatesFromPosIterators(iterators, true)); - Stmt stmt = resolveCoordinate(mergers, coordinate, true); + Stmt stmt = resolveCoordinate(mergers, coordinate, true, false); taco_iassert(stmt != Stmt()); result.push_back(stmt); result.push_back(codeToRecoverDerivedIndexVar(coordinateVar, iterator.getIndexVar(), true)); @@ -3545,7 +3593,7 @@ Stmt LowererImplImperative::codeToRecoverDerivedIndexVar(IndexVar underived, Ind return Stmt(); } -Stmt LowererImplImperative::codeToIncIteratorVars(Expr coordinate, IndexVar coordinateVar, vector iterators, vector mergers) { +Stmt LowererImplImperative::codeToIncIteratorVars(Expr coordinate, IndexVar coordinateVar, vector iterators, vector mergers, MergeStrategy strategy) { if (iterators.size() == 1) { Expr ivar = iterators[0].getIteratorVar(); @@ -3572,12 +3620,23 @@ Stmt LowererImplImperative::codeToIncIteratorVars(Expr coordinate, IndexVar coor for (auto& iterator : levelIterators) { Expr ivar = iterator.getIteratorVar(); if (iterator.isUnique()) { - Expr increment = iterator.isFull() - ? 1 - : ir::Cast::make(Eq::make(iterator.getCoordVar(), - coordinate), - ivar.type()); - result.push_back(compoundAssign(ivar, increment)); + if (iterator.isFull()) { + Expr increment = 1; + result.push_back(compoundAssign(ivar, increment)); + } else if (strategy == MergeStrategy::Gallop) { + Expr iteratorParentPos = iterator.getParent().getPosVar(); + ModeFunction iterBounds = iterator.posBounds(iteratorParentPos); + result.push_back(iterBounds.compute()); + vector gallopArgs = { + iterator.getMode().getModePack().getArray(1), + ivar, iterBounds[1], + coordinate, + }; + result.push_back(ir::Assign::make(ivar, ir::Call::make("taco_gallop", gallopArgs, ivar.type()))); + } else { // strategy == MergeStrategy::TwoFinger + Expr increment = ir::Cast::make(Eq::make(iterator.getCoordVar(), coordinate), ivar.type()); + result.push_back(compoundAssign(ivar, increment)); + } } else if (!iterator.isLeaf()) { result.push_back(Assign::make(ivar, iterator.getSegendVar())); } @@ -3593,7 +3652,7 @@ Stmt LowererImplImperative::codeToIncIteratorVars(Expr coordinate, IndexVar coor } else { result.push_back(codeToLoadCoordinatesFromPosIterators(iterators, false)); - Stmt stmt = resolveCoordinate(mergers, coordinate, false); + Stmt stmt = resolveCoordinate(mergers, coordinate, false, false); taco_iassert(stmt != Stmt()); result.push_back(stmt); result.push_back(codeToRecoverDerivedIndexVar(coordinateVar, iterator.getIndexVar(), false)); diff --git a/test/tests-index_notation.cpp b/test/tests-index_notation.cpp index 8090e71de..df6cbf938 100644 --- a/test/tests-index_notation.cpp +++ b/test/tests-index_notation.cpp @@ -149,7 +149,7 @@ TEST(notation, isomorphic) { ASSERT_FALSE(isomorphic(forall(i, forall(j, A(i,j) = B(i,j) + C(i,j))), forall(i, forall(j, A(j,i) = B(j,i) + C(j,i))))); ASSERT_FALSE(isomorphic(forall(i, forall(j, A(i,j) = B(i,j) + C(i,j), - ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces)), + MergeStrategy::TwoFinger, ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces)), forall(j, forall(i, A(j,i) = B(j,i) + C(j,i))))); ASSERT_TRUE(isomorphic(sum(j, B(i,j) + C(i,j)), sum(i, B(j,i) + C(j,i)))); ASSERT_FALSE(isomorphic(sum(j, B(i,j) + C(i,j)), sum(j, B(j,i) + C(j,i)))); diff --git a/test/tests-scheduling.cpp b/test/tests-scheduling.cpp index ca827421b..ee564577b 100644 --- a/test/tests-scheduling.cpp +++ b/test/tests-scheduling.cpp @@ -1053,3 +1053,85 @@ TEST(scheduling, divide) { return stmt.fuse(i, j, f).pos(f, fpos, A(i, j)).divide(fpos, f0, f1, 4).split(f1, i1, i2, 16).split(i2, i3, i4, 8); }); } + +TEST(scheduling, mergeby) { + auto dim = 256; + float sparsity = 0.1; + Tensor A("A", {dim, dim}, {Sparse, Sparse}); + Tensor B("B", {dim, dim}, {Dense, Sparse}); + Tensor x("x", {dim}, Sparse); + IndexVar i("i"), i1("i1"), i2("i2"), ipos("ipos"), j("j"), f("f"), fpos("fpos"), f0("f0"), f1("f1"); + + srand(59393); + for (int i = 0; i < dim; i++) { + for (int j = 0; j < dim; j++) { + auto rand_float = (float)rand()/(float)(RAND_MAX); + if (rand_float < sparsity) { + A.insert({i, j},((int)(rand_float * 10 / sparsity))); + B.insert({i, j},((int)(rand_float * 10 / sparsity))); + } + } + } + + for (int j = 0; j < dim; j++) { + float rand_float = (float)rand()/(float)(RAND_MAX); + x.insert({j}, ((int)(rand_float*10))); + } + + x.pack(); A.pack(); B.pack(); + + auto test = [&](std::function f) { + Tensor y("y", {dim}, Dense); + y(i) = A(i, j) * B(i, j) * x(j); + auto stmt = f(y.getAssignment().concretize()); + y.compile(stmt); + y.evaluate(); + Tensor expected("expected", {dim}, Dense); + expected(i) = A(i, j) * B(i, j) * x(j); + expected.evaluate(); + ASSERT_TRUE(equals(expected, y)) << expected << endl << y << endl; + }; + + // Test that a simple mergeby works. + test([&](IndexStmt stmt) { + return stmt.mergeby(j, MergeStrategy::Gallop); + }); + + // Testing Two Finger merge. + test([&](IndexStmt stmt) { + return stmt.mergeby(j, MergeStrategy::TwoFinger); + }); + + // Merging a dimension with a dense iterator with Gallop should be no-op. + test([&](IndexStmt stmt) { + return stmt.mergeby(i, MergeStrategy::Gallop); + }); + + // Test interaction between mergeby and other directives + test([&](IndexStmt stmt) { + return stmt.mergeby(i, MergeStrategy::Gallop).split(i, i1, i2, 16); + }); + + test([&](IndexStmt stmt) { + return stmt.mergeby(i, MergeStrategy::Gallop).split(i, i1, i2, 32).unroll(i1, 4); + }); + + test([&](IndexStmt stmt) { + return stmt.mergeby(i, MergeStrategy::Gallop).pos(i, ipos, A(i,j)); + }); + + test([&](IndexStmt stmt) { + return stmt.mergeby(i, MergeStrategy::Gallop).parallelize(i, ParallelUnit::CPUThread, OutputRaceStrategy::NoRaces); + }); +} + +TEST(scheduling, mergeby_gallop_error) { + Tensor x("x", {8}, Format({Sparse})); + Tensor y("y", {8}, Format({Dense})); + Tensor z("z", {8}, Format({Sparse})); + IndexVar i("i"), ipos("ipos"); + y(i) = x(i) + z(i); + + IndexStmt stmt = y.getAssignment().concretize(); + ASSERT_THROW(stmt.mergeby(i, MergeStrategy::Gallop), taco::TacoException); +} \ No newline at end of file diff --git a/test/tests-transformation.cpp b/test/tests-transformation.cpp index abfec3d45..83ff16510 100644 --- a/test/tests-transformation.cpp +++ b/test/tests-transformation.cpp @@ -239,15 +239,15 @@ INSTANTIATE_TEST_CASE_P(parallelize, apply, Values( TransformationTest(Parallelize(i), forall(i, w(i) = b(i)), - forall(i, w(i) = b(i), ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces) + forall(i, w(i) = b(i), MergeStrategy::TwoFinger, ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces) ), TransformationTest(Parallelize(i), forall(i, forall(j, W(i,j) = A(i,j))), - forall(i, forall(j, W(i,j) = A(i,j)), ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces) + forall(i, forall(j, W(i,j) = A(i,j)), MergeStrategy::TwoFinger, ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces) ), TransformationTest(Parallelize(j), forall(i, forall(j, W(i,j) = A(i,j))), - forall(i, forall(j, W(i,j) = A(i,j), ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces)) + forall(i, forall(j, W(i,j) = A(i,j), MergeStrategy::TwoFinger, ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces)) ) ) ); @@ -264,8 +264,8 @@ INSTANTIATE_TEST_CASE_P(misc, reorderLoopsTopologically, Values( NotationTest(forall(i, w(i) = b(i)), forall(i, w(i) = b(i))), - NotationTest(forall(i, w(i) = b(i), ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces), - forall(i, w(i) = b(i), ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces)), + NotationTest(forall(i, w(i) = b(i), MergeStrategy::TwoFinger, ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces), + forall(i, w(i) = b(i), MergeStrategy::TwoFinger, ParallelUnit::DefaultUnit, OutputRaceStrategy::NoRaces)), NotationTest(forall(i, forall(j, W(i,j) = A(i,j))), forall(i, forall(j, W(i,j) = A(i,j)))), diff --git a/tools/taco.cpp b/tools/taco.cpp index fbdcfb984..78023e7f6 100644 --- a/tools/taco.cpp +++ b/tools/taco.cpp @@ -497,6 +497,24 @@ static bool setSchedulingCommands(vector> scheduleCommands, parse stmt = stmt.reorder(reorderedVars); + } else if (command == "mergeby") { + taco_uassert(scheduleCommand.size() == 2) << "'mergeby' scheduling directive takes 2 parameters: mergeby(i, strategy)"; + string i, strat; + MergeStrategy strategy; + + i = scheduleCommand[0]; + strat = scheduleCommand[1]; + if (strat == "TwoFinger") { + strategy = MergeStrategy::TwoFinger; + } else if (strat == "Gallop") { + strategy = MergeStrategy::Gallop; + } else { + taco_uerror << "Merge strategy not defined."; + goto end; + } + + stmt = stmt.mergeby(findVar(i), strategy); + } else if (command == "bound") { taco_uassert(scheduleCommand.size() == 4) << "'bound' scheduling directive takes 4 parameters: bound(i, i1, bound, type)"; string i, i1, type;