Skip to content

Commit 3b2a16f

Browse files
authored
Merge pull request RustPython#2164 from skinny121/jit-args
Support arguments in jitted functions
2 parents f98a381 + b79d7ed commit 3b2a16f

File tree

10 files changed

+406
-81
lines changed

10 files changed

+406
-81
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Lib/test/test_baseexception.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ def test_inheritance(self):
7777
last_depth = depth
7878
finally:
7979
inheritance_tree.close()
80+
81+
# RUSTPYTHON specific
82+
exc_set.discard("JitError")
83+
8084
self.assertEqual(len(exc_set), 0, "%s not accounted for" % exc_set)
8185

8286
interface_tests = ("length", "args", "str", "repr")

jit/src/instructions.rs

Lines changed: 49 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,44 +3,7 @@ use num_traits::cast::ToPrimitive;
33
use rustpython_bytecode::bytecode::{BinaryOperator, Constant, Instruction, NameScope};
44
use std::collections::HashMap;
55

6-
use super::JitCompileError;
7-
8-
#[derive(Default)]
9-
pub struct JitSig {
10-
pub ret: Option<JitType>,
11-
}
12-
13-
impl JitSig {
14-
pub fn to_cif(&self) -> libffi::middle::Cif {
15-
let ret = match self.ret {
16-
Some(ref ty) => ty.to_libffi(),
17-
None => libffi::middle::Type::void(),
18-
};
19-
libffi::middle::Cif::new(Vec::new(), ret)
20-
}
21-
}
22-
23-
#[derive(Clone, PartialEq)]
24-
pub enum JitType {
25-
Int,
26-
Float,
27-
}
28-
29-
impl JitType {
30-
fn to_cranelift(&self) -> types::Type {
31-
match self {
32-
Self::Int => types::I64,
33-
Self::Float => types::F64,
34-
}
35-
}
36-
37-
fn to_libffi(&self) -> libffi::middle::Type {
38-
match self {
39-
Self::Int => libffi::middle::Type::i64(),
40-
Self::Float => libffi::middle::Type::f64(),
41-
}
42-
}
43-
}
6+
use super::{JitCompileError, JitSig, JitType};
447

