Skip to content

Commit

Permalink
[gccjit] rename visit to visitExpr
Browse files Browse the repository at this point in the history
  • Loading branch information
SchrodingerZhu committed Oct 24, 2024
1 parent b9fa376 commit 9d373e4
Showing 1 changed file with 69 additions and 68 deletions.
137 changes: 69 additions & 68 deletions src/Translation/TranslateToGCCJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,24 +76,24 @@ class RegionVisitor {
void translateIntoContext();

private:
Expr visit(Operation *op);
void visitAsRValue(ValueRange operands,
llvm::SmallVectorImpl<gcc_jit_rvalue *> &result);
gcc_jit_rvalue *visitWithoutCache(ConstantOp op);
gcc_jit_rvalue *visitWithoutCache(LiteralOp op);
gcc_jit_rvalue *visitWithoutCache(SizeOfOp op);
gcc_jit_rvalue *visitWithoutCache(AlignOfOp op);
gcc_jit_rvalue *visitWithoutCache(AsRValueOp op);
gcc_jit_rvalue *visitWithoutCache(BinaryOp op);
gcc_jit_rvalue *visitWithoutCache(UnaryOp op);
gcc_jit_rvalue *visitWithoutCache(CompareOp op);
gcc_jit_rvalue *visitWithoutCache(CallOp op);
gcc_jit_rvalue *visitWithoutCache(CastOp op);
gcc_jit_rvalue *visitWithoutCache(BitCastOp op);
gcc_jit_rvalue *visitWithoutCache(PtrCallOp op);
gcc_jit_rvalue *visitWithoutCache(AddrOp op);
gcc_jit_rvalue *visitWithoutCache(FnAddrOp op);
gcc_jit_lvalue *visitWithoutCache(GetGlobalOp op);
Expr visitExpr(Operation *op);
void visitExprAsRValue(ValueRange operands,
llvm::SmallVectorImpl<gcc_jit_rvalue *> &result);
gcc_jit_rvalue *visitExprWithoutCache(ConstantOp op);
gcc_jit_rvalue *visitExprWithoutCache(LiteralOp op);
gcc_jit_rvalue *visitExprWithoutCache(SizeOfOp op);
gcc_jit_rvalue *visitExprWithoutCache(AlignOfOp op);
gcc_jit_rvalue *visitExprWithoutCache(AsRValueOp op);
gcc_jit_rvalue *visitExprWithoutCache(BinaryOp op);
gcc_jit_rvalue *visitExprWithoutCache(UnaryOp op);
gcc_jit_rvalue *visitExprWithoutCache(CompareOp op);
gcc_jit_rvalue *visitExprWithoutCache(CallOp op);
gcc_jit_rvalue *visitExprWithoutCache(CastOp op);
gcc_jit_rvalue *visitExprWithoutCache(BitCastOp op);
gcc_jit_rvalue *visitExprWithoutCache(PtrCallOp op);
gcc_jit_rvalue *visitExprWithoutCache(AddrOp op);
gcc_jit_rvalue *visitExprWithoutCache(FnAddrOp op);
gcc_jit_lvalue *visitExprWithoutCache(GetGlobalOp op);
};

} // namespace
Expand Down Expand Up @@ -458,7 +458,7 @@ void RegionVisitor::translateIntoContext() {
Block &block = region.getBlocks().front();
auto terminator = cast<gccjit::ReturnOp>(block.getTerminator());
auto value = terminator->getOperand(0);
auto rvalue = visit(value.getDefiningOp());
auto rvalue = visitExpr(value.getDefiningOp());
auto symName = SymbolRefAttr::get(getMLIRContext(), globalOp.getSymName());
auto *lvalue = getTranslator().getGlobalLValue(symName);
gcc_jit_global_set_initializer_rvalue(lvalue, rvalue);
Expand All @@ -467,41 +467,42 @@ void RegionVisitor::translateIntoContext() {
llvm_unreachable("unknown region parent");
}

Expr RegionVisitor::visit(Operation *op) {
Expr RegionVisitor::visitExpr(Operation *op) {
if (op->getNumResults() != 1)
llvm_unreachable("expected single result operation");

auto &cached = exprCache[op->getResult(0)];
if (!cached)
cached = llvm::TypeSwitch<Operation *, Expr>(op)
.Case([&](ConstantOp op) { return visitWithoutCache(op); })
.Case([&](LiteralOp op) { return visitWithoutCache(op); })
.Case([&](SizeOfOp op) { return visitWithoutCache(op); })
.Case([&](AlignOfOp op) { return visitWithoutCache(op); })
.Case([&](AsRValueOp op) { return visitWithoutCache(op); })
.Case([&](BinaryOp op) { return visitWithoutCache(op); })
.Case([&](UnaryOp op) { return visitWithoutCache(op); })
.Case([&](CompareOp op) { return visitWithoutCache(op); })
.Case([&](CallOp op) { return visitWithoutCache(op); })
.Case([&](CastOp op) { return visitWithoutCache(op); })
.Case([&](BitCastOp op) { return visitWithoutCache(op); })
.Case([&](PtrCallOp op) { return visitWithoutCache(op); })
.Case([&](AddrOp op) { return visitWithoutCache(op); })
.Case([&](FnAddrOp op) { return visitWithoutCache(op); })
.Case([&](GetGlobalOp op) { return visitWithoutCache(op); })
.Default([](Operation *) -> Expr {
llvm_unreachable("unknown expression type");
});
cached =
llvm::TypeSwitch<Operation *, Expr>(op)
.Case([&](ConstantOp op) { return visitExprWithoutCache(op); })
.Case([&](LiteralOp op) { return visitExprWithoutCache(op); })
.Case([&](SizeOfOp op) { return visitExprWithoutCache(op); })
.Case([&](AlignOfOp op) { return visitExprWithoutCache(op); })
.Case([&](AsRValueOp op) { return visitExprWithoutCache(op); })
.Case([&](BinaryOp op) { return visitExprWithoutCache(op); })
.Case([&](UnaryOp op) { return visitExprWithoutCache(op); })
.Case([&](CompareOp op) { return visitExprWithoutCache(op); })
.Case([&](CallOp op) { return visitExprWithoutCache(op); })
.Case([&](CastOp op) { return visitExprWithoutCache(op); })
.Case([&](BitCastOp op) { return visitExprWithoutCache(op); })
.Case([&](PtrCallOp op) { return visitExprWithoutCache(op); })
.Case([&](AddrOp op) { return visitExprWithoutCache(op); })
.Case([&](FnAddrOp op) { return visitExprWithoutCache(op); })
.Case([&](GetGlobalOp op) { return visitExprWithoutCache(op); })
.Default([](Operation *) -> Expr {
llvm_unreachable("unknown expression type");
});
return cached;
}

void RegionVisitor::visitAsRValue(
void RegionVisitor::visitExprAsRValue(
ValueRange operands, llvm::SmallVectorImpl<gcc_jit_rvalue *> &result) {
for (auto operand : operands)
result.push_back(visit(operand.getDefiningOp()));
result.push_back(visitExpr(operand.getDefiningOp()));
}

gcc_jit_rvalue *RegionVisitor::visitWithoutCache(ConstantOp op) {
gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(ConstantOp op) {
auto type = op.getType();
auto *typeHandle = getTranslator().convertType(type);
return llvm::TypeSwitch<TypedAttr, gcc_jit_rvalue *>(op.getValue())
Expand Down Expand Up @@ -530,23 +531,23 @@ gcc_jit_rvalue *RegionVisitor::visitWithoutCache(ConstantOp op) {
});
}

gcc_jit_rvalue *RegionVisitor::visitWithoutCache(LiteralOp op) {
gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(LiteralOp op) {
auto string = op.getValue().getInitializer().str();
return gcc_jit_context_new_string_literal(getContext(), string.c_str());
}

gcc_jit_rvalue *RegionVisitor::visitWithoutCache(SizeOfOp op) {
gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(SizeOfOp op) {
auto type = op.getType();
auto *typeHandle = getTranslator().convertType(type);
return gcc_jit_context_new_sizeof(getContext(), typeHandle);
}

gcc_jit_rvalue *RegionVisitor::visitWithoutCache(AlignOfOp op) {
gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(AlignOfOp op) {
llvm_unreachable("GCCJIT does not support alignof yet");
}

gcc_jit_rvalue *RegionVisitor::visitWithoutCache(AsRValueOp op) {
auto lvalue = visit(op.getLvalue().getDefiningOp());
gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(AsRValueOp op) {
auto lvalue = visitExpr(op.getLvalue().getDefiningOp());
return gcc_jit_lvalue_as_rvalue(lvalue);
}

Expand Down Expand Up @@ -581,9 +582,9 @@ static gcc_jit_binary_op convertBinaryOp(BOp kind) {
}

// RValue always has a defining operation
gcc_jit_rvalue *RegionVisitor::visitWithoutCache(BinaryOp op) {
auto lhs = visit(op.getLhs().getDefiningOp());
auto rhs = visit(op.getRhs().getDefiningOp());
gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(BinaryOp op) {
auto lhs = visitExpr(op.getLhs().getDefiningOp());
auto rhs = visitExpr(op.getRhs().getDefiningOp());
auto kind = convertBinaryOp(op.getOp());
auto *loc = getTranslator().getLocation(op.getLoc());
auto *ctxt = getContext();
Expand All @@ -605,8 +606,8 @@ static gcc_jit_unary_op convertUnaryOp(UOp kind) {
llvm_unreachable("unknown unary operation");
}

gcc_jit_rvalue *RegionVisitor::visitWithoutCache(UnaryOp op) {
auto operand = visit(op.getOperand().getDefiningOp());
gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(UnaryOp op) {
auto operand = visitExpr(op.getOperand().getDefiningOp());
auto kind = convertUnaryOp(op.getOp());
auto *loc = getTranslator().getLocation(op.getLoc());
auto *ctxt = getContext();
Expand All @@ -632,16 +633,16 @@ static gcc_jit_comparison convertCompareOp(CmpOp kind) {
llvm_unreachable("unknown compare operation");
}

gcc_jit_rvalue *RegionVisitor::visitWithoutCache(CompareOp op) {
auto lhs = visit(op.getLhs().getDefiningOp());
auto rhs = visit(op.getRhs().getDefiningOp());
gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(CompareOp op) {
auto lhs = visitExpr(op.getLhs().getDefiningOp());
auto rhs = visitExpr(op.getRhs().getDefiningOp());
auto kind = convertCompareOp(op.getOp());
auto *loc = getTranslator().getLocation(op.getLoc());
auto *ctxt = getContext();
return gcc_jit_context_new_comparison(ctxt, loc, kind, lhs, rhs);
}

gcc_jit_rvalue *RegionVisitor::visitWithoutCache(CallOp op) {
gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(CallOp op) {
gcc_jit_function *callee = nullptr;
if (op.getBuiltin()) {
callee = gcc_jit_context_get_builtin_function(
Expand All @@ -651,7 +652,7 @@ gcc_jit_rvalue *RegionVisitor::visitWithoutCache(CallOp op) {
}
assert(callee && "function not found");
llvm::SmallVector<gcc_jit_rvalue *> args;
visitAsRValue(op.getArgs(), args);
visitExprAsRValue(op.getArgs(), args);
auto *loc = getTranslator().getLocation(op.getLoc());
auto *ctxt = getContext();
auto *call =
Expand All @@ -660,26 +661,26 @@ gcc_jit_rvalue *RegionVisitor::visitWithoutCache(CallOp op) {
return call;
}

gcc_jit_rvalue *RegionVisitor::visitWithoutCache(CastOp op) {
auto operand = visit(op.getOperand().getDefiningOp());
gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(CastOp op) {
auto operand = visitExpr(op.getOperand().getDefiningOp());
auto *loc = getTranslator().getLocation(op.getLoc());
auto *ctxt = getContext();
auto *type = getTranslator().convertType(op.getType());
return gcc_jit_context_new_cast(ctxt, loc, operand, type);
}

gcc_jit_rvalue *RegionVisitor::visitWithoutCache(BitCastOp op) {
auto operand = visit(op.getOperand().getDefiningOp());
gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(BitCastOp op) {
auto operand = visitExpr(op.getOperand().getDefiningOp());
auto *loc = getTranslator().getLocation(op.getLoc());
auto *ctxt = getContext();
auto *type = getTranslator().convertType(op.getType());
return gcc_jit_context_new_bitcast(ctxt, loc, operand, type);
}

gcc_jit_rvalue *RegionVisitor::visitWithoutCache(PtrCallOp op) {
auto callee = visit(op.getCallee().getDefiningOp());
gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(PtrCallOp op) {
auto callee = visitExpr(op.getCallee().getDefiningOp());
llvm::SmallVector<gcc_jit_rvalue *> args;
visitAsRValue(op.getArgs(), args);
visitExprAsRValue(op.getArgs(), args);
auto *loc = getTranslator().getLocation(op.getLoc());
auto *ctxt = getContext();
auto *call = gcc_jit_context_new_call_through_ptr(ctxt, loc, callee,
Expand All @@ -688,20 +689,20 @@ gcc_jit_rvalue *RegionVisitor::visitWithoutCache(PtrCallOp op) {
return call;
}

gcc_jit_rvalue *RegionVisitor::visitWithoutCache(AddrOp op) {
auto lvalue = visit(op.getOperand().getDefiningOp());
gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(AddrOp op) {
auto lvalue = visitExpr(op.getOperand().getDefiningOp());
auto *loc = getTranslator().getLocation(op.getLoc());
return gcc_jit_lvalue_get_address(lvalue, loc);
}

gcc_jit_rvalue *RegionVisitor::visitWithoutCache(FnAddrOp op) {
gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(FnAddrOp op) {
auto *fn = getTranslator().getFunction(op.getCallee());
assert(fn && "function not found");
auto *loc = getTranslator().getLocation(op.getLoc());
return gcc_jit_function_get_address(fn, loc);
}

gcc_jit_lvalue *RegionVisitor::visitWithoutCache(GetGlobalOp op) {
gcc_jit_lvalue *RegionVisitor::visitExprWithoutCache(GetGlobalOp op) {
auto *lvalue = getTranslator().getGlobalLValue(op.getSym());
assert(lvalue && "global not found");
return lvalue;
Expand Down

0 comments on commit 9d373e4

Please sign in to comment.