Skip to content

Commit

Permalink
Add a subtype to PassManager to prevent it from being run on the
Browse files Browse the repository at this point in the history
incorrect type which would have led to a memory error
  • Loading branch information
TheDan64 committed May 27, 2019
1 parent b95c72d commit b5d34a1
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 48 deletions.
8 changes: 4 additions & 4 deletions examples/kaleidoscope/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ impl<'a> Parser<'a> {
pub struct Compiler<'a> {
pub context: &'a Context,
pub builder: &'a Builder,
pub fpm: &'a PassManager,
pub fpm: &'a PassManager<FunctionValue>,
pub module: &'a Module,
pub function: &'a Function,

Expand Down Expand Up @@ -1146,7 +1146,7 @@ impl<'a> Compiler<'a> {

// return the whole thing after verification and optimization
if function.verify(true) {
self.fpm.run_on_function(&function);
self.fpm.run_on(&function);

Ok(function)
} else {
Expand All @@ -1159,7 +1159,7 @@ impl<'a> Compiler<'a> {
}

/// Compiles the specified `Function` in the given `Context` and using the specified `Builder`, `PassManager`, and `Module`.
pub fn compile(context: &'a Context, builder: &'a Builder, pass_manager: &'a PassManager, module: &'a Module, function: &Function) -> Result<FunctionValue, &'static str> {
pub fn compile(context: &'a Context, builder: &'a Builder, pass_manager: &'a PassManager<FunctionValue>, module: &'a Module, function: &Function) -> Result<FunctionValue, &'static str> {
let mut compiler = Compiler {
context: context,
builder: builder,
Expand Down Expand Up @@ -1228,7 +1228,7 @@ pub fn main() {
let builder = context.create_builder();

// Create FPM
let fpm = PassManager::create_for_function(&module);
let fpm = PassManager::create(&module);

fpm.add_instruction_combining_pass();
fpm.add_reassociate_pass();
Expand Down
6 changes: 4 additions & 2 deletions src/builder.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! A `Builder` enables you to build instructions.
use either::{Either, Left, Right};
use llvm_sys::core::{LLVMBuildAdd, LLVMBuildAlloca, LLVMBuildAnd, LLVMBuildArrayAlloca, LLVMBuildArrayMalloc, LLVMBuildBr, LLVMBuildCall, LLVMBuildCast, LLVMBuildCondBr, LLVMBuildExtractValue, LLVMBuildFAdd, LLVMBuildFCmp, LLVMBuildFDiv, LLVMBuildFence, LLVMBuildFMul, LLVMBuildFNeg, LLVMBuildFree, LLVMBuildFSub, LLVMBuildGEP, LLVMBuildICmp, LLVMBuildInsertValue, LLVMBuildIsNotNull, LLVMBuildIsNull, LLVMBuildLoad, LLVMBuildMalloc, LLVMBuildMul, LLVMBuildNeg, LLVMBuildNot, LLVMBuildOr, LLVMBuildPhi, LLVMBuildPointerCast, LLVMBuildRet, LLVMBuildRetVoid, LLVMBuildStore, LLVMBuildSub, LLVMBuildUDiv, LLVMBuildUnreachable, LLVMBuildXor, LLVMDisposeBuilder, LLVMGetElementType, LLVMGetInsertBlock, LLVMGetReturnType, LLVMGetTypeKind, LLVMInsertIntoBuilder, LLVMPositionBuilderAtEnd, LLVMTypeOf, LLVMBuildExtractElement, LLVMBuildInsertElement, LLVMBuildIntToPtr, LLVMBuildPtrToInt, LLVMInsertIntoBuilderWithName, LLVMClearInsertionPosition, LLVMCreateBuilder, LLVMPositionBuilder, LLVMPositionBuilderBefore, LLVMBuildAggregateRet, LLVMBuildStructGEP, LLVMBuildInBoundsGEP, LLVMBuildPtrDiff, LLVMBuildNSWAdd, LLVMBuildNUWAdd, LLVMBuildNSWSub, LLVMBuildNUWSub, LLVMBuildNSWMul, LLVMBuildNUWMul, LLVMBuildSDiv, LLVMBuildSRem, LLVMBuildURem, LLVMBuildFRem, LLVMBuildNSWNeg, LLVMBuildNUWNeg, LLVMBuildFPToUI, LLVMBuildFPToSI, LLVMBuildSIToFP, LLVMBuildUIToFP, LLVMBuildFPTrunc, LLVMBuildFPExt, LLVMBuildIntCast, LLVMBuildFPCast, LLVMBuildSExtOrBitCast, LLVMBuildZExtOrBitCast, LLVMBuildTruncOrBitCast, LLVMBuildSwitch, LLVMAddCase, LLVMBuildShl, LLVMBuildAShr, LLVMBuildLShr, LLVMBuildGlobalString, LLVMBuildGlobalStringPtr, LLVMBuildExactSDiv, LLVMBuildTrunc, LLVMBuildSExt, LLVMBuildZExt, LLVMBuildSelect, LLVMBuildAddrSpaceCast, LLVMBuildBitCast, LLVMBuildShuffleVector, LLVMBuildVAArg, LLVMBuildIndirectBr, LLVMAddDestination};
use llvm_sys::prelude::{LLVMBuilderRef, LLVMValueRef};
Expand All @@ -17,7 +19,7 @@ pub struct Builder {

impl Builder {
pub(crate) fn new(builder: LLVMBuilderRef) -> Self {
assert!(!builder.is_null());
debug_assert!(!builder.is_null());

Builder {
builder: builder
Expand Down Expand Up @@ -397,7 +399,7 @@ impl Builder {
},
None => unsafe {
LLVMInsertIntoBuilder(self.builder, instruction.as_value_ref());
}
},
}
}

Expand Down
139 changes: 107 additions & 32 deletions src/passes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ use crate::module::Module;
use crate::targets::TargetData;
use crate::values::{AsValueRef, FunctionValue};

use std::borrow::Borrow;
use std::marker::PhantomData;

// REVIEW: Opt Level might be identical to targets::Option<CodeGenOptLevel>
#[derive(Debug)]
pub struct PassManagerBuilder {
Expand Down Expand Up @@ -74,22 +77,72 @@ impl PassManagerBuilder {
}
}

// SubType: pass_manager: &PassManager<FunctionValue>
pub fn populate_function_pass_manager(&self, pass_manager: &PassManager) {
/// Populates a PassManager<FunctionValue> with the expectation of function
/// transformations.
///
/// # Example
///
/// ```
/// use inkwell::OptimizationLevel::Aggressive;
/// use inkwell::module::Module;
/// use inkwell::passes::{PassManager, PassManagerBuilder};
///
/// let module = Module::create("mod");
/// let pass_manager_builder = PassManagerBuilder::create();
///
/// pass_manager_builder.set_optimization_level(Aggressive);
///
/// let fpm = PassManager::create(&module);
///
/// pass_manager_builder.populate_function_pass_manager(&fpm);
/// ```
pub fn populate_function_pass_manager(&self, pass_manager: &PassManager<FunctionValue>) {
unsafe {
LLVMPassManagerBuilderPopulateFunctionPassManager(self.pass_manager_builder, pass_manager.pass_manager)
}
}

// SubType: pass_manager: &PassManager<Module>
pub fn populate_module_pass_manager(&self, pass_manager: &PassManager) {
/// Populates a PassManager<Module> with the expectation of whole module
/// transformations.
///
/// # Example
///
/// ```
/// use inkwell::OptimizationLevel::Aggressive;
/// use inkwell::passes::{PassManager, PassManagerBuilder};
///
/// let pass_manager_builder = PassManagerBuilder::create();
///
/// pass_manager_builder.set_optimization_level(Aggressive);
///
/// let fpm = PassManager::create(());
///
/// pass_manager_builder.populate_module_pass_manager(&fpm);
/// ```
pub fn populate_module_pass_manager(&self, pass_manager: &PassManager<Module>) {
unsafe {
LLVMPassManagerBuilderPopulateModulePassManager(self.pass_manager_builder, pass_manager.pass_manager)
}
}

// SubType: Need LTO subtype?
pub fn populate_lto_pass_manager(&self, pass_manager: &PassManager, internalize: bool, run_inliner: bool) {
/// Populates a PassManager<Module> with the expectation of link time
/// optimization transformations.
///
/// # Example
///
/// ```
/// use inkwell::OptimizationLevel::Aggressive;
/// use inkwell::passes::{PassManager, PassManagerBuilder};
///
/// let pass_manager_builder = PassManagerBuilder::create();
///
/// pass_manager_builder.set_optimization_level(Aggressive);
///
/// let lpm = PassManager::create(());
///
/// pass_manager_builder.populate_lto_pass_manager(&lpm, false, false);
/// ```
pub fn populate_lto_pass_manager(&self, pass_manager: &PassManager<Module>, internalize: bool, run_inliner: bool) {
unsafe {
LLVMPassManagerBuilderPopulateLTOPassManager(self.pass_manager_builder, pass_manager.pass_manager, internalize as i32, run_inliner as i32)
}
Expand All @@ -104,37 +157,65 @@ impl Drop for PassManagerBuilder {
}
}

// This is an ugly privacy hack so that PassManagerSubType can stay private
// to this module and so that super traits using this trait will be not be
// implementable outside this library
pub trait PassManagerSubType {
type Input;

unsafe fn create<I: Borrow<Self::Input>>(input: I) -> LLVMPassManagerRef;
unsafe fn run_in_pass_manager(&self, pass_manager: &PassManager<Self>) -> bool where Self: Sized;
}

impl PassManagerSubType for Module {
type Input = ();

unsafe fn create<I: Borrow<Self::Input>>(_: I) -> LLVMPassManagerRef {
LLVMCreatePassManager()
}

unsafe fn run_in_pass_manager(&self, pass_manager: &PassManager<Self>) -> bool {
LLVMRunPassManager(pass_manager.pass_manager, self.module.get()) == 1
}
}

// With GATs https://github.com/rust-lang/rust/issues/44265 this could be
// type Input<'a> = &'a Module;
impl PassManagerSubType for FunctionValue {
type Input = Module;

unsafe fn create<I: Borrow<Self::Input>>(input: I) -> LLVMPassManagerRef {
LLVMCreateFunctionPassManagerForModule(input.borrow().module.get())
}

unsafe fn run_in_pass_manager(&self, pass_manager: &PassManager<Self>) -> bool {
LLVMRunFunctionPassManager(pass_manager.pass_manager, self.as_value_ref()) == 1
}
}

// SubTypes: PassManager<Module>, PassManager<FunctionValue>
/// A manager for running optimization and simplification passes. Much of the
/// documenation for specific passes is directly from the [LLVM
/// documentation](https://llvm.org/docs/Passes.html).
#[derive(Debug)]
pub struct PassManager {
pub struct PassManager<T> {
pub(crate) pass_manager: LLVMPassManagerRef,
sub_type: PhantomData<T>,
}

impl PassManager {
pub(crate) fn new(pass_manager: LLVMPassManagerRef) -> PassManager {
impl<T: PassManagerSubType> PassManager<T> {
pub(crate) fn new(pass_manager: LLVMPassManagerRef) -> Self {
assert!(!pass_manager.is_null());

PassManager {
pass_manager,
sub_type: PhantomData,
}
}

// SubTypes: PassManager<Module>::create()
pub fn create_for_module() -> Self {
let pass_manager = unsafe {
LLVMCreatePassManager()
};

PassManager::new(pass_manager)
}

// SubTypes: PassManager<FunctionValue>::create()
pub fn create_for_function(module: &Module) -> Self {
pub fn create<I: Borrow<T::Input>>(input: I) -> PassManager<T> {
let pass_manager = unsafe {
LLVMCreateFunctionPassManagerForModule(module.module.get())
T::create(input)
};

PassManager::new(pass_manager)
Expand All @@ -153,17 +234,11 @@ impl PassManager {
}
}

// SubTypes: For PassManager<FunctionValue> only, rename run_on
pub fn run_on_function(&self, fn_value: &FunctionValue) -> bool {
unsafe {
LLVMRunFunctionPassManager(self.pass_manager, fn_value.as_value_ref()) == 1
}
}

// SubTypes: For PassManager<Module> only, rename run_on
pub fn run_on_module(&self, module: &Module) -> bool {
/// This method returns true if any of the passes modified the module and
/// false otherwise.
pub fn run_on(&self, input: &T) -> bool {
unsafe {
LLVMRunPassManager(self.pass_manager, module.module.get()) == 1
input.run_in_pass_manager(self)
}
}

Expand Down Expand Up @@ -984,7 +1059,7 @@ impl PassManager {
}
}

impl Drop for PassManager {
impl<T> Drop for PassManager<T> {
fn drop(&mut self) {
unsafe {
LLVMDisposePassManager(self.pass_manager)
Expand Down
2 changes: 1 addition & 1 deletion src/targets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ impl TargetMachine {
}

// TODO: Move to PassManager?
pub fn add_analysis_passes(&self, pass_manager: &PassManager) {
pub fn add_analysis_passes<T>(&self, pass_manager: &PassManager<T>) {
unsafe {
LLVMAddAnalysisPasses(self.target_machine, pass_manager.pass_manager)
}
Expand Down
28 changes: 19 additions & 9 deletions tests/all/test_passes.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
extern crate inkwell;

use self::inkwell::OptimizationLevel::Aggressive;
use self::inkwell::context::Context;
use self::inkwell::passes::{PassManagerBuilder, PassManager, PassRegistry};
use self::inkwell::OptimizationLevel::Aggressive;

#[test]
fn test_init_all_passes_for_module() {
let context = Context::create();
let module = context.create_module("my_module");
let pass_manager = PassManager::create_for_module();
let pass_manager = PassManager::create(());

pass_manager.add_argument_promotion_pass();
pass_manager.add_constant_merge_pass();
Expand Down Expand Up @@ -85,7 +85,7 @@ fn test_init_all_passes_for_module() {
assert!(!pass_manager.initialize());
assert!(!pass_manager.finalize());

pass_manager.run_on_module(&module);
pass_manager.run_on(&module);

assert!(!pass_manager.initialize());
assert!(!pass_manager.finalize());
Expand All @@ -107,7 +107,7 @@ fn test_pass_manager_builder() {
let context = Context::create();
let module = context.create_module("my_module");

let fn_pass_manager = PassManager::create_for_function(&module);
let fn_pass_manager = PassManager::create(&module);

pass_manager_builder.populate_function_pass_manager(&fn_pass_manager);

Expand All @@ -123,20 +123,30 @@ fn test_pass_manager_builder() {
// TODO: Test with actual changes? Would be true in that case
// REVIEW: Segfaults in 4.0
#[cfg(not(feature = "llvm4-0"))]
assert!(!fn_pass_manager.run_on_function(&fn_value));
assert!(!fn_pass_manager.run_on(&fn_value));

let module_pass_manager = PassManager::create_for_module();
let module_pass_manager = PassManager::create(());

pass_manager_builder.populate_module_pass_manager(&module_pass_manager);

let module2 = module.clone();

// TODOC: Seems to return true in 3.7, 6.0, & 7.0 even though no changes were made.
// In 3.6, 3.8, & 3.9 it returns false. Seems like a LLVM bug?
#[cfg(not(any(feature = "llvm3-7", feature = "llvm6-0", feature = "llvm7-0")))]
assert!(!module_pass_manager.run_on_module(&module));
assert!(!module_pass_manager.run_on(&module));
#[cfg(any(feature = "llvm3-7", feature = "llvm6-0", feature = "llvm7-0"))]
assert!(module_pass_manager.run_on_module(&module));
assert!(module_pass_manager.run_on(&module));

let lto_pass_manager = PassManager::create(());

// TODO: Populate LTO pass manager?
pass_manager_builder.populate_lto_pass_manager(&lto_pass_manager, false, false);

// See above note on version differences
#[cfg(not(any(feature = "llvm3-7", feature = "llvm6-0", feature = "llvm7-0")))]
assert!(!lto_pass_manager.run_on(&module2));
#[cfg(any(feature = "llvm3-7", feature = "llvm6-0", feature = "llvm7-0"))]
assert!(lto_pass_manager.run_on(&module2));
}

#[test]
Expand Down

0 comments on commit b5d34a1

Please sign in to comment.