Skip to content

Commit

Permalink
Chapter 4: JIT Support
Browse files Browse the repository at this point in the history
  • Loading branch information
srp-mx committed Dec 21, 2022
1 parent 99fdc00 commit e5c6d5f
Show file tree
Hide file tree
Showing 12 changed files with 228 additions and 23 deletions.
15 changes: 12 additions & 3 deletions run
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
BUILD_DIR="build"
PLATFORM="linux"

LLVM_COMPILE_FLAGS="llvm-config --cxxflags --ldflags --system-libs --libs core orcjit native"
EXTRA_COMPILE_FLAGS="-rdynamic"
COMPILE_OUTPUT="./${BUILD_DIR}/${PLATFORM}_kaleidoscope"

# Clean
echo "Cleaning..."
mkdir ./${BUILD_DIR}
rm -r ./${BUILD_DIR}/*
echo "Done cleaning."
echo ""

# Compile
echo "Compiling..."
clang++ -g -O3 ./src/${PLATFORM}_kaleidoscope.cpp `llvm-config --cxxflags --ldflags --system-libs --libs core` -o ./${BUILD_DIR}/${PLATFORM}_kaleidoscope
echo "Compiling for ${PLATFORM}..."
echo "LLVM FLAGS: ${LLVM_COMPILE_FLAGS}"
echo "EXTRA FLAGS: ${EXTRA_COMPILE_FLAGS}"
clang++ -g -O3 ./src/${PLATFORM}_kaleidoscope.cpp `${LLVM_COMPILE_FLAGS}` ${EXTRA_COMPILE_FLAGS} -o ${COMPILE_OUTPUT}
echo "Done compiling."
echo "OUTPUT: ${COMPILE_OUTPUT}"
echo ""

# Run
echo "Running..."
./${BUILD_DIR}/${PLATFORM}_kaleidoscope
eval ${COMPILE_OUTPUT}
echo ""
echo "Done running."
42 changes: 28 additions & 14 deletions src/ast/ast_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,28 @@
#include "./ast.cpp"
#include "../logging/ast_err.cpp"

#include <string>

llvm::Function *getFunction(std::string Name)
{
// First, see if the function has already been added to the current module.
if (auto *F = TheModule->getFunction(Name))
{
return F;
}

// If not, check whether we can codegen the declaration from some existing
// prototype.
auto FI = FunctionProtos.find(Name);
if (FI != FunctionProtos.end())
{
return FI->second->codegen();
}

// If no existing prototype exists, return null.
return nullptr;
}

llvm::Value*
NumberExprAST::codegen()
{
Expand Down Expand Up @@ -54,7 +76,7 @@ llvm::Value *
CallExprAST::codegen()
{
// Look up the name in the global module table.
llvm::Function *CalleeF = TheModule->getFunction(Callee);
llvm::Function *CalleeF = getFunction(Callee);
if (!CalleeF)
{
return LogErrorV("Unknown function referenced");
Expand Down Expand Up @@ -101,23 +123,15 @@ PrototypeAST::codegen()
llvm::Function *
FunctionAST::codegen()
{
// First, check for an existing function from a previous 'extern' declaration
llvm::Function *TheFunction = TheModule->getFunction(Proto->getName());

if (!TheFunction)
{
TheFunction = Proto->codegen();
}

// Transfer ownership of the prototype to the FunctionProtos map, but keep a
// reference to it for use below.
auto &P = *Proto;
FunctionProtos[Proto->getName()] = std::move(Proto);
llvm::Function *TheFunction = getFunction(P.getName());
if (!TheFunction)
{
return nullptr;
}

if (!TheFunction->empty())
{
return LogErrorF("Function cannot be redefined");
}

// Create a new basic block to start insertion into
llvm::BasicBlock *BB = llvm::BasicBlock::Create(*TheContext, "entry", TheFunction);
Expand Down
104 changes: 104 additions & 0 deletions src/jit/KaleidoscopeJIT.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
//===- KaleidoscopeJIT.h - A simple JIT for Kaleidoscope --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Contains a simple JIT definition for use in the kaleidoscope tutorials.
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
#define LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H

#include "llvm/ADT/StringRef.h"
#include "llvm/ExecutionEngine/JITSymbol.h"
#include "llvm/ExecutionEngine/Orc/CompileUtils.h"
#include "llvm/ExecutionEngine/Orc/Core.h"
#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h"
#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/LLVMContext.h"
#include <memory>

namespace llvm {
namespace orc {

class KaleidoscopeJIT {
private:
std::unique_ptr<ExecutionSession> ES;

DataLayout DL;
MangleAndInterner Mangle;

RTDyldObjectLinkingLayer ObjectLayer;
IRCompileLayer CompileLayer;

JITDylib &MainJD;

public:
KaleidoscopeJIT(std::unique_ptr<ExecutionSession> ES,
JITTargetMachineBuilder JTMB, DataLayout DL)
: ES(std::move(ES)), DL(std::move(DL)), Mangle(*this->ES, this->DL),
ObjectLayer(*this->ES,
[]() { return std::make_unique<SectionMemoryManager>(); }),
CompileLayer(*this->ES, ObjectLayer,
std::make_unique<ConcurrentIRCompiler>(std::move(JTMB))),
MainJD(this->ES->createBareJITDylib("<main>")) {
MainJD.addGenerator(
cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess(
DL.getGlobalPrefix())));
if (JTMB.getTargetTriple().isOSBinFormatCOFF()) {
ObjectLayer.setOverrideObjectFlagsWithResponsibilityFlags(true);
ObjectLayer.setAutoClaimResponsibilityForObjectSymbols(true);
}
}

~KaleidoscopeJIT() {
if (auto Err = ES->endSession())
ES->reportError(std::move(Err));
}

static Expected<std::unique_ptr<KaleidoscopeJIT>> Create() {
auto EPC = SelfExecutorProcessControl::Create();
if (!EPC)
return EPC.takeError();

auto ES = std::make_unique<ExecutionSession>(std::move(*EPC));

JITTargetMachineBuilder JTMB(
ES->getExecutorProcessControl().getTargetTriple());

auto DL = JTMB.getDefaultDataLayoutForTarget();
if (!DL)
return DL.takeError();

return std::make_unique<KaleidoscopeJIT>(std::move(ES), std::move(JTMB),
std::move(*DL));
}

const DataLayout &getDataLayout() const { return DL; }

JITDylib &getMainJITDylib() { return MainJD; }

Error addModule(ThreadSafeModule TSM, ResourceTrackerSP RT = nullptr) {
if (!RT)
RT = MainJD.getDefaultResourceTracker();
return CompileLayer.add(RT, std::move(TSM));
}

Expected<JITEvaluatedSymbol> lookup(StringRef Name) {
return ES->lookup({&MainJD}, Mangle(Name.str()));
}
};

} // end namespace orc
} // end namespace llvm

#endif // LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
45 changes: 40 additions & 5 deletions src/kaleidoscope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,23 @@

// NOTE(srp): Top-level parsing and JIT driver

internal void
InitializeJIT()
{
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
llvm::InitializeNativeTargetAsmParser();

TheJIT = ExitOnErr(llvmo::KaleidoscopeJIT::Create());
}

internal void
InitializeModule()
{
// Open a new context and module.
TheContext = std::make_unique<llvm::LLVMContext>();
TheModule = std::make_unique<llvm::Module>("my cool jit", *TheContext);
TheModule->setDataLayout(TheJIT->getDataLayout());

// Create a new builder for the module.
Builder = std::make_unique<llvm::IRBuilder<>>(*TheContext);
Expand Down Expand Up @@ -44,6 +55,7 @@ InitializePassManager()
internal void
InitializeLLVM()
{
InitializeJIT();
InitializeModule();
InitializePassManager();
}
Expand All @@ -55,9 +67,12 @@ HandleDefinition()
{
if (auto *FnIR = FnAST->codegen())
{
fprintf(stderr, "Read function definition:");
fprintf(stderr, "Read function definition:\n");
FnIR->print(llvm::errs());
fprintf(stderr, "\n");
ExitOnErr(TheJIT->addModule(llvmo::ThreadSafeModule(std::move(TheModule), std::move(TheContext))));
InitializeModule();
InitializePassManager();
}
}
else
Expand All @@ -75,9 +90,10 @@ HandleExtern()
{
if (auto *FnIR = ProtoAST->codegen())
{
fprintf(stderr, "Read extern: ");
fprintf(stderr, "Read extern: \n");
FnIR->print(llvm::errs());
fprintf(stderr, "\n");
FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
}
}
else
Expand All @@ -95,12 +111,31 @@ HandleTopLevelExpression()
{
if (auto *FnIR = FnAST->codegen())
{
fprintf(stderr, "Read top-level expression: ");
// Prints IR
fprintf(stderr, "Read top level expression: \n");
FnIR->print(llvm::errs());
fprintf(stderr, "\n");

// Remove the anonymous expression
FnIR->eraseFromParent();
// Create a ResourceTracker to track JIT'd memory allocated to our
// anonymous expression -- that way we can free it after executing.
auto RT = TheJIT->getMainJITDylib().createResourceTracker();

auto TSM = llvmo::ThreadSafeModule(std::move(TheModule), std::move(TheContext));
ExitOnErr(TheJIT->addModule(std::move(TSM), RT));
InitializeModule();
InitializePassManager();

// Search the JIT for the __anon_expr symbol.
auto ExprSymbol = ExitOnErr(TheJIT->lookup("{__anon_expr}"));
assert(ExprSymbol && "Function not found\n");

// Get the symbol's address and cast it to the right type (takes no
// arguments, returns a real64) so we can call it as a native function.
real64 (*FP)() = (real64 (*)())(intptr_t)ExprSymbol.getAddress();
fprintf(stderr, "\nEvaluated to %f\n", FP());

// Delete the anonymous expression module from the JIT.
ExitOnErr(RT->remove());
}
}
else
Expand Down
2 changes: 2 additions & 0 deletions src/linux_kaleidoscope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

#include "kaleidoscope.cpp"

#include "platform/externs/linux_extern_table.hpp"

int main()
{
// Install standard binary operators.
Expand Down
2 changes: 1 addition & 1 deletion src/parser/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ ParseTopLevelExpr()
if (auto E = ParseExpression())
{
// Make an anonymous proto.
auto Proto = std::make_unique<PrototypeAST>("", std::vector<std::string>());
auto Proto = std::make_unique<PrototypeAST>("{__anon_expr}", std::vector<std::string>());
return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
}
return nullptr;
Expand Down
3 changes: 3 additions & 0 deletions src/platform/externs/linux_extern_table.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#pragma once

#include "linux_putchard.cpp"
12 changes: 12 additions & 0 deletions src/platform/externs/linux_putchard.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include <stdio.h>
#include "../typedefs/typedefs.hpp"

/// putchard - putchar that takes a double and returns 0.
extern "C" real64
putchard(real64 X)
{
fputc((char)X, stderr);
return 0;
}
4 changes: 4 additions & 0 deletions src/platform/externs/win32_extern_table.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#pragma once

#include "win32_putchard.cpp"

12 changes: 12 additions & 0 deletions src/platform/externs/win32_putchard.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include "../typedefs/typedefs.hpp"

/// putchard - putchar that takes a double and returns 0.
extern "C" __declspec(dllexport) real64
putchard(real64 X)
{
fputc((char)X, stderr);
return 0;
}

2 changes: 2 additions & 0 deletions src/platform/llvm/linux_llvm_include.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/GVN.h"

#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
8 changes: 8 additions & 0 deletions src/platform/llvm/llvm_include.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,22 @@
#include "linux_llvm_include.hpp" // Default to Linux llvm .h files
#endif

#include "../../jit/KaleidoscopeJIT.h"

#include "../typedefs/typedefs.hpp"
#include <map>
#include <string>

namespace llvml = llvm::legacy;
namespace llvmo = llvm::orc;

class PrototypeAST;

global_variable std::unique_ptr<llvm::LLVMContext> TheContext;
global_variable std::unique_ptr<llvm::Module> TheModule;
global_variable std::unique_ptr<llvm::IRBuilder<>> Builder;
global_variable std::map<std::string, llvm::Value*> NamedValues;
global_variable std::unique_ptr<llvml::FunctionPassManager> TheFPM;
global_variable std::unique_ptr<llvmo::KaleidoscopeJIT> TheJIT;
global_variable std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
global_variable llvm::ExitOnError ExitOnErr;

0 comments on commit e5c6d5f

Please sign in to comment.