458
#[derive(Clone)]
469
struct Local {
@@ -53,20 +16,63 @@ struct JitValue {
5316
ty: JitType,
5417
}
5518

19+
impl JitValue {
20+
fn new(val: Value, ty: JitType) -> JitValue {
21+
JitValue { val, ty }
22+
}
23+
}
24+
5625
pub struct FunctionCompiler<'a, 'b> {
5726
builder: &'a mut FunctionBuilder<'b>,
5827
stack: Vec<JitValue>,
5928
variables: HashMap<String, Local>,
60-
pub sig: JitSig,
29+
pub(crate) sig: JitSig,
6130
}
6231

6332
impl<'a, 'b> FunctionCompiler<'a, 'b> {
64-
pub fn new(builder: &'a mut FunctionBuilder<'b>) -> FunctionCompiler<'a, 'b> {
65-
FunctionCompiler {
33+
pub fn new(
34+
builder: &'a mut FunctionBuilder<'b>,
35+
arg_names: &[String],
36+
arg_types: &[JitType],
37+
entry_block: Block,
38+
) -> FunctionCompiler<'a, 'b> {
39+
let mut compiler = FunctionCompiler {
6640
builder,
6741
stack: Vec::new(),
6842
variables: HashMap::new(),
69-
sig: JitSig::default(),
43+
sig: JitSig {
44+
args: arg_types.to_vec(),
45+
ret: None,
46+
},
47+
};
48+
let params = compiler.builder.func.dfg.block_params(entry_block).to_vec();
49+
debug_assert_eq!(arg_names.len(), arg_types.len());
50+
debug_assert_eq!(arg_names.len(), params.len());
51+
for ((name, ty), val) in arg_names.iter().zip(arg_types).zip(params) {
52+
compiler
53+
.store_variable(name.clone(), JitValue::new(val, ty.clone()))
54+
.unwrap();
55+
}
56+
compiler
57+
}
58+
59+
fn store_variable(&mut self, name: String, val: JitValue) -> Result<(), JitCompileError> {
60+
let len = self.variables.len();
61+
let builder = &mut self.builder;
62+
let local = self.variables.entry(name).or_insert_with(|| {
63+
let var = Variable::new(len);
64+
let local = Local {
65+
var,
66+
ty: val.ty.clone(),
67+
};
68+
builder.declare_var(var, val.ty.to_cranelift());
69+
local
70+
});
71+
if val.ty != local.ty {
72+
Err(JitCompileError::NotSupported)
73+
} else {
74+
self.builder.def_var(local.var, val.val);
75+
Ok(())
7076
}
7177
}
7278

@@ -91,22 +97,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
9197
scope: NameScope::Local,
9298
} => {
9399
let val = self.stack.pop().ok_or(JitCompileError::BadBytecode)?;
94-
let len = self.variables.len();
95-
let builder = &mut self.builder;
96-
let local = self.variables.entry(name.clone()).or_insert_with(|| {
97-
let var = Variable::new(len);
98-
let local = Local {
99-
var,
100-
ty: val.ty.clone(),
101-
};
102-
builder.declare_var(var, val.ty.to_cranelift());
103-
local
104-
});
105-
if val.ty != local.ty {
106-
return Err(JitCompileError::NotSupported);
107-
}
108-
self.builder.def_var(local.var, val.val);
109-
Ok(())
100+
self.store_variable(name.clone(), val)
110101
}
111102
Instruction::LoadConst {
112103
value: Constant::Integer { value },

jit/src/lib.rs

Lines changed: 122 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use rustpython_bytecode::bytecode;
88

99
mod instructions;
1010

11-
use instructions::{FunctionCompiler, JitSig, JitType};
11+
use instructions::FunctionCompiler;
1212

1313
#[derive(Debug, thiserror::Error)]
1414
pub enum JitCompileError {
@@ -20,6 +20,12 @@ pub enum JitCompileError {
2020
CraneliftError(#[from] ModuleError),
2121
}
2222

23+
#[derive(Debug, thiserror::Error)]
24+
pub enum JitArgumentError {
25+
#[error("argument is of wrong type")]
26+
ArgumentTypeMismatch,
27+
}
28+
2329
struct Jit {
2430
builder_context: FunctionBuilderContext,
2531
ctx: codegen::Context,
@@ -40,15 +46,26 @@ impl Jit {
4046
fn build_function(
4147
&mut self,
4248
bytecode: &bytecode::CodeObject,
49+
args: &[JitType],
4350
) -> Result<(FuncId, JitSig), JitCompileError> {
51+
for arg in args {
52+
self.ctx
53+
.func
54+
.signature
55+
.params
56+
.push(AbiParam::new(arg.to_cranelift()));
57+
}
58+
4459
let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context);
4560
let entry_block = builder.create_block();
46-
// builder.append_block_params_for_function_params(entry_block);
61+
builder.append_block_params_for_function_params(entry_block);
4762
builder.switch_to_block(entry_block);
4863
builder.seal_block(entry_block);
4964

5065
let sig = {
51-
let mut compiler = FunctionCompiler::new(&mut builder);
66+
let mut arg_names = bytecode.arg_names.clone();
67+
arg_names.extend(bytecode.kwonlyarg_names.iter().cloned());
68+
let mut compiler = FunctionCompiler::new(&mut builder, &arg_names, args, entry_block);
5269

5370
for instruction in &bytecode.instructions {
5471
compiler.add_instruction(instruction)?;
@@ -74,10 +91,13 @@ impl Jit {
7491
}
7592
}
7693

77-
pub fn compile(bytecode: &bytecode::CodeObject) -> Result<CompiledCode, JitCompileError> {
94+
pub fn compile(
95+
bytecode: &bytecode::CodeObject,
96+
args: &[JitType],
97+
) -> Result<CompiledCode, JitCompileError> {
7898
let mut jit = Jit::new();
7999

80-
let (id, sig) = jit.build_function(bytecode)?;
100+
let (id, sig) = jit.build_function(bytecode, args)?;
81101

82102
jit.module.finalize_definitions();
83103

@@ -96,23 +116,67 @@ pub struct CompiledCode {
96116
}
97117

98118
impl CompiledCode {
99-
pub fn invoke(&self) -> Option<AbiValue> {
119+
pub fn args_builder(&self) -> ArgsBuilder<'_> {
120+
ArgsBuilder::new(self)
121+
}
122+
123+
pub fn invoke<'a>(&self, args: &Args<'a>) -> Option<AbiValue> {
124+
debug_assert_eq!(self as *const _, args.code as *const _);
100125
let cif = self.sig.to_cif();
101126
unsafe {
102127
let value = cif.call::<UnTypedAbiValue>(
103128
libffi::middle::CodePtr::from_ptr(self.code as *const _),
104-
&[],
129+
&args.cif_args,
105130
);
106131
self.sig.ret.as_ref().map(|ty| value.to_typed(ty))
107132
}
108133
}
109134
}
110135

136+
struct JitSig {
137+
args: Vec<JitType>,
138+
ret: Option<JitType>,
139+
}
140+
141+
impl JitSig {
142+
fn to_cif(&self) -> libffi::middle::Cif {
143+
let ret = match self.ret {
144+
Some(ref ty) => ty.to_libffi(),
145+
None => libffi::middle::Type::void(),
146+
};
147+
libffi::middle::Cif::new(self.args.iter().map(JitType::to_libffi), ret)
148+
}
149+
}
150+
151+
#[derive(Clone, PartialEq)]
152+
pub enum JitType {
153+
Int,
154+
Float,
155+
}
156+
157+
impl JitType {
158+
fn to_cranelift(&self) -> types::Type {
159+
match self {
160+
Self::Int => types::I64,
161+
Self::Float => types::F64,
162+
}
163+
}
164+
165+
fn to_libffi(&self) -> libffi::middle::Type {
166+
match self {
167+
Self::Int => libffi::middle::Type::i64(),
168+
Self::Float => libffi::middle::Type::f64(),
169+
}
170+
}
171+
}
172+
173+
#[derive(Clone)]
111174
pub enum AbiValue {
112175
Float(f64),
113176
Int(i64),
114177
}
115178

179+
#[derive(Copy, Clone)]
116180
union UnTypedAbiValue {
117181
float: f64,
118182
int: i64,
@@ -143,3 +207,54 @@ impl fmt::Debug for CompiledCode {
143207
f.write_str("[compiled code]")
144208
}
145209
}
210+
211+
pub struct ArgsBuilder<'a> {
212+
values: Vec<Option<AbiValue>>,
213+
code: &'a CompiledCode,
214+
}
215+
216+
impl<'a> ArgsBuilder<'a> {
217+
fn new(code: &'a CompiledCode) -> ArgsBuilder<'a> {
218+
ArgsBuilder {
219+
values: vec![None; code.sig.args.len()],
220+
code,
221+
}
222+
}
223+
224+
pub fn set(&mut self, idx: usize, value: AbiValue) -> Result<(), JitArgumentError> {
225+
match (&self.code.sig.args[idx], &value) {
226+
(JitType::Int, AbiValue::Int(_)) | (JitType::Float, AbiValue::Float(_)) => {
227+
self.values[idx] = Some(value);
228+
Ok(())
229+
}
230+
_ => Err(JitArgumentError::ArgumentTypeMismatch),
231+
}
232+
}
233+
234+
pub fn is_set(&self, idx: usize) -> bool {
235+
self.values[idx].is_some()
236+
}
237+
238+
pub fn into_args(self) -> Option<Args<'a>> {
239+
self.values
240+
.iter()
241+
.map(|v| {
242+
v.as_ref().map(|v| match v {
243+
AbiValue::Int(ref i) => libffi::middle::Arg::new(i),
244+
AbiValue::Float(ref f) => libffi::middle::Arg::new(f),
245+
})
246+
})
247+
.collect::<Option<_>>()
248+
.map(|cif_args| Args {
249+
_values: self.values,
250+
cif_args,
251+
code: self.code,
252+
})
253+
}
254+
}
255+
256+
pub struct Args<'a> {
257+
_values: Vec<Option<AbiValue>>,
258+
cif_args: Vec<libffi::middle::Arg>,
259+
code: &'a CompiledCode,
260+
}

tests/snippets/jit.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,15 @@ def bar():
99
return a / 5.0
1010

1111

12+
def baz(a: int, b: int):
13+
return a + b + 12
14+
15+
1216
def tests():
1317
assert foo() == 15
1418
assert bar() == 2e5
19+
assert baz(17, 20) == 49
20+
assert baz(17, 22.5) == 51.5
1521

1622

1723
tests()
@@ -20,4 +26,5 @@ def tests():
2026
print("Has jit")
2127
foo.__jit__()
2228
bar.__jit__()
29+
baz.__jit__()
2330
tests()

vm/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ parking_lot = "0.11"
7777
thread_local = "1.0"
7878
cfg-if = "0.1.10"
7979
timsort = "0.1"
80+
thiserror = "1.0"
8081

8182
## unicode stuff
8283
unicode_names2 = "0.4"

vm/src/builtins.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -977,4 +977,9 @@ pub fn make_module(vm: &VirtualMachine, module: PyObjectRef) {
977977
"BytesWarning" => ctx.exceptions.bytes_warning.clone(),
978978
"ResourceWarning" => ctx.exceptions.resource_warning.clone(),
979979
});
980+
981+
#[cfg(feature = "jit")]
982+
extend_module!(vm, module, {
983+
"JitError" => ctx.exceptions.jit_error.clone(),
984+
});
980985
}

0 commit comments

Comments
 (0)