Skip to content

Commit 4e6172b

Browse files
authored
Add object protocol correspoinding to PyObject_GetAIter (RustPython#5090)
1 parent 9241e2e commit 4e6172b

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

vm/src/protocol/object.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
44
use crate::{
55
builtins::{
6-
pystr::AsPyStr, PyBytes, PyDict, PyDictRef, PyGenericAlias, PyInt, PyList, PyStr, PyStrRef,
7-
PyTuple, PyTupleRef, PyType, PyTypeRef,
6+
pystr::AsPyStr, PyAsyncGen, PyBytes, PyDict, PyDictRef, PyGenericAlias, PyInt, PyList,
7+
PyStr, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef,
88
},
99
bytesinner::ByteInnerNewOptions,
1010
common::{hash::PyHash, str::to_ascii},
@@ -92,6 +92,13 @@ impl PyObject {
9292
}
9393

9494
// PyObject *PyObject_GetAIter(PyObject *o)
95+
pub fn get_aiter(&self, vm: &VirtualMachine) -> PyResult {
96+
if self.payload_is::<PyAsyncGen>() {
97+
vm.call_special_method(self, identifier!(vm, __aiter__), ())
98+
} else {
99+
Err(vm.new_type_error("wrong argument type".to_owned()))
100+
}
101+
}
95102

96103
pub fn has_attr<'a>(&self, attr_name: impl AsPyStr<'a>, vm: &VirtualMachine) -> PyResult<bool> {
97104
self.get_attr(attr_name, vm).map(|o| !vm.is_none(&o))

vm/src/stdlib/builtins.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ pub use builtins::{ascii, print, reversed};
99
mod builtins {
1010
use crate::{
1111
builtins::{
12-
asyncgenerator::PyAsyncGen,
1312
enumerate::PyReverseSequenceIterator,
1413
function::{PyCellRef, PyFunction},
1514
int::PyIntRef,
@@ -459,11 +458,7 @@ mod builtins {
459458

460459
#[pyfunction]
461460
fn aiter(iter_target: PyObjectRef, vm: &VirtualMachine) -> PyResult {
462-
if iter_target.payload_is::<PyAsyncGen>() {
463-
vm.call_special_method(&iter_target, identifier!(vm, __aiter__), ())
464-
} else {
465-
Err(vm.new_type_error("wrong argument type".to_owned()))
466-
}
461+
iter_target.get_aiter(vm)
467462
}
468463

469464
#[pyfunction]

0 commit comments

Comments
 (0)