Skip to content

Commit 2c3dbb5

Browse files
authored
Merge pull request RustPython#5081 from MannarAmuthan/bool-compare
Implemented compare operation for boolean types in JIT engine
2 parents f1b3623 + eb83b72 commit 2c3dbb5

File tree

3 files changed

+215
-3
lines changed

3 files changed

+215
-3
lines changed

jit/src/instructions.rs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,18 +315,34 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
315315
let b = self.stack.pop().ok_or(JitCompileError::BadBytecode)?;
316316
let a = self.stack.pop().ok_or(JitCompileError::BadBytecode)?;
317317

318+
let a_type: Option<JitType> = a.to_jit_type();
319+
let b_type: Option<JitType> = b.to_jit_type();
320+
318321
match (a, b) {
319-
(JitValue::Int(a), JitValue::Int(b)) => {
322+
(JitValue::Int(a), JitValue::Int(b))
323+
| (JitValue::Bool(a), JitValue::Bool(b))
324+
| (JitValue::Bool(a), JitValue::Int(b))
325+
| (JitValue::Int(a), JitValue::Bool(b)) => {
326+
let operand_one = match a_type.unwrap() {
327+
JitType::Bool => self.builder.ins().uextend(types::I64, a),
328+
_ => a,
329+
};
330+
331+
let operand_two = match b_type.unwrap() {
332+
JitType::Bool => self.builder.ins().uextend(types::I64, b),
333+
_ => b,
334+
};
335+
320336
let cond = match op {
321337
ComparisonOperator::Equal => IntCC::Equal,
322338
ComparisonOperator::NotEqual => IntCC::NotEqual,
323339
ComparisonOperator::Less => IntCC::SignedLessThan,
324340
ComparisonOperator::LessOrEqual => IntCC::SignedLessThanOrEqual,
325341
ComparisonOperator::Greater => IntCC::SignedGreaterThan,
326-
ComparisonOperator::GreaterOrEqual => IntCC::SignedLessThanOrEqual,
342+
ComparisonOperator::GreaterOrEqual => IntCC::SignedGreaterThanOrEqual,
327343
};
328344

329-
let val = self.builder.ins().icmp(cond, a, b);
345+
let val = self.builder.ins().icmp(cond, operand_one, operand_two);
330346
// TODO: Remove this `bint` in cranelift 0.90 as icmp now returns i8
331347
self.stack
332348
.push(JitValue::Bool(self.builder.ins().bint(types::I8, val)));

jit/tests/bool_tests.rs

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,153 @@ fn test_if_not() {
5050
assert_eq!(if_not(true), Ok(1));
5151
assert_eq!(if_not(false), Ok(0));
5252
}
53+
54+
#[test]
55+
fn test_eq() {
56+
let eq = jit_function! { eq(a:bool, b:bool) -> i64 => r##"
57+
def eq(a: bool, b: bool):
58+
if a == b:
59+
return 1
60+
return 0
61+
"## };
62+
63+
assert_eq!(eq(false, false), Ok(1));
64+
assert_eq!(eq(true, true), Ok(1));
65+
assert_eq!(eq(false, true), Ok(0));
66+
assert_eq!(eq(true, false), Ok(0));
67+
}
68+
69+
#[test]
70+
fn test_eq_with_integers() {
71+
let eq = jit_function! { eq(a:bool, b:i64) -> i64 => r##"
72+
def eq(a: bool, b: int):
73+
if a == b:
74+
return 1
75+
return 0
76+
"## };
77+
78+
assert_eq!(eq(false, 0), Ok(1));
79+
assert_eq!(eq(true, 1), Ok(1));
80+
assert_eq!(eq(false, 1), Ok(0));
81+
assert_eq!(eq(true, 0), Ok(0));
82+
}
83+
84+
#[test]
85+
fn test_gt() {
86+
let gt = jit_function! { gt(a:bool, b:bool) -> i64 => r##"
87+
def gt(a: bool, b: bool):
88+
if a > b:
89+
return 1
90+
return 0
91+
"## };
92+
93+
assert_eq!(gt(false, false), Ok(0));
94+
assert_eq!(gt(true, true), Ok(0));
95+
assert_eq!(gt(false, true), Ok(0));
96+
assert_eq!(gt(true, false), Ok(1));
97+
}
98+
99+
#[test]
100+
fn test_gt_with_integers() {
101+
let gt = jit_function! { gt(a:i64, b:bool) -> i64 => r##"
102+
def gt(a: int, b: bool):
103+
if a > b:
104+
return 1
105+
return 0
106+
"## };
107+
108+
assert_eq!(gt(0, false), Ok(0));
109+
assert_eq!(gt(1, true), Ok(0));
110+
assert_eq!(gt(0, true), Ok(0));
111+
assert_eq!(gt(1, false), Ok(1));
112+
}
113+
114+
#[test]
115+
fn test_lt() {
116+
let lt = jit_function! { lt(a:bool, b:bool) -> i64 => r##"
117+
def lt(a: bool, b: bool):
118+
if a < b:
119+
return 1
120+
return 0
121+
"## };
122+
123+
assert_eq!(lt(false, false), Ok(0));
124+
assert_eq!(lt(true, true), Ok(0));
125+
assert_eq!(lt(false, true), Ok(1));
126+
assert_eq!(lt(true, false), Ok(0));
127+
}
128+
129+
#[test]
130+
fn test_lt_with_integers() {
131+
let lt = jit_function! { lt(a:i64, b:bool) -> i64 => r##"
132+
def lt(a: int, b: bool):
133+
if a < b:
134+
return 1
135+
return 0
136+
"## };
137+
138+
assert_eq!(lt(0, false), Ok(0));
139+
assert_eq!(lt(1, true), Ok(0));
140+
assert_eq!(lt(0, true), Ok(1));
141+
assert_eq!(lt(1, false), Ok(0));
142+
}
143+
144+
#[test]
145+
fn test_gte() {
146+
let gte = jit_function! { gte(a:bool, b:bool) -> i64 => r##"
147+
def gte(a: bool, b: bool):
148+
if a >= b:
149+
return 1
150+
return 0
151+
"## };
152+
153+
assert_eq!(gte(false, false), Ok(1));
154+
assert_eq!(gte(true, true), Ok(1));
155+
assert_eq!(gte(false, true), Ok(0));
156+
assert_eq!(gte(true, false), Ok(1));
157+
}
158+
159+
#[test]
160+
fn test_gte_with_integers() {
161+
let gte = jit_function! { gte(a:bool, b:i64) -> i64 => r##"
162+
def gte(a: bool, b: int):
163+
if a >= b:
164+
return 1
165+
return 0
166+
"## };
167+
168+
assert_eq!(gte(false, 0), Ok(1));
169+
assert_eq!(gte(true, 1), Ok(1));
170+
assert_eq!(gte(false, 1), Ok(0));
171+
assert_eq!(gte(true, 0), Ok(1));
172+
}
173+
174+
#[test]
175+
fn test_lte() {
176+
let lte = jit_function! { lte(a:bool, b:bool) -> i64 => r##"
177+
def lte(a: bool, b: bool):
178+
if a <= b:
179+
return 1
180+
return 0
181+
"## };
182+
183+
assert_eq!(lte(false, false), Ok(1));
184+
assert_eq!(lte(true, true), Ok(1));
185+
assert_eq!(lte(false, true), Ok(1));
186+
assert_eq!(lte(true, false), Ok(0));
187+
}
188+
189+
#[test]
190+
fn test_lte_with_integers() {
191+
let lte = jit_function! { lte(a:bool, b:i64) -> i64 => r##"
192+
def lte(a: bool, b: int):
193+
if a <= b:
194+
return 1
195+
return 0
196+
"## };
197+
198+
assert_eq!(lte(false, 0), Ok(1));
199+
assert_eq!(lte(true, 1), Ok(1));
200+
assert_eq!(lte(false, 1), Ok(1));
201+
assert_eq!(lte(true, 0), Ok(0));
202+
}

jit/tests/int_tests.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,52 @@ fn test_gt() {
160160
assert_eq!(gt(1, -1), Ok(1));
161161
}
162162

163+
#[test]
164+
fn test_lt() {
165+
let lt = jit_function! { lt(a:i64, b:i64) -> i64 => r##"
166+
def lt(a: int, b: int):
167+
if a < b:
168+
return 1
169+
return 0
170+
"## };
171+
172+
assert_eq!(lt(-1, -5), Ok(0));
173+
assert_eq!(lt(10, 0), Ok(0));
174+
assert_eq!(lt(0, 1), Ok(1));
175+
assert_eq!(lt(-10, -1), Ok(1));
176+
assert_eq!(lt(100, 100), Ok(0));
177+
}
178+
179+
#[test]
180+
fn test_gte() {
181+
let gte = jit_function! { gte(a:i64, b:i64) -> i64 => r##"
182+
def gte(a: int, b: int):
183+
if a >= b:
184+
return 1
185+
return 0
186+
"## };
187+
188+
assert_eq!(gte(-64, -64), Ok(1));
189+
assert_eq!(gte(100, -1), Ok(1));
190+
assert_eq!(gte(1, 2), Ok(0));
191+
assert_eq!(gte(1, 0), Ok(1));
192+
}
193+
194+
#[test]
195+
fn test_lte() {
196+
let lte = jit_function! { lte(a:i64, b:i64) -> i64 => r##"
197+
def lte(a: int, b: int):
198+
if a <= b:
199+
return 1
200+
return 0
201+
"## };
202+
203+
assert_eq!(lte(-100, -100), Ok(1));
204+
assert_eq!(lte(-100, 100), Ok(1));
205+
assert_eq!(lte(10, 1), Ok(0));
206+
assert_eq!(lte(0, -2), Ok(0));
207+
}
208+
163209
#[test]
164210
fn test_minus() {
165211
let minus = jit_function! { minus(a:i64) -> i64 => r##"

0 commit comments

Comments
 (0)