Skip to content

Commit 6bc33fb

Browse files
committed
Add async for
1 parent 960e8de commit 6bc33fb

File tree

4 files changed

+122
-33
lines changed

4 files changed

+122
-33
lines changed

bytecode/src/bytecode.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,8 @@ pub enum Instruction {
278278
SetupAsyncWith {
279279
end: Label,
280280
},
281+
GetAIter,
282+
GetANext,
281283
}
282284

283285
use self::Instruction::*;
@@ -315,6 +317,7 @@ pub enum ComparisonOperator {
315317
NotIn,
316318
Is,
317319
IsNot,
320+
ExceptionMatch,
318321
}
319322

320323
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
@@ -558,6 +561,8 @@ impl Instruction {
558561
PopException => w!(PopException),
559562
Reverse { amount } => w!(Reverse, amount),
560563
GetAwaitable => w!(GetAwaitable),
564+
GetAIter => w!(GetAIter),
565+
GetANext => w!(GetANext),
561566
}
562567
}
563568
}

compiler/src/compile.rs

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -442,13 +442,7 @@ impl<O: OutputStream> Compiler<O> {
442442
iter,
443443
body,
444444
orelse,
445-
} => {
446-
if *is_async {
447-
unimplemented!("async for");
448-
} else {
449-
self.compile_for(target, iter, body, orelse)?
450-
}
451-
}
445+
} => self.compile_for(target, iter, body, orelse, *is_async)?,
452446
Raise { exception, cause } => match exception {
453447
Some(value) => {
454448
self.compile_expression(value)?;
@@ -736,14 +730,9 @@ impl<O: OutputStream> Compiler<O> {
736730
self.emit(Instruction::Duplicate);
737731

738732
// Check exception type:
739-
self.emit(Instruction::LoadName {
740-
name: String::from("isinstance"),
741-
scope: bytecode::NameScope::Global,
742-
});
743-
self.emit(Instruction::Rotate { amount: 2 });
744733
self.compile_expression(exc_type)?;
745-
self.emit(Instruction::CallFunction {
746-
typ: CallType::Positional(2),
734+
self.emit(Instruction::CompareOperation {
735+
op: bytecode::ComparisonOperator::ExceptionMatch,
747736
});
748737

749738
// We cannot handle this exception type:
@@ -1108,11 +1097,13 @@ impl<O: OutputStream> Compiler<O> {
11081097
iter: &ast::Expression,
11091098
body: &[ast::Statement],
11101099
orelse: &Option<Vec<ast::Statement>>,
1100+
is_async: bool,
11111101
) -> Result<(), CompileError> {
11121102
// Start loop
11131103
let start_label = self.new_label();
11141104
let else_label = self.new_label();
11151105
let end_label = self.new_label();
1106+
11161107
self.emit(Instruction::SetupLoop {
11171108
start: start_label,
11181109
end: end_label,
@@ -1121,19 +1112,57 @@ impl<O: OutputStream> Compiler<O> {
11211112
// The thing iterated:
11221113
self.compile_expression(iter)?;
11231114

1124-
// Retrieve Iterator
1125-
self.emit(Instruction::GetIter);
1115+
if is_async {
1116+
let check_asynciter_label = self.new_label();
1117+
let body_label = self.new_label();
11261118

1127-
self.set_label(start_label);
1128-
self.emit(Instruction::ForIter { target: else_label });
1119+
self.emit(Instruction::GetAIter);
11291120

1130-
// Start of loop iteration, set targets:
1131-
self.compile_store(target)?;
1121+
self.set_label(start_label);
1122+
self.emit(Instruction::SetupExcept {
1123+
handler: check_asynciter_label,
1124+
});
1125+
self.emit(Instruction::GetANext);
1126+
self.emit(Instruction::LoadConst {
1127+
value: bytecode::Constant::None,
1128+
});
1129+
self.emit(Instruction::YieldFrom);
1130+
self.compile_store(target)?;
1131+
self.emit(Instruction::PopBlock);
1132+
self.emit(Instruction::Jump { target: body_label });
11321133

1133-
let was_in_loop = self.in_loop;
1134-
self.in_loop = true;
1135-
self.compile_statements(body)?;
1136-
self.in_loop = was_in_loop;
1134+
self.set_label(check_asynciter_label);
1135+
self.emit(Instruction::Duplicate);
1136+
self.emit(Instruction::LoadName {
1137+
name: "StopAsyncIteration".to_string(),
1138+
scope: bytecode::NameScope::Global,
1139+
});
1140+
self.emit(Instruction::CompareOperation {
1141+
op: bytecode::ComparisonOperator::ExceptionMatch,
1142+
});
1143+
self.emit(Instruction::JumpIfTrue { target: else_label });
1144+
self.emit(Instruction::Raise { argc: 0 });
1145+
1146+
let was_in_loop = self.in_loop;
1147+
self.in_loop = true;
1148+
self.set_label(body_label);
1149+
self.compile_statements(body)?;
1150+
self.in_loop = was_in_loop;
1151+
} else {
1152+
// Retrieve Iterator
1153+
self.emit(Instruction::GetIter);
1154+
1155+
self.set_label(start_label);
1156+
self.emit(Instruction::ForIter { target: else_label });
1157+
1158+
// Start of loop iteration, set targets:
1159+
self.compile_store(target)?;
1160+
1161+
let was_in_loop = self.in_loop;
1162+
self.in_loop = true;
1163+
self.compile_statements(body)?;
1164+
self.in_loop = was_in_loop;
1165+
}
11371166

11381167
self.emit(Instruction::Jump {
11391168
target: start_label,
@@ -1144,6 +1173,9 @@ impl<O: OutputStream> Compiler<O> {
11441173
self.compile_statements(orelse)?;
11451174
}
11461175
self.set_label(end_label);
1176+
if is_async {
1177+
self.emit(Instruction::Pop);
1178+
}
11471179
Ok(())
11481180
}
11491181

tests/snippets/async_stuff.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,27 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
1919
ls = []
2020

2121

22+
class AIterWrap:
23+
def __init__(self, obj):
24+
self._it = iter(obj)
25+
26+
def __aiter__(self):
27+
return self
28+
29+
async def __anext__(self):
30+
try:
31+
value = next(self._it)
32+
except StopIteration:
33+
raise StopAsyncIteration
34+
return value
35+
36+
2237
async def a(s, m):
2338
async with ContextManager() as b:
2439
print(f"val = {b}")
2540
await asyncio.sleep(s)
26-
for _ in range(0, 2):
41+
async for i in AIterWrap(range(0, 2)):
42+
print(i)
2743
ls.append(m)
2844
await asyncio.sleep(1)
2945

vm/src/frame.rs

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use indexmap::IndexMap;
55
use itertools::Itertools;
66

77
use crate::bytecode;
8-
use crate::function::PyFuncArgs;
8+
use crate::function::{single_or_tuple_any, PyFuncArgs};
99
use crate::obj::objbool;
1010
use crate::obj::objcode::PyCodeRef;
1111
use crate::obj::objcoroutine::PyCoroutine;
@@ -278,7 +278,7 @@ impl Frame {
278278
trace!("=======");
279279
}
280280

281-
match &instruction {
281+
match instruction {
282282
bytecode::Instruction::LoadConst { ref value } => {
283283
let obj = vm.ctx.unwrap_constant(value);
284284
self.push_value(obj);
@@ -503,8 +503,7 @@ impl Frame {
503503
}
504504
bytecode::Instruction::GetAwaitable => {
505505
let awaited_obj = self.pop_value();
506-
let awaitable = if awaited_obj.payload_is::<crate::obj::objcoroutine::PyCoroutine>()
507-
{
506+
let awaitable = if awaited_obj.payload_is::<PyCoroutine>() {
508507
awaited_obj
509508
} else {
510509
let await_method =
@@ -519,6 +518,23 @@ impl Frame {
519518
self.push_value(awaitable);
520519
Ok(None)
521520
}
521+
bytecode::Instruction::GetAIter => {
522+
let aiterable = self.pop_value();
523+
let aiter = vm.call_method(&aiterable, "__aiter__", vec![])?;
524+
self.push_value(aiter);
525+
Ok(None)
526+
}
527+
bytecode::Instruction::GetANext => {
528+
let aiter = self.last_value();
529+
let awaitable = vm.call_method(&aiter, "__anext__", vec![])?;
530+
let awaitable = if awaitable.payload_is::<PyCoroutine>() {
531+
awaitable
532+
} else {
533+
vm.call_method(&awaitable, "__await__", vec![])?
534+
};
535+
self.push_value(awaitable);
536+
Ok(None)
537+
}
522538
bytecode::Instruction::ForIter { target } => self.execute_for_iter(vm, *target),
523539
bytecode::Instruction::MakeFunction => self.execute_make_function(vm),
524540
bytecode::Instruction::CallFunction { typ } => self.execute_call_function(vm, typ),
@@ -1251,6 +1267,25 @@ impl Frame {
12511267
!a.is(&b)
12521268
}
12531269

1270+
fn exc_match(
1271+
&self,
1272+
vm: &VirtualMachine,
1273+
exc: PyObjectRef,
1274+
exc_type: PyObjectRef,
1275+
) -> PyResult<bool> {
1276+
single_or_tuple_any(
1277+
exc_type,
1278+
|cls: PyClassRef| vm.isinstance(&exc, &cls),
1279+
|o| {
1280+
format!(
1281+
"isinstance() arg 2 must be a type or tuple of types, not {}",
1282+
o.class()
1283+
)
1284+
},
1285+
vm,
1286+
)
1287+
}
1288+
12541289
#[cfg_attr(feature = "flame-it", flame("Frame"))]
12551290
fn execute_compare(
12561291
&self,
@@ -1266,10 +1301,11 @@ impl Frame {
12661301
bytecode::ComparisonOperator::LessOrEqual => vm._le(a, b)?,
12671302
bytecode::ComparisonOperator::Greater => vm._gt(a, b)?,
12681303
bytecode::ComparisonOperator::GreaterOrEqual => vm._ge(a, b)?,
1269-
bytecode::ComparisonOperator::Is => vm.ctx.new_bool(self._is(a, b)),
1270-
bytecode::ComparisonOperator::IsNot => vm.ctx.new_bool(self._is_not(a, b)),
1271-
bytecode::ComparisonOperator::In => vm.ctx.new_bool(self._in(vm, a, b)?),
1272-
bytecode::ComparisonOperator::NotIn => vm.ctx.new_bool(self._not_in(vm, a, b)?),
1304+
bytecode::ComparisonOperator::Is => vm.new_bool(self._is(a, b)),
1305+
bytecode::ComparisonOperator::IsNot => vm.new_bool(self._is_not(a, b)),
1306+
bytecode::ComparisonOperator::In => vm.new_bool(self._in(vm, a, b)?),
1307+
bytecode::ComparisonOperator::NotIn => vm.new_bool(self._not_in(vm, a, b)?),
1308+
bytecode::ComparisonOperator::ExceptionMatch => vm.new_bool(self.exc_match(vm, a, b)?),
12731309
};
12741310

12751311
self.push_value(value);

0 commit comments

Comments
 (0)