Skip to content

Commit

Permalink
Use std::optional instead of llvm::Optional
Browse files Browse the repository at this point in the history
Note that llvm::Optional is just an alias for std::optional these days
and have since been deprecated upstream in favor of std::optional.

PiperOrigin-RevId: 514100010
  • Loading branch information
tensorflower-gardener authored and copybara-github committed Mar 5, 2023
1 parent 6c6bd0a commit 03cfac4
Show file tree
Hide file tree
Showing 21 changed files with 125 additions and 112 deletions.
3 changes: 2 additions & 1 deletion xla/mlir/backends/cpu/transforms/legalize_collective_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ class LegalizeCollectiveOpsPass
void runOnOperation() override;
};

Optional<xla_cpu::ReductionKind> MatchReductionComputation(Region& region) {
std::optional<xla_cpu::ReductionKind> MatchReductionComputation(
Region& region) {
if (!region.hasOneBlock()) {
return std::nullopt;
}
Expand Down
4 changes: 1 addition & 3 deletions xla/mlir/runtime/ir/rt_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ namespace runtime {

using namespace mlir; // NOLINT

using llvm::Optional;

//===----------------------------------------------------------------------===//
// ExportOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -60,7 +58,7 @@ LogicalResult ExportOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}

Optional<unsigned> ExportOp::ordinal() {
std::optional<unsigned> ExportOp::ordinal() {
if (auto ordinal = getOrdinal()) return ordinal->getLimitedValue();
return std::nullopt;
}
Expand Down
4 changes: 2 additions & 2 deletions xla/mlir/runtime/transforms/rt_to_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ void ConvertRuntimeToLLVMPass::runOnOperation() {
// rewriter function into the CFG and they interact badly.

// Convert all async types to opaque pointers.
llvm_converter.addConversion([&](Type type) -> Optional<Type> {
llvm_converter.addConversion([&](Type type) -> std::optional<Type> {
if (type.isa<async::TokenType, async::GroupType, async::ValueType>())
return llvm_converter.getPointerType(
IntegerType::get(type.getContext(), 8));
Expand All @@ -712,7 +712,7 @@ void ConvertRuntimeToLLVMPass::runOnOperation() {
auto add_unrealized_cast = [](OpBuilder &builder, Type type,
ValueRange inputs, Location loc) {
auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
return Optional<Value>(cast.getResult(0));
return std::optional<Value>(cast.getResult(0));
};
converter.addSourceMaterialization(add_unrealized_cast);

Expand Down
2 changes: 1 addition & 1 deletion xla/mlir_hlo/gml_st/IR/gml_st_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ struct CollapseSingleIterationLoops : public OpRewritePattern<ParallelOp> {
newLowerBounds.reserve(op.getLowerBound().size());
newUpperBounds.reserve(op.getUpperBound().size());
newSteps.reserve(op.getStep().size());
auto getConstant = [](Value v) -> Optional<int64_t> {
auto getConstant = [](Value v) -> std::optional<int64_t> {
auto constant =
dyn_cast_or_null<arith::ConstantIndexOp>(v.getDefiningOp());
if (constant) return constant.value();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ SmallVector<OpFoldResult> getDims(OpBuilder &builder, Location loc,
[&](int64_t dim) { return getDim(builder, loc, shapedTypeValue, dim); }));
}

Optional<Value> getPaddingValue(Value &source) {
std::optional<Value> getPaddingValue(Value &source) {
auto padOp = source.getDefiningOp<tensor::PadOp>();
if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
return std::nullopt;
Expand All @@ -127,7 +127,7 @@ Value pack(Location loc, PatternRewriter &rewriter, Value source,
getAsOpFoldResult(rewriter.getI64ArrayAttr(innerTileSizes));
auto empty = tensor::PackOp::createDestinationTensor(
rewriter, loc, source, innerTileSizesOfr, innerDimsPos, outerDimsPerm);
Optional<Value> paddingValue = getPaddingValue(source);
std::optional<Value> paddingValue = getPaddingValue(source);
return rewriter.create<tensor::PackOp>(loc, source, empty, innerDimsPos,
innerTileSizesOfr, paddingValue,
outerDimsPerm);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ struct ScalarizeLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {

// Returns `startIndices`[0, :] for `startIndices` of shape 1xn. Returns None if
// startIndices has a different shape.
Optional<SmallVector<Value>> extractStartIndices(
std::optional<SmallVector<Value>> extractStartIndices(
ImplicitLocOpBuilder &b, TypedValue<ShapedType> startIndices) {
if (startIndices.getType().getRank() != 2 ||
startIndices.getType().getDimSize(0) != 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ LogicalResult tilePartialSoftmax(
// i) by a reduction and subsequent bcast in one dimension, or
// ii) by using the source value as is.
Value commonSource;
Optional<int64_t> commonReductionDim;
SmallVector<Optional<SimpleBcastReduction>> simpleBcastReductions;
std::optional<int64_t> commonReductionDim;
SmallVector<std::optional<SimpleBcastReduction>> simpleBcastReductions;
auto mapOp = llvm::dyn_cast_or_null<linalg::MapOp>(op.getOperation());
if (!mapOp || mapOp.getNumDpsInits() != 1)
return rewriter.notifyMatchFailure(op, "no mapOp");
Expand Down
152 changes: 79 additions & 73 deletions xla/mlir_hlo/mhlo/IR/hlo_ops.cc

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions xla/mlir_hlo/mhlo/analysis/shape_component_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ void ShapeComponentAnalysis::compute(ShapeOrValueInfo requestedInfo) {
.visit(requestedInfo);
}

Optional<ArrayRef<SymbolicExpr>>
std::optional<ArrayRef<SymbolicExpr>>
ShapeComponentAnalysis::ShapeComponentAnalysis::GetShapeInfo(Value value) {
auto request = ShapeOrValueInfo::getShapeInfoOf(value);
compute(request);
Expand All @@ -749,7 +749,7 @@ ShapeComponentAnalysis::ShapeComponentAnalysis::GetShapeInfo(Value value) {
return llvm::ArrayRef(found->second);
}

Optional<ArrayRef<SymbolicExpr>>
std::optional<ArrayRef<SymbolicExpr>>
ShapeComponentAnalysis::ShapeComponentAnalysis::GetValueInfo(Value shape) {
auto request = ShapeOrValueInfo::getValueInfoOf(shape);
compute(request);
Expand Down
8 changes: 5 additions & 3 deletions xla/mlir_hlo/mhlo/analysis/shape_component_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License.
#ifndef MLIR_HLO_MHLO_ANALYSIS_SHAPE_COMPONENT_ANALYSIS_H
#define MLIR_HLO_MHLO_ANALYSIS_SHAPE_COMPONENT_ANALYSIS_H

#include <optional>

#include "llvm/Support/raw_ostream.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/IR/AffineMap.h"
Expand Down Expand Up @@ -97,7 +99,7 @@ class ShapeComponentAnalysis {
// `1`. This is useful for broadcasts.
bool isKnownNotOne() const;
// If this is a reference to a singular symbol, return it.
Optional<Symbol> singleton() const;
std::optional<Symbol> singleton() const;

bool operator==(const SymbolicExpr &rhs) const {
return expr == rhs.expr && symbols == rhs.symbols;
Expand Down Expand Up @@ -127,10 +129,10 @@ class ShapeComponentAnalysis {
public:
// Return the computed components for the shape of a value, e.g., the
// dimensions of a tensor.
Optional<ArrayRef<SymbolicExpr>> GetShapeInfo(Value value);
std::optional<ArrayRef<SymbolicExpr>> GetShapeInfo(Value value);
// Return the computed components for the value of a value, e.g, the elements
// of a shape tensor.
Optional<ArrayRef<SymbolicExpr>> GetValueInfo(Value shape);
std::optional<ArrayRef<SymbolicExpr>> GetValueInfo(Value shape);

// Clear analysis data structures.
void reset();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/

#include <algorithm>
#include <optional>
#include <utility>

#include "llvm/ADT/STLExtras.h"
Expand Down Expand Up @@ -156,7 +157,7 @@ LogicalResult tryLowerTo1DOr2DReduction(

// Reify the result shape early so that the pattern can fail without altering
// the IR.
Optional<Value> resultShape;
std::optional<Value> resultShape;
if (requiresDynamicReshape) {
llvm::SmallVector<Value, 1> reifiedShapes;
if (failed(llvm::cast<InferShapedTypeOpInterface>(op.getOperation())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
// This file implements logic for lowering HLO dialect to LHLO dialect.

#include <memory>
#include <optional>
#include <utility>

#include "mhlo/IR/hlo_ops.h"
Expand Down Expand Up @@ -132,7 +133,7 @@ struct ScalarHloToArithmeticPattern : public OpConversionPattern<OpTy> {

auto loc = op.getLoc();

Optional<ShapedType> resultTy;
std::optional<ShapedType> resultTy;
resultTy = this->typeConverter->convertType(op->getResultTypes().front())
.template dyn_cast<ShapedType>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <optional>
#include <string>

#include "llvm/ADT/ArrayRef.h"
Expand Down Expand Up @@ -72,7 +73,7 @@ bool hasPrivateFeaturesNotInStablehlo(HloOpTy hloOp) {
return false;
}

bool hasPackedNibble(Optional<ArrayAttr> precisionConfigAttr) {
bool hasPackedNibble(std::optional<ArrayAttr> precisionConfigAttr) {
if (!precisionConfigAttr) return false;
return llvm::any_of(*precisionConfigAttr, [&](Attribute attr) {
auto precisionAttr = attr.cast<mhlo::PrecisionAttr>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,8 @@ class DynamicBroadcastInDimOpToBroadcastConverter
SmallVector<int64_t> broadcastDimensions =
llvm::to_vector(op.getBroadcastDimensions().getValues<int64_t>());

SmallVector<Optional<bool>> expansionBehavior(broadcastDimensions.size());
SmallVector<std::optional<bool>> expansionBehavior(
broadcastDimensions.size());

// Use static type info.
for (const auto& [idx, dim] : llvm::enumerate(operandTy.getShape())) {
Expand Down Expand Up @@ -1375,7 +1376,7 @@ class ReshapeOpConverter : public OpConversionPattern<mhlo::ReshapeOp> {
// Compute the reassociation maps for the linalg operation. This will
// succeed if the reshape can be done with a single expand_shape or
// collapse_shape.
if (Optional<SmallVector<ReassociationIndices>> reassociationMap =
if (std::optional<SmallVector<ReassociationIndices>> reassociationMap =
getReassociationIndicesForReshape(operandType, resultType)) {
if (resultType.getRank() < operandType.getRank()) {
// We have found a working reassociation map. If the operand is dynamic,
Expand Down Expand Up @@ -4214,7 +4215,7 @@ class PointwiseToLinalgMapConverter : public OpConversionPattern<OpTy> {
}

// Find result type, if on tensors.
Optional<ShapedType> resultTy;
std::optional<ShapedType> resultTy;
resultTy = this->typeConverter->convertType(op->getResultTypes().front())
.template dyn_cast<ShapedType>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class CompareIConvert : public OpRewritePattern<mhlo::CompareOp> {
!rhsType.getElementType().isSignlessInteger())
return failure();

Optional<arith::CmpIPredicate> comparePredicate = std::nullopt;
std::optional<arith::CmpIPredicate> comparePredicate = std::nullopt;
switch (op.getComparisonDirection()) {
case ComparisonDirection::EQ:
comparePredicate = arith::CmpIPredicate::eq;
Expand Down Expand Up @@ -105,7 +105,7 @@ class CompareFConvert : public OpRewritePattern<mhlo::CompareOp> {
!rhsType.getElementType().isa<FloatType>())
return failure();

Optional<arith::CmpFPredicate> comparePredicate = std::nullopt;
std::optional<arith::CmpFPredicate> comparePredicate = std::nullopt;
switch (op.getComparisonDirection()) {
case ComparisonDirection::EQ:
comparePredicate = arith::CmpFPredicate::OEQ;
Expand Down
18 changes: 10 additions & 8 deletions xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,16 +319,17 @@ inline Value getConstantOrSplat(OpBuilder* b, Location loc, Type t,
}

template <typename PredicateType>
inline Optional<PredicateType> getCmpPredicate(mhlo::ComparisonDirection,
bool) {
inline std::optional<PredicateType> getCmpPredicate(mhlo::ComparisonDirection,
bool) {
return std::nullopt;
}

template <>
inline Optional<arith::CmpFPredicate> getCmpPredicate<arith::CmpFPredicate>(
inline std::optional<arith::CmpFPredicate>
getCmpPredicate<arith::CmpFPredicate>(
mhlo::ComparisonDirection comparisonDirection, bool isSigned) {
assert(isSigned && "cannot have an unsigned float!");
return llvm::StringSwitch<Optional<arith::CmpFPredicate>>(
return llvm::StringSwitch<std::optional<arith::CmpFPredicate>>(
stringifyComparisonDirection(comparisonDirection))
.Case("EQ", arith::CmpFPredicate::OEQ)
.Case("NE", arith::CmpFPredicate::UNE)
Expand All @@ -340,9 +341,10 @@ inline Optional<arith::CmpFPredicate> getCmpPredicate<arith::CmpFPredicate>(
}

template <>
inline Optional<arith::CmpIPredicate> getCmpPredicate<arith::CmpIPredicate>(
inline std::optional<arith::CmpIPredicate>
getCmpPredicate<arith::CmpIPredicate>(
mhlo::ComparisonDirection comparisonDirection, bool isSigned) {
return llvm::StringSwitch<Optional<arith::CmpIPredicate>>(
return llvm::StringSwitch<std::optional<arith::CmpIPredicate>>(
stringifyComparisonDirection(comparisonDirection))
.Case("EQ", arith::CmpIPredicate::eq)
.Case("NE", arith::CmpIPredicate::ne)
Expand Down Expand Up @@ -406,7 +408,7 @@ inline Value mapMhloOpToStdScalarOp<mhlo::CompareOp>(
Type elementType = getElementTypeOrSelf(argTypes.front());
if (elementType.isa<IntegerType>()) {
bool isUnsigned = IsUnsignedIntegerType{}(elementType);
Optional<arith::CmpIPredicate> predicate =
std::optional<arith::CmpIPredicate> predicate =
getCmpPredicate<arith::CmpIPredicate>(comparisonDirection, !isUnsigned);
assert(predicate.has_value() && "expected valid comparison direction");
return b->create<ScalarIOp<mhlo::CompareOp>>(loc, predicate.value(), lhs,
Expand Down Expand Up @@ -449,7 +451,7 @@ inline Value mapMhloOpToStdScalarOp<mhlo::CompareOp>(
assert(predicate.has_value() && "expected valid comparison direction");
return b->create<arith::CmpIOp>(loc, *predicate, lhsInt, rhsInt);
}
Optional<arith::CmpFPredicate> predicate =
std::optional<arith::CmpFPredicate> predicate =
getCmpPredicate<arith::CmpFPredicate>(comparisonDirection,
/*is_signed=*/true);
assert(predicate.has_value() && "expected valid comparison direction");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
#include <utility>

#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ limitations under the License.
#include <memory>
#include <utility>

#include "llvm/ADT/Optional.h"
#include "mhlo/IR/hlo_ops.h"
#include "mhlo/transforms/passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ struct SimplifyBroadcasts : public mlir::OpRewritePattern<shape::BroadcastOp> {
for (const auto &sInfo : shapesInfo) rank = std::max(rank, sInfo.size());

// Compute broadcast symbolically.
SmallVector<Optional<SymbolicBroadcastDimension>> symResult(rank,
std::nullopt);
SmallVector<std::optional<SymbolicBroadcastDimension>> symResult(
rank, std::nullopt);
for (const auto &sInfo : llvm::enumerate(shapesInfo)) {
size_t dimOffset = rank - sInfo.value().size();
for (const auto &symExpr : llvm::enumerate(sInfo.value())) {
Expand Down
3 changes: 2 additions & 1 deletion xla/mlir_hlo/mhlo/utils/legalize_to_linalg_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.

#include <algorithm>
#include <numeric>
#include <optional>
#include <string>
#include <utility>

Expand Down Expand Up @@ -127,7 +128,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
}

// Find result type, if on tensors.
Optional<ShapedType> resultTy;
std::optional<ShapedType> resultTy;
resultTy = this->typeConverter->convertType(op->getResultTypes().front())
.template dyn_cast<ShapedType>();

Expand Down
3 changes: 2 additions & 1 deletion xla/mlir_hlo/transforms/propagate_static_shapes_to_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <iterator>
#include <memory>
#include <numeric>
#include <optional>
#include <utility>

#include "llvm/ADT/STLExtras.h"
Expand Down Expand Up @@ -139,7 +140,7 @@ LogicalResult PropagateStaticShapesPattern::matchAndRewrite(
}

// Collect gpu.launch_func ops which launch the func_op kernel.
Optional<SymbolTable::UseRange> symUses =
std::optional<SymbolTable::UseRange> symUses =
symbolTable.getSymbolUses(funcOp, symbolTable.getOp());
if (!symUses)
return rewriter.notifyMatchFailure(funcOp, "failed to find symbol uses");
Expand Down

0 comments on commit 03cfac4

Please sign in to comment.