Skip to content

Commit

Permalink
Merge branch 'new-test-target' into test-target
Browse files Browse the repository at this point in the history
g especially if it merges an updated upstream into a topic branch.
merge
  • Loading branch information
zhang677 committed May 26, 2022
2 parents 8af953b + 58567e2 commit 1505153
Show file tree
Hide file tree
Showing 29 changed files with 2,259 additions and 1,885 deletions.
4 changes: 3 additions & 1 deletion include/taco/format.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ class ModeFormat {
/// Properties of a mode format
enum Property {
FULL, NOT_FULL, ORDERED, NOT_ORDERED, UNIQUE, NOT_UNIQUE, BRANCHLESS,
NOT_BRANCHLESS, COMPACT, NOT_COMPACT, ZEROLESS, NOT_ZEROLESS
NOT_BRANCHLESS, COMPACT, NOT_COMPACT, ZEROLESS, NOT_ZEROLESS, PADDED,
NOT_PADDED
};

/// Instantiates an undefined mode format
Expand Down Expand Up @@ -129,6 +130,7 @@ class ModeFormat {
bool isBranchless() const;
bool isCompact() const;
bool isZeroless() const;
bool isPadded() const;

/// Returns true if a mode format has a specific capability, false otherwise
bool hasCoordValIter() const;
Expand Down
22 changes: 20 additions & 2 deletions include/taco/index_notation/index_notation.h
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,23 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
/// reorder takes a new ordering for a set of index variables that are directly nested in the iteration order
IndexStmt reorder(std::vector<IndexVar> reorderedvars) const;

/// The mergeby transformation specifies how to merge iterators on
/// the given index variable. By default, if an iterator is used for windowing
/// it will be merged with the "gallop" strategy.
/// All other iterators are merged with the "two finger" strategy.
/// The two finger strategy merges by advancing each iterator one at a time,
/// while the gallop strategy implements the exponential search algorithm.
///
/// Preconditions:
/// This command applies to variables involving sparse iterators only;
/// it is a no-op if the variable invovles any dense iterators.
/// Any variable can be merged with the two finger strategy, whereas gallop
/// only applies to a variable if its merge lattice has a single point
/// (i.e. an intersection). For example, if a variable involves multiplications
/// only, it can be merged with gallop.
/// Furthermore, all iterators must be ordered for gallop to apply.
IndexStmt mergeby(IndexVar i, MergeStrategy strategy) const;

/// The parallelize
/// transformation tags an index variable for parallel execution. The
/// transformation takes as an argument the type of parallel hardware
Expand Down Expand Up @@ -835,13 +852,14 @@ class Forall : public IndexStmt {
Forall() = default;
Forall(const ForallNode*);
Forall(IndexVar indexVar, IndexStmt stmt);
Forall(IndexVar indexVar, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0);
Forall(IndexVar indexVar, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0);

IndexVar getIndexVar() const;
IndexStmt getStmt() const;

ParallelUnit getParallelUnit() const;
OutputRaceStrategy getOutputRaceStrategy() const;
MergeStrategy getMergeStrategy() const;

size_t getUnrollFactor() const;

Expand All @@ -850,7 +868,7 @@ class Forall : public IndexStmt {

/// Create a forall index statement.
Forall forall(IndexVar i, IndexStmt stmt);
Forall forall(IndexVar i, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0);
Forall forall(IndexVar i, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0);


/// A where statment has a producer statement that binds a tensor variable in
Expand Down
5 changes: 3 additions & 2 deletions include/taco/index_notation/index_notation_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -398,15 +398,16 @@ struct YieldNode : public IndexStmtNode {
};

struct ForallNode : public IndexStmtNode {
ForallNode(IndexVar indexVar, IndexStmt stmt, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0)
: indexVar(indexVar), stmt(stmt), parallel_unit(parallel_unit), output_race_strategy(output_race_strategy), unrollFactor(unrollFactor) {}
ForallNode(IndexVar indexVar, IndexStmt stmt, MergeStrategy merge_strategy, ParallelUnit parallel_unit, OutputRaceStrategy output_race_strategy, size_t unrollFactor = 0)
: indexVar(indexVar), stmt(stmt), merge_strategy(merge_strategy), parallel_unit(parallel_unit), output_race_strategy(output_race_strategy), unrollFactor(unrollFactor) {}

void accept(IndexStmtVisitorStrict* v) const {
v->visit(this);
}

IndexVar indexVar;
IndexStmt stmt;
MergeStrategy merge_strategy;
ParallelUnit parallel_unit;
OutputRaceStrategy output_race_strategy;
size_t unrollFactor = 0;
Expand Down
21 changes: 21 additions & 0 deletions include/taco/index_notation/transformations.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class AddSuchThatPredicates;
class Parallelize;
class TopoReorder;
class SetAssembleStrategy;
class SetMergeStrategy;

/// A transformation is an optimization that transforms a statement in the
/// concrete index notation into a new statement that computes the same result
Expand All @@ -36,6 +37,7 @@ class Transformation {
Transformation(TopoReorder);
Transformation(AddSuchThatPredicates);
Transformation(SetAssembleStrategy);
Transformation(SetMergeStrategy);

IndexStmt apply(IndexStmt stmt, std::string *reason = nullptr) const;

Expand Down Expand Up @@ -206,6 +208,25 @@ class SetAssembleStrategy : public TransformationInterface {
/// Print a SetAssembleStrategy command.
std::ostream &operator<<(std::ostream &, const SetAssembleStrategy&);

class SetMergeStrategy : public TransformationInterface {
public:
SetMergeStrategy(IndexVar i, MergeStrategy strategy);

IndexVar geti() const;
MergeStrategy getMergeStrategy() const;

IndexStmt apply(IndexStmt stmt, std::string *reason = nullptr) const;

void print(std::ostream &os) const;

private:
struct Content;
std::shared_ptr<Content> content;
};

/// Print a SetMergeStrategy command.
std::ostream &operator<<(std::ostream &, const SetMergeStrategy&);

// Autoscheduling functions

/**
Expand Down
7 changes: 7 additions & 0 deletions include/taco/ir_tags.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ enum class AssembleStrategy {
};
extern const char *AssembleStrategy_NAMES[];

/// MergeStrategy::TwoFinger merges iterators by incrementing one at a time
/// MergeStrategy::Galloping merges iterators by exponential search (galloping)
enum class MergeStrategy {
TwoFinger, Gallop
};
extern const char *MergeStrategy_NAMES[];

}

#endif //TACO_IR_TAGS_H
26 changes: 18 additions & 8 deletions include/taco/lower/lowerer_impl_imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,18 @@ class LowererImplImperative : public LowererImpl {
* \param statement
* A concrete index notation statement to compute at the points in the
* sparse iteration space described by the merge lattice.
* \param mergeStrategy
* A strategy for merging iterators. One of TwoFinger or Gallop.
*
* \return
* IR code to compute the forall loop.
*/
virtual ir::Stmt lowerMergeLattice(MergeLattice lattice, IndexVar coordinateVar,
IndexStmt statement,
const std::set<Access>& reducedAccesses);
const std::set<Access>& reducedAccesses,
MergeStrategy mergeStrategy);

virtual ir::Stmt resolveCoordinate(std::vector<Iterator> mergers, ir::Expr coordinate, bool emitVarDecl);
virtual ir::Stmt resolveCoordinate(std::vector<Iterator> mergers, ir::Expr coordinate, bool emitVarDecl, bool mergeWithMax);

/**
* Lower the merge point at the top of the given lattice to code that iterates
Expand All @@ -169,23 +172,29 @@ class LowererImplImperative : public LowererImpl {
* coordinate the merge point is at.
* A concrete index notation statement to compute at the points in the
* sparse iteration space region described by the merge point.
* \param mergeWithMax
* A boolean indicating whether coordinates should be combined with MAX instead of MIN.
* MAX is needed when the iterators are merged with the Gallop strategy.
*/
virtual ir::Stmt lowerMergePoint(MergeLattice pointLattice,
ir::Expr coordinate, IndexVar coordinateVar, IndexStmt statement,
const std::set<Access>& reducedAccesses, bool resolvedCoordDeclared);
const std::set<Access>& reducedAccesses, bool resolvedCoordDeclared,
MergeStrategy mergestrategy);

/// Lower a merge lattice to cases.
virtual ir::Stmt lowerMergeCases(ir::Expr coordinate, IndexVar coordinateVar, IndexStmt stmt,
MergeLattice lattice,
const std::set<Access>& reducedAccesses);
const std::set<Access>& reducedAccesses,
MergeStrategy mergeStrategy);

/// Lower a forall loop body.
virtual ir::Stmt lowerForallBody(ir::Expr coordinate, IndexStmt stmt,
std::vector<Iterator> locaters,
std::vector<Iterator> inserters,
std::vector<Iterator> appenders,
MergeLattice caseLattice,
const std::set<Access>& reducedAccesses);
const std::set<Access>& reducedAccesses,
MergeStrategy mergeStrategy);


/// Lower a where statement.
Expand Down Expand Up @@ -375,7 +384,7 @@ class LowererImplImperative : public LowererImpl {

/// Conditionally increment iterator position variables.
ir::Stmt codeToIncIteratorVars(ir::Expr coordinate, IndexVar coordinateVar,
std::vector<Iterator> iterators, std::vector<Iterator> mergers);
std::vector<Iterator> iterators, std::vector<Iterator> mergers, MergeStrategy strategy);

ir::Stmt codeToLoadCoordinatesFromPosIterators(std::vector<Iterator> iterators, bool declVars);

Expand Down Expand Up @@ -410,7 +419,8 @@ class LowererImplImperative : public LowererImpl {
/// Lowers a merge lattice to cases assuming there are no more loops to be emitted in stmt.
/// Will emit checks for explicit zeros for each mode iterator and each locator in the lattice.
ir::Stmt lowerMergeCasesWithExplicitZeroChecks(ir::Expr coordinate, IndexVar coordinateVar, IndexStmt stmt,
MergeLattice lattice, const std::set<Access>& reducedAccesses);
MergeLattice lattice, const std::set<Access>& reducedAccesses,
MergeStrategy mergeStrategy);

/// Constructs cases comparing the coordVar for each iterator to the resolved coordinate.
/// Returns a vector where coordComparisons[i] corresponds to a case for iters[i]
Expand Down Expand Up @@ -444,7 +454,7 @@ class LowererImplImperative : public LowererImpl {
/// The map must be of iterators to exprs of boolean types
std::vector<ir::Stmt> lowerCasesFromMap(std::map<Iterator, ir::Expr> iteratorToCondition,
ir::Expr coordinate, IndexStmt stmt, const MergeLattice& lattice,
const std::set<Access>& reducedAccesses);
const std::set<Access>& reducedAccesses, MergeStrategy mergeStrategy);

/// Constructs an expression which checks if this access is "zero"
ir::Expr constructCheckForAccessZero(Access);
Expand Down
8 changes: 5 additions & 3 deletions include/taco/lower/mode_format_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,10 @@ class ModeFormatImpl {
public:
ModeFormatImpl(std::string name, bool isFull, bool isOrdered, bool isUnique,
bool isBranchless, bool isCompact, bool isZeroless,
bool hasCoordValIter, bool hasCoordPosIter, bool hasLocate,
bool hasInsert, bool hasAppend, bool hasSeqInsertEdge,
bool hasInsertCoord, bool isYieldPosPure);
bool isPadded, bool hasCoordValIter, bool hasCoordPosIter,
bool hasLocate, bool hasInsert, bool hasAppend,
bool hasSeqInsertEdge, bool hasInsertCoord,
bool isYieldPosPure);

virtual ~ModeFormatImpl();

Expand Down Expand Up @@ -246,6 +247,7 @@ class ModeFormatImpl {
const bool isBranchless;
const bool isCompact;
const bool isZeroless;
const bool isPadded;

const bool hasCoordValIter;
const bool hasCoordPosIter;
Expand Down
5 changes: 3 additions & 2 deletions include/taco/lower/mode_format_singleton.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ class SingletonModeFormat : public ModeFormatImpl {
using ModeFormatImpl::getInsertCoord;

SingletonModeFormat();
SingletonModeFormat(bool isFull, bool isOrdered,
bool isUnique, bool isZeroless, long long allocSize = DEFAULT_ALLOC_SIZE);
SingletonModeFormat(bool isFull, bool isOrdered, bool isUnique,
bool isZeroless, bool isPadded,
long long allocSize = DEFAULT_ALLOC_SIZE);

~SingletonModeFormat() override {}

Expand Down
22 changes: 22 additions & 0 deletions src/codegen/codegen_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,28 @@ const string cHeaders =
"int cmp(const void *a, const void *b) {\n"
" return *((const int*)a) - *((const int*)b);\n"
"}\n"
// Increment arrayStart until array[arrayStart] >= target or arrayStart >= arrayEnd
// using an exponential search algorithm: https://en.wikipedia.org/wiki/Exponential_search.
"int taco_gallop(int *array, int arrayStart, int arrayEnd, int target) {\n"
" if (array[arrayStart] >= target || arrayStart >= arrayEnd) {\n"
" return arrayStart;\n"
" }\n"
" int step = 1;\n"
" int curr = arrayStart;\n"
" while (curr + step < arrayEnd && array[curr + step] < target) {\n"
" curr += step;\n"
" step = step * 2;\n"
" }\n"
"\n"
" step = step / 2;\n"
" while (step > 0) {\n"
" if (curr + step < arrayEnd && array[curr + step] < target) {\n"
" curr += step;\n"
" }\n"
" step = step / 2;\n"
" }\n"
" return curr+1;\n"
"}\n"
"int taco_binarySearchAfter(int *array, int arrayStart, int arrayEnd, int target) {\n"
" if (array[arrayStart] >= target) {\n"
" return arrayStart;\n"
Expand Down
15 changes: 15 additions & 0 deletions src/format.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ bool ModeFormat::hasProperties(const std::vector<Property>& properties) const {
return false;
}
break;
case PADDED:
if (!isPadded()) {
return false;
}
break;
case NOT_FULL:
if (isFull()) {
return false;
Expand Down Expand Up @@ -217,6 +222,11 @@ bool ModeFormat::hasProperties(const std::vector<Property>& properties) const {
return false;
}
break;
case NOT_PADDED:
if (isPadded()) {
return false;
}
break;
}
}
return true;
Expand Down Expand Up @@ -252,6 +262,11 @@ bool ModeFormat::isZeroless() const {
return impl->isZeroless;
}

bool ModeFormat::isPadded() const {
taco_iassert(defined());
return impl->isPadded;
}

bool ModeFormat::hasCoordValIter() const {
taco_iassert(defined());
return impl->hasCoordValIter;
Expand Down
Loading

0 comments on commit 1505153

Please sign in to comment.