From e5c6d5ff011ddd6d395c8edf10ad31ed8086ea4b Mon Sep 17 00:00:00 2001 From: Santiago Romero Date: Tue, 20 Dec 2022 22:51:41 -0600 Subject: [PATCH] Chapter 4: JIT Support --- run | 15 ++- src/ast/ast_codegen.cpp | 42 +++++--- src/jit/KaleidoscopeJIT.h | 104 ++++++++++++++++++++ src/kaleidoscope.cpp | 45 ++++++++- src/linux_kaleidoscope.cpp | 2 + src/parser/parser.cpp | 2 +- src/platform/externs/linux_extern_table.hpp | 3 + src/platform/externs/linux_putchard.cpp | 12 +++ src/platform/externs/win32_extern_table.hpp | 4 + src/platform/externs/win32_putchard.cpp | 12 +++ src/platform/llvm/linux_llvm_include.hpp | 2 + src/platform/llvm/llvm_include.hpp | 8 ++ 12 files changed, 228 insertions(+), 23 deletions(-) create mode 100644 src/jit/KaleidoscopeJIT.h create mode 100644 src/platform/externs/linux_extern_table.hpp create mode 100644 src/platform/externs/linux_putchard.cpp create mode 100644 src/platform/externs/win32_extern_table.hpp create mode 100644 src/platform/externs/win32_putchard.cpp diff --git a/run b/run index 8f4b5f3..f8f0f39 100755 --- a/run +++ b/run @@ -1,19 +1,28 @@ BUILD_DIR="build" PLATFORM="linux" +LLVM_COMPILE_FLAGS="llvm-config --cxxflags --ldflags --system-libs --libs core orcjit native" +EXTRA_COMPILE_FLAGS="-rdynamic" +COMPILE_OUTPUT="./${BUILD_DIR}/${PLATFORM}_kaleidoscope" + # Clean echo "Cleaning..." mkdir ./${BUILD_DIR} rm -r ./${BUILD_DIR}/* echo "Done cleaning." +echo "" # Compile -echo "Compiling..." -clang++ -g -O3 ./src/${PLATFORM}_kaleidoscope.cpp `llvm-config --cxxflags --ldflags --system-libs --libs core` -o ./${BUILD_DIR}/${PLATFORM}_kaleidoscope +echo "Compiling for ${PLATFORM}..." +echo "LLVM FLAGS: ${LLVM_COMPILE_FLAGS}" +echo "EXTRA FLAGS: ${EXTRA_COMPILE_FLAGS}" +clang++ -g -O3 ./src/${PLATFORM}_kaleidoscope.cpp `${LLVM_COMPILE_FLAGS}` ${EXTRA_COMPILE_FLAGS} -o ${COMPILE_OUTPUT} echo "Done compiling." +echo "OUTPUT: ${COMPILE_OUTPUT}" +echo "" # Run echo "Running..." -./${BUILD_DIR}/${PLATFORM}_kaleidoscope +eval ${COMPILE_OUTPUT} echo "" echo "Done running." diff --git a/src/ast/ast_codegen.cpp b/src/ast/ast_codegen.cpp index 109c510..fe2bbec 100644 --- a/src/ast/ast_codegen.cpp +++ b/src/ast/ast_codegen.cpp @@ -6,6 +6,28 @@ #include "./ast.cpp" #include "../logging/ast_err.cpp" +#include + +llvm::Function *getFunction(std::string Name) +{ + // First, see if the function has already been added to the current module. + if (auto *F = TheModule->getFunction(Name)) + { + return F; + } + + // If not, check whether we can codegen the declaration from some existing + // prototype. + auto FI = FunctionProtos.find(Name); + if (FI != FunctionProtos.end()) + { + return FI->second->codegen(); + } + + // If no existing prototype exists, return null. + return nullptr; +} + llvm::Value* NumberExprAST::codegen() { @@ -54,7 +76,7 @@ llvm::Value * CallExprAST::codegen() { // Look up the name in the global module table. - llvm::Function *CalleeF = TheModule->getFunction(Callee); + llvm::Function *CalleeF = getFunction(Callee); if (!CalleeF) { return LogErrorV("Unknown function referenced"); @@ -101,23 +123,15 @@ PrototypeAST::codegen() llvm::Function * FunctionAST::codegen() { - // First, check for an existing function from a previous 'extern' declaration - llvm::Function *TheFunction = TheModule->getFunction(Proto->getName()); - - if (!TheFunction) - { - TheFunction = Proto->codegen(); - } - + // Transfer ownership of the prototype to the FunctionProtos map, but keep a + // reference to it for use below. + auto &P = *Proto; + FunctionProtos[Proto->getName()] = std::move(Proto); + llvm::Function *TheFunction = getFunction(P.getName()); if (!TheFunction) { return nullptr; } - - if (!TheFunction->empty()) - { - return LogErrorF("Function cannot be redefined"); - } // Create a new basic block to start insertion into llvm::BasicBlock *BB = llvm::BasicBlock::Create(*TheContext, "entry", TheFunction); diff --git a/src/jit/KaleidoscopeJIT.h b/src/jit/KaleidoscopeJIT.h new file mode 100644 index 0000000..457a2d4 --- /dev/null +++ b/src/jit/KaleidoscopeJIT.h @@ -0,0 +1,104 @@ +//===- KaleidoscopeJIT.h - A simple JIT for Kaleidoscope --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Contains a simple JIT definition for use in the kaleidoscope tutorials. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H +#define LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H + +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/CompileUtils.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" +#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" +#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/SectionMemoryManager.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/LLVMContext.h" +#include + +namespace llvm { +namespace orc { + +class KaleidoscopeJIT { +private: + std::unique_ptr ES; + + DataLayout DL; + MangleAndInterner Mangle; + + RTDyldObjectLinkingLayer ObjectLayer; + IRCompileLayer CompileLayer; + + JITDylib &MainJD; + +public: + KaleidoscopeJIT(std::unique_ptr ES, + JITTargetMachineBuilder JTMB, DataLayout DL) + : ES(std::move(ES)), DL(std::move(DL)), Mangle(*this->ES, this->DL), + ObjectLayer(*this->ES, + []() { return std::make_unique(); }), + CompileLayer(*this->ES, ObjectLayer, + std::make_unique(std::move(JTMB))), + MainJD(this->ES->createBareJITDylib("
")) { + MainJD.addGenerator( + cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess( + DL.getGlobalPrefix()))); + if (JTMB.getTargetTriple().isOSBinFormatCOFF()) { + ObjectLayer.setOverrideObjectFlagsWithResponsibilityFlags(true); + ObjectLayer.setAutoClaimResponsibilityForObjectSymbols(true); + } + } + + ~KaleidoscopeJIT() { + if (auto Err = ES->endSession()) + ES->reportError(std::move(Err)); + } + + static Expected> Create() { + auto EPC = SelfExecutorProcessControl::Create(); + if (!EPC) + return EPC.takeError(); + + auto ES = std::make_unique(std::move(*EPC)); + + JITTargetMachineBuilder JTMB( + ES->getExecutorProcessControl().getTargetTriple()); + + auto DL = JTMB.getDefaultDataLayoutForTarget(); + if (!DL) + return DL.takeError(); + + return std::make_unique(std::move(ES), std::move(JTMB), + std::move(*DL)); + } + + const DataLayout &getDataLayout() const { return DL; } + + JITDylib &getMainJITDylib() { return MainJD; } + + Error addModule(ThreadSafeModule TSM, ResourceTrackerSP RT = nullptr) { + if (!RT) + RT = MainJD.getDefaultResourceTracker(); + return CompileLayer.add(RT, std::move(TSM)); + } + + Expected lookup(StringRef Name) { + return ES->lookup({&MainJD}, Mangle(Name.str())); + } +}; + +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H diff --git a/src/kaleidoscope.cpp b/src/kaleidoscope.cpp index c2d6ff1..369e5ea 100644 --- a/src/kaleidoscope.cpp +++ b/src/kaleidoscope.cpp @@ -11,12 +11,23 @@ // NOTE(srp): Top-level parsing and JIT driver +internal void +InitializeJIT() +{ + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + llvm::InitializeNativeTargetAsmParser(); + + TheJIT = ExitOnErr(llvmo::KaleidoscopeJIT::Create()); +} + internal void InitializeModule() { // Open a new context and module. TheContext = std::make_unique(); TheModule = std::make_unique("my cool jit", *TheContext); + TheModule->setDataLayout(TheJIT->getDataLayout()); // Create a new builder for the module. Builder = std::make_unique>(*TheContext); @@ -44,6 +55,7 @@ InitializePassManager() internal void InitializeLLVM() { + InitializeJIT(); InitializeModule(); InitializePassManager(); } @@ -55,9 +67,12 @@ HandleDefinition() { if (auto *FnIR = FnAST->codegen()) { - fprintf(stderr, "Read function definition:"); + fprintf(stderr, "Read function definition:\n"); FnIR->print(llvm::errs()); fprintf(stderr, "\n"); + ExitOnErr(TheJIT->addModule(llvmo::ThreadSafeModule(std::move(TheModule), std::move(TheContext)))); + InitializeModule(); + InitializePassManager(); } } else @@ -75,9 +90,10 @@ HandleExtern() { if (auto *FnIR = ProtoAST->codegen()) { - fprintf(stderr, "Read extern: "); + fprintf(stderr, "Read extern: \n"); FnIR->print(llvm::errs()); fprintf(stderr, "\n"); + FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST); } } else @@ -95,12 +111,31 @@ HandleTopLevelExpression() { if (auto *FnIR = FnAST->codegen()) { - fprintf(stderr, "Read top-level expression: "); + // Prints IR + fprintf(stderr, "Read top level expression: \n"); FnIR->print(llvm::errs()); fprintf(stderr, "\n"); - // Remove the anonymous expression - FnIR->eraseFromParent(); + // Create a ResourceTracker to track JIT'd memory allocated to our + // anonymous expression -- that way we can free it after executing. + auto RT = TheJIT->getMainJITDylib().createResourceTracker(); + + auto TSM = llvmo::ThreadSafeModule(std::move(TheModule), std::move(TheContext)); + ExitOnErr(TheJIT->addModule(std::move(TSM), RT)); + InitializeModule(); + InitializePassManager(); + + // Search the JIT for the __anon_expr symbol. + auto ExprSymbol = ExitOnErr(TheJIT->lookup("{__anon_expr}")); + assert(ExprSymbol && "Function not found\n"); + + // Get the symbol's address and cast it to the right type (takes no + // arguments, returns a real64) so we can call it as a native function. + real64 (*FP)() = (real64 (*)())(intptr_t)ExprSymbol.getAddress(); + fprintf(stderr, "\nEvaluated to %f\n", FP()); + + // Delete the anonymous expression module from the JIT. + ExitOnErr(RT->remove()); } } else diff --git a/src/linux_kaleidoscope.cpp b/src/linux_kaleidoscope.cpp index 0d8dd19..853ba99 100644 --- a/src/linux_kaleidoscope.cpp +++ b/src/linux_kaleidoscope.cpp @@ -13,6 +13,8 @@ #include "kaleidoscope.cpp" +#include "platform/externs/linux_extern_table.hpp" + int main() { // Install standard binary operators. diff --git a/src/parser/parser.cpp b/src/parser/parser.cpp index 36fedb2..3e77dee 100644 --- a/src/parser/parser.cpp +++ b/src/parser/parser.cpp @@ -289,7 +289,7 @@ ParseTopLevelExpr() if (auto E = ParseExpression()) { // Make an anonymous proto. - auto Proto = std::make_unique("", std::vector()); + auto Proto = std::make_unique("{__anon_expr}", std::vector()); return std::make_unique(std::move(Proto), std::move(E)); } return nullptr; diff --git a/src/platform/externs/linux_extern_table.hpp b/src/platform/externs/linux_extern_table.hpp new file mode 100644 index 0000000..1e8242d --- /dev/null +++ b/src/platform/externs/linux_extern_table.hpp @@ -0,0 +1,3 @@ +#pragma once + +#include "linux_putchard.cpp" diff --git a/src/platform/externs/linux_putchard.cpp b/src/platform/externs/linux_putchard.cpp new file mode 100644 index 0000000..2df74f4 --- /dev/null +++ b/src/platform/externs/linux_putchard.cpp @@ -0,0 +1,12 @@ +#pragma once + +#include +#include "../typedefs/typedefs.hpp" + +/// putchard - putchar that takes a double and returns 0. +extern "C" real64 +putchard(real64 X) +{ + fputc((char)X, stderr); + return 0; +} diff --git a/src/platform/externs/win32_extern_table.hpp b/src/platform/externs/win32_extern_table.hpp new file mode 100644 index 0000000..a054640 --- /dev/null +++ b/src/platform/externs/win32_extern_table.hpp @@ -0,0 +1,4 @@ +#pragma once + +#include "win32_putchard.cpp" + diff --git a/src/platform/externs/win32_putchard.cpp b/src/platform/externs/win32_putchard.cpp new file mode 100644 index 0000000..cec7817 --- /dev/null +++ b/src/platform/externs/win32_putchard.cpp @@ -0,0 +1,12 @@ +#pragma once + +#include "../typedefs/typedefs.hpp" + +/// putchard - putchar that takes a double and returns 0. +extern "C" __declspec(dllexport) real64 +putchard(real64 X) +{ + fputc((char)X, stderr); + return 0; +} + diff --git a/src/platform/llvm/linux_llvm_include.hpp b/src/platform/llvm/linux_llvm_include.hpp index d826cb1..ea54d54 100644 --- a/src/platform/llvm/linux_llvm_include.hpp +++ b/src/platform/llvm/linux_llvm_include.hpp @@ -20,3 +20,5 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/GVN.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Target/TargetMachine.h" diff --git a/src/platform/llvm/llvm_include.hpp b/src/platform/llvm/llvm_include.hpp index 17d6b21..567cf0c 100644 --- a/src/platform/llvm/llvm_include.hpp +++ b/src/platform/llvm/llvm_include.hpp @@ -7,14 +7,22 @@ #include "linux_llvm_include.hpp" // Default to Linux llvm .h files #endif +#include "../../jit/KaleidoscopeJIT.h" + #include "../typedefs/typedefs.hpp" #include #include namespace llvml = llvm::legacy; +namespace llvmo = llvm::orc; + +class PrototypeAST; global_variable std::unique_ptr TheContext; global_variable std::unique_ptr TheModule; global_variable std::unique_ptr> Builder; global_variable std::map NamedValues; global_variable std::unique_ptr TheFPM; +global_variable std::unique_ptr TheJIT; +global_variable std::map> FunctionProtos; +global_variable llvm::ExitOnError ExitOnErr;