Skip to content

Commit d83cfb4

Browse files
Merge pull request RustPython#1234 from RustPython/short-circuit-evaluation
Improve the situation regarding boolean operations.
2 parents 0129bb5 + 36d8147 commit d83cfb4

File tree

8 files changed

+191
-96
lines changed

8 files changed

+191
-96
lines changed

bytecode/src/bytecode.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,24 @@ pub enum Instruction {
138138
Jump {
139139
target: Label,
140140
},
141-
JumpIf {
141+
/// Pop the top of the stack, and jump if this value is true.
142+
JumpIfTrue {
142143
target: Label,
143144
},
145+
/// Pop the top of the stack, and jump if this value is false.
144146
JumpIfFalse {
145147
target: Label,
146148
},
149+
/// Peek at the top of the stack, and jump if this value is true.
150+
/// Otherwise, pop top of stack.
151+
JumpIfTrueOrPop {
152+
target: Label,
153+
},
154+
/// Peek at the top of the stack, and jump if this value is false.
155+
/// Otherwise, pop top of stack.
156+
JumpIfFalseOrPop {
157+
target: Label,
158+
},
147159
MakeFunction {
148160
flags: FunctionOpArg,
149161
},
@@ -411,8 +423,10 @@ impl Instruction {
411423
Continue => w!(Continue),
412424
Break => w!(Break),
413425
Jump { target } => w!(Jump, label_map[target]),
414-
JumpIf { target } => w!(JumpIf, label_map[target]),
426+
JumpIfTrue { target } => w!(JumpIfTrue, label_map[target]),
415427
JumpIfFalse { target } => w!(JumpIfFalse, label_map[target]),
428+
JumpIfTrueOrPop { target } => w!(JumpIfTrueOrPop, label_map[target]),
429+
JumpIfFalseOrPop { target } => w!(JumpIfFalseOrPop, label_map[target]),
416430
MakeFunction { flags } => w!(MakeFunction, format!("{:?}", flags)),
417431
CallFunction { typ } => w!(CallFunction, format!("{:?}", typ)),
418432
ForIter { target } => w!(ForIter, label_map[target]),

compiler/src/compile.rs

Lines changed: 110 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use rustpython_parser::{ast, parser};
1515

1616
type BasicOutputStream = PeepholeOptimizer<CodeObjectStream>;
1717

18+
/// Main structure holding the state of compilation.
1819
struct Compiler<O: OutputStream = BasicOutputStream> {
1920
output_stack: Vec<O>,
2021
scope_stack: Vec<SymbolScope>,
@@ -107,12 +108,6 @@ pub enum Mode {
107108
Single,
108109
}
109110

110-
#[derive(Clone, Copy)]
111-
enum EvalContext {
112-
Statement,
113-
Expression,
114-
}
115-
116111
pub(crate) type Label = usize;
117112

118113
impl<O> Default for Compiler<O>
@@ -350,14 +345,14 @@ impl<O: OutputStream> Compiler<O> {
350345
match orelse {
351346
None => {
352347
// Only if:
353-
self.compile_test(test, None, Some(end_label), EvalContext::Statement)?;
348+
self.compile_jump_if(test, false, end_label)?;
354349
self.compile_statements(body)?;
355350
self.set_label(end_label);
356351
}
357352
Some(statements) => {
358353
// if - else:
359354
let else_label = self.new_label();
360-
self.compile_test(test, None, Some(else_label), EvalContext::Statement)?;
355+
self.compile_jump_if(test, false, else_label)?;
361356
self.compile_statements(body)?;
362357
self.emit(Instruction::Jump { target: end_label });
363358

@@ -459,7 +454,7 @@ impl<O: OutputStream> Compiler<O> {
459454
// if some flag, ignore all assert statements!
460455
if self.optimize == 0 {
461456
let end_label = self.new_label();
462-
self.compile_test(test, Some(end_label), None, EvalContext::Statement)?;
457+
self.compile_jump_if(test, true, end_label)?;
463458
self.emit(Instruction::LoadName {
464459
name: String::from("AssertionError"),
465460
scope: bytecode::NameScope::Local,
@@ -1006,7 +1001,7 @@ impl<O: OutputStream> Compiler<O> {
10061001

10071002
self.set_label(start_label);
10081003

1009-
self.compile_test(test, None, Some(else_label), EvalContext::Statement)?;
1004+
self.compile_jump_if(test, false, else_label)?;
10101005

10111006
let was_in_loop = self.in_loop;
10121007
self.in_loop = true;
@@ -1118,12 +1113,9 @@ impl<O: OutputStream> Compiler<O> {
11181113
});
11191114

11201115
// if comparison result is false, we break with this value; if true, try the next one.
1121-
// (CPython compresses these three opcodes into JUMP_IF_FALSE_OR_POP)
1122-
self.emit(Instruction::Duplicate);
1123-
self.emit(Instruction::JumpIfFalse {
1116+
self.emit(Instruction::JumpIfFalseOrPop {
11241117
target: break_label,
11251118
});
1126-
self.emit(Instruction::Pop);
11271119
}
11281120

11291121
// handle the last comparison
@@ -1256,66 +1248,120 @@ impl<O: OutputStream> Compiler<O> {
12561248
self.emit(Instruction::BinaryOperation { op: i, inplace });
12571249
}
12581250

1259-
fn compile_test(
1251+
/// Implement boolean short circuit evaluation logic.
1252+
/// https://en.wikipedia.org/wiki/Short-circuit_evaluation
1253+
///
1254+
/// This means, in a boolean statement 'x and y' the variable y will
1255+
/// not be evaluated when x is false.
1256+
///
1257+
/// The idea is to jump to a label if the expression is either true or false
1258+
/// (indicated by the condition parameter).
1259+
fn compile_jump_if(
12601260
&mut self,
12611261
expression: &ast::Expression,
1262-
true_label: Option<Label>,
1263-
false_label: Option<Label>,
1264-
context: EvalContext,
1262+
condition: bool,
1263+
target_label: Label,
12651264
) -> Result<(), CompileError> {
12661265
// Compile expression for test, and jump to label if false
12671266
match &expression.node {
1268-
ast::ExpressionType::BoolOp { a, op, b } => match op {
1269-
ast::BooleanOperator::And => {
1270-
let f = false_label.unwrap_or_else(|| self.new_label());
1271-
self.compile_test(a, None, Some(f), context)?;
1272-
self.compile_test(b, true_label, false_label, context)?;
1273-
if false_label.is_none() {
1274-
self.set_label(f);
1267+
ast::ExpressionType::BoolOp { op, values } => {
1268+
match op {
1269+
ast::BooleanOperator::And => {
1270+
if condition {
1271+
// If all values are true.
1272+
let end_label = self.new_label();
1273+
let (last_value, values) = values.split_last().unwrap();
1274+
1275+
// If any of the values is false, we can short-circuit.
1276+
for value in values {
1277+
self.compile_jump_if(value, false, end_label)?;
1278+
}
1279+
1280+
// It depends upon the last value now: will it be true?
1281+
self.compile_jump_if(last_value, true, target_label)?;
1282+
self.set_label(end_label);
1283+
} else {
1284+
// If any value is false, the whole condition is false.
1285+
for value in values {
1286+
self.compile_jump_if(value, false, target_label)?;
1287+
}
1288+
}
12751289
}
1276-
}
1277-
ast::BooleanOperator::Or => {
1278-
let t = true_label.unwrap_or_else(|| self.new_label());
1279-
self.compile_test(a, Some(t), None, context)?;
1280-
self.compile_test(b, true_label, false_label, context)?;
1281-
if true_label.is_none() {
1282-
self.set_label(t);
1290+
ast::BooleanOperator::Or => {
1291+
if condition {
1292+
// If any of the values is true.
1293+
for value in values {
1294+
self.compile_jump_if(value, true, target_label)?;
1295+
}
1296+
} else {
1297+
// If all of the values are false.
1298+
let end_label = self.new_label();
1299+
let (last_value, values) = values.split_last().unwrap();
1300+
1301+
// If any value is true, we can short-circuit:
1302+
for value in values {
1303+
self.compile_jump_if(value, true, end_label)?;
1304+
}
1305+
1306+
// It all depends upon the last value now!
1307+
self.compile_jump_if(last_value, false, target_label)?;
1308+
self.set_label(end_label);
1309+
}
12831310
}
12841311
}
1285-
},
1312+
}
1313+
ast::ExpressionType::Unop {
1314+
op: ast::UnaryOperator::Not,
1315+
a,
1316+
} => {
1317+
self.compile_jump_if(a, !condition, target_label)?;
1318+
}
12861319
_ => {
1320+
// Fall back case which always will work!
12871321
self.compile_expression(expression)?;
1288-
match context {
1289-
EvalContext::Statement => {
1290-
if let Some(true_label) = true_label {
1291-
self.emit(Instruction::JumpIf { target: true_label });
1292-
}
1293-
if let Some(false_label) = false_label {
1294-
self.emit(Instruction::JumpIfFalse {
1295-
target: false_label,
1296-
});
1297-
}
1298-
}
1299-
EvalContext::Expression => {
1300-
if let Some(true_label) = true_label {
1301-
self.emit(Instruction::Duplicate);
1302-
self.emit(Instruction::JumpIf { target: true_label });
1303-
self.emit(Instruction::Pop);
1304-
}
1305-
if let Some(false_label) = false_label {
1306-
self.emit(Instruction::Duplicate);
1307-
self.emit(Instruction::JumpIfFalse {
1308-
target: false_label,
1309-
});
1310-
self.emit(Instruction::Pop);
1311-
}
1312-
}
1322+
if condition {
1323+
self.emit(Instruction::JumpIfTrue {
1324+
target: target_label,
1325+
});
1326+
} else {
1327+
self.emit(Instruction::JumpIfFalse {
1328+
target: target_label,
1329+
});
13131330
}
13141331
}
13151332
}
13161333
Ok(())
13171334
}
13181335

1336+
/// Compile a boolean operation as an expression.
1337+
/// This means, that the last value remains on the stack.
1338+
fn compile_bool_op(
1339+
&mut self,
1340+
op: &ast::BooleanOperator,
1341+
values: &[ast::Expression],
1342+
) -> Result<(), CompileError> {
1343+
let end_label = self.new_label();
1344+
1345+
let (last_value, values) = values.split_last().unwrap();
1346+
for value in values {
1347+
self.compile_expression(value)?;
1348+
1349+
match op {
1350+
ast::BooleanOperator::And => {
1351+
self.emit(Instruction::JumpIfFalseOrPop { target: end_label });
1352+
}
1353+
ast::BooleanOperator::Or => {
1354+
self.emit(Instruction::JumpIfTrueOrPop { target: end_label });
1355+
}
1356+
}
1357+
}
1358+
1359+
// If all values did not qualify, take the value of the last value:
1360+
self.compile_expression(last_value)?;
1361+
self.set_label(end_label);
1362+
Ok(())
1363+
}
1364+
13191365
fn compile_expression(&mut self, expression: &ast::Expression) -> Result<(), CompileError> {
13201366
trace!("Compiling {:?}", expression);
13211367
self.set_source_location(&expression.location);
@@ -1327,12 +1373,7 @@ impl<O: OutputStream> Compiler<O> {
13271373
args,
13281374
keywords,
13291375
} => self.compile_call(function, args, keywords)?,
1330-
BoolOp { .. } => self.compile_test(
1331-
expression,
1332-
Option::None,
1333-
Option::None,
1334-
EvalContext::Expression,
1335-
)?,
1376+
BoolOp { op, values } => self.compile_bool_op(op, values)?,
13361377
Binop { a, op, b } => {
13371378
self.compile_expression(a)?;
13381379
self.compile_expression(b)?;
@@ -1527,8 +1568,7 @@ impl<O: OutputStream> Compiler<O> {
15271568
IfExpression { test, body, orelse } => {
15281569
let no_label = self.new_label();
15291570
let end_label = self.new_label();
1530-
self.compile_test(test, Option::None, Option::None, EvalContext::Expression)?;
1531-
self.emit(Instruction::JumpIfFalse { target: no_label });
1571+
self.compile_jump_if(test, false, no_label)?;
15321572
// True case
15331573
self.compile_expression(body)?;
15341574
self.emit(Instruction::Jump { target: end_label });
@@ -1745,12 +1785,7 @@ impl<O: OutputStream> Compiler<O> {
17451785

17461786
// Now evaluate the ifs:
17471787
for if_condition in &generator.ifs {
1748-
self.compile_test(
1749-
if_condition,
1750-
None,
1751-
Some(start_label),
1752-
EvalContext::Statement,
1753-
)?
1788+
self.compile_jump_if(if_condition, false, start_label)?
17541789
}
17551790
}
17561791

@@ -1988,11 +2023,11 @@ mod tests {
19882023
LoadConst {
19892024
value: Boolean { value: true }
19902025
},
1991-
JumpIf { target: 1 },
2026+
JumpIfTrue { target: 1 },
19922027
LoadConst {
19932028
value: Boolean { value: false }
19942029
},
1995-
JumpIf { target: 1 },
2030+
JumpIfTrue { target: 1 },
19962031
LoadConst {
19972032
value: Boolean { value: false }
19982033
},
@@ -2042,7 +2077,7 @@ mod tests {
20422077
LoadConst {
20432078
value: Boolean { value: false }
20442079
},
2045-
JumpIf { target: 1 },
2080+
JumpIfTrue { target: 1 },
20462081
LoadConst {
20472082
value: Boolean { value: false }
20482083
},

compiler/src/symboltable.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -404,9 +404,8 @@ impl SymbolTableBuilder {
404404
self.scan_expression(a)?;
405405
self.scan_expression(b)?;
406406
}
407-
BoolOp { a, b, .. } => {
408-
self.scan_expression(a)?;
409-
self.scan_expression(b)?;
407+
BoolOp { values, .. } => {
408+
self.scan_expressions(values)?;
410409
}
411410
Compare { vals, .. } => {
412411
self.scan_expressions(vals)?;

parser/src/ast.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,8 @@ pub type Expression = Located<ExpressionType>;
148148
#[derive(Debug, PartialEq)]
149149
pub enum ExpressionType {
150150
BoolOp {
151-
a: Box<Expression>,
152151
op: BooleanOperator,
153-
b: Box<Expression>,
152+
values: Vec<Expression>,
154153
},
155154
Binop {
156155
a: Box<Expression>,

parser/src/python.lalrpop

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -659,18 +659,32 @@ LambdaDef: ast::Expression = {
659659
}
660660

661661
OrTest: ast::Expression = {
662-
AndTest,
663-
<e1:OrTest> <location:@L> "or" <e2:AndTest> => ast::Expression {
664-
location,
665-
node: ast::ExpressionType::BoolOp { a: Box::new(e1), op: ast::BooleanOperator::Or, b: Box::new(e2) }
662+
<e1:AndTest> <location:@L> <e2:("or" AndTest)*> => {
663+
if e2.is_empty() {
664+
e1
665+
} else {
666+
let mut values = vec![e1];
667+
values.extend(e2.into_iter().map(|e| e.1));
668+
ast::Expression {
669+
location,
670+
node: ast::ExpressionType::BoolOp { op: ast::BooleanOperator::Or, values }
671+
}
672+
}
666673
},
667674
};
668675

669676
AndTest: ast::Expression = {
670-
NotTest,
671-
<e1:AndTest> <location:@L> "and" <e2:NotTest> => ast::Expression {
672-
location,
673-
node: ast::ExpressionType::BoolOp { a: Box::new(e1), op: ast::BooleanOperator::And, b: Box::new(e2) }
677+
<e1:NotTest> <location:@L> <e2:("and" NotTest)*> => {
678+
if e2.is_empty() {
679+
e1
680+
} else {
681+
let mut values = vec![e1];
682+
values.extend(e2.into_iter().map(|e| e.1));
683+
ast::Expression {
684+
location,
685+
node: ast::ExpressionType::BoolOp { op: ast::BooleanOperator::And, values }
686+
}
687+
}
674688
},
675689
};
676690

0 commit comments

Comments
 (0)