Skip to content

Commit a10936c

Browse files
committed
Fix atomic iterator increments
1 parent 219802a commit a10936c

File tree

4 files changed

+33
-50
lines changed

4 files changed

+33
-50
lines changed

vm/src/obj/objiter.rs

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -168,20 +168,15 @@ impl PySequenceIterator {
168168

169169
#[pymethod(name = "__next__")]
170170
fn next(&self, vm: &VirtualMachine) -> PyResult {
171-
let pos = self.position.load();
171+
let step: isize = if self.reversed { -1 } else { 1 };
172+
let pos = self.position.fetch_add(step);
172173
if pos >= 0 {
173-
let step: isize = if self.reversed { -1 } else { 1 };
174-
let number = vm.ctx.new_int(pos);
175-
match vm.call_method(&self.obj, "__getitem__", vec![number]) {
176-
Ok(val) => {
177-
self.position.store(pos + step);
178-
Ok(val)
179-
}
174+
match vm.call_method(&self.obj, "__getitem__", vec![vm.new_int(pos)]) {
180175
Err(ref e) if objtype::isinstance(&e, &vm.ctx.exceptions.index_error) => {
181176
Err(new_stop_iteration(vm))
182177
}
183178
// also catches stop_iteration => stop_iteration
184-
Err(e) => Err(e),
179+
ret => ret,
185180
}
186181
} else {
187182
Err(new_stop_iteration(vm))
@@ -217,7 +212,7 @@ pub fn seq_iter_method(obj: PyObjectRef) -> PySequenceIterator {
217212
pub struct PyCallableIterator {
218213
callable: PyCallable,
219214
sentinel: PyObjectRef,
220-
done: Cell<bool>,
215+
done: AtomicCell<bool>,
221216
}
222217

223218
impl PyValue for PyCallableIterator {
@@ -232,20 +227,20 @@ impl PyCallableIterator {
232227
Self {
233228
callable,
234229
sentinel,
235-
done: Cell::new(false),
230+
done: AtomicCell::new(false),
236231
}
237232
}
238233

239234
#[pymethod(magic)]
240235
fn next(&self, vm: &VirtualMachine) -> PyResult {
241-
if self.done.get() {
236+
if self.done.load() {
242237
return Err(new_stop_iteration(vm));
243238
}
244239

245240
let ret = self.callable.invoke(vec![], vm)?;
246241

247242
if vm.bool_eq(ret.clone(), self.sentinel.clone())? {
248-
self.done.set(true);
243+
self.done.store(true);
249244
Err(new_stop_iteration(vm))
250245
} else {
251246
Ok(ret)

vm/src/obj/objlist.rs

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ impl PyList {
236236
fn reversed(zelf: PyRef<Self>) -> PyListReverseIterator {
237237
let final_position = zelf.elements.borrow().len();
238238
PyListReverseIterator {
239-
position: AtomicCell::new(final_position),
239+
position: AtomicCell::new(final_position as isize),
240240
list: zelf,
241241
}
242242
}
@@ -861,9 +861,8 @@ impl PyListIterator {
861861
#[pymethod(name = "__next__")]
862862
fn next(&self, vm: &VirtualMachine) -> PyResult {
863863
let list = self.list.elements.borrow();
864-
let pos = self.position.load();
864+
let pos = self.position.fetch_add(1);
865865
if let Some(obj) = list.get(pos) {
866-
self.position.store(pos + 1);
867866
Ok(obj.clone())
868867
} else {
869868
Err(objiter::new_stop_iteration(vm))
@@ -879,14 +878,14 @@ impl PyListIterator {
879878
fn length_hint(&self) -> usize {
880879
let list = self.list.elements.borrow();
881880
let pos = self.position.load();
882-
list.len() - pos
881+
list.len().saturating_sub(pos)
883882
}
884883
}
885884

886885
#[pyclass]
887886
#[derive(Debug)]
888887
pub struct PyListReverseIterator {
889-
pub position: AtomicCell<usize>,
888+
pub position: AtomicCell<isize>,
890889
pub list: PyListRef,
891890
}
892891

@@ -900,16 +899,14 @@ impl PyValue for PyListReverseIterator {
900899
impl PyListReverseIterator {
901900
#[pymethod(name = "__next__")]
902901
fn next(&self, vm: &VirtualMachine) -> PyResult {
903-
let pos = self.position.load();
902+
let list = self.list.elements.borrow();
903+
let pos = self.position.fetch_sub(1);
904904
if pos > 0 {
905-
let pos = pos - 1;
906-
let list = self.list.elements.borrow();
907-
let ret = list[pos].clone();
908-
self.position.store(pos);
909-
Ok(ret)
910-
} else {
911-
Err(objiter::new_stop_iteration(vm))
905+
if let Some(ret) = list.get(pos as usize - 1) {
906+
return Ok(ret.clone());
907+
}
912908
}
909+
Err(objiter::new_stop_iteration(vm))
913910
}
914911

915912
#[pymethod(name = "__iter__")]
@@ -919,7 +916,7 @@ impl PyListReverseIterator {
919916

920917
#[pymethod(name = "__length_hint__")]
921918
fn length_hint(&self) -> usize {
922-
self.position.load()
919+
std::cmp::max(self.position.load(), 0) as usize
923920
}
924921
}
925922

vm/src/obj/objstr.rs

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ impl TryIntoRef<PyString> for &str {
9797
#[derive(Debug)]
9898
pub struct PyStringIterator {
9999
pub string: PyStringRef,
100-
byte_position: AtomicCell<usize>,
100+
position: AtomicCell<usize>,
101101
}
102102

103103
impl PyValue for PyStringIterator {
@@ -109,16 +109,12 @@ impl PyValue for PyStringIterator {
109109
#[pyimpl]
110110
impl PyStringIterator {
111111
#[pymethod(name = "__next__")]
112-
fn next(&self, vm: &VirtualMachine) -> PyResult {
113-
let pos = self.byte_position.load();
112+
fn next(&self, vm: &VirtualMachine) -> PyResult<String> {
113+
// TODO: use something more performant than chars().nth() that's still atomic
114+
let pos = self.position.fetch_add(1);
114115

115-
if pos < self.string.value.len() {
116-
// We can be sure that chars() has a value, because of the pos check above.
117-
let char_ = self.string.value[pos..].chars().next().unwrap();
118-
119-
self.byte_position.store(pos + char_.len_utf8());
120-
121-
char_.to_string().into_pyobject(vm)
116+
if let Some(c) = self.string.value.chars().nth(pos) {
117+
Ok(c.to_string())
122118
} else {
123119
Err(objiter::new_stop_iteration(vm))
124120
}
@@ -133,7 +129,7 @@ impl PyStringIterator {
133129
#[pyclass]
134130
#[derive(Debug)]
135131
pub struct PyStringReverseIterator {
136-
pub position: AtomicCell<usize>,
132+
pub position: AtomicCell<isize>,
137133
pub string: PyStringRef,
138134
}
139135

@@ -146,14 +142,10 @@ impl PyValue for PyStringReverseIterator {
146142
#[pyimpl]
147143
impl PyStringReverseIterator {
148144
#[pymethod(name = "__next__")]
149-
fn next(&self, vm: &VirtualMachine) -> PyResult {
150-
let pos = self.position.load();
151-
152-
if pos > 0 {
153-
let value = self.string.value.do_slice(pos - 1..pos);
154-
155-
self.position.store(pos - 1);
156-
value.into_pyobject(vm)
145+
fn next(&self, vm: &VirtualMachine) -> PyResult<String> {
146+
let pos = self.position.fetch_sub(1) as usize - 1;
147+
if let Some(c) = self.string.value.chars().nth(pos) {
148+
Ok(c.to_string())
157149
} else {
158150
Err(objiter::new_stop_iteration(vm))
159151
}
@@ -1288,7 +1280,7 @@ impl PyString {
12881280
#[pymethod(magic)]
12891281
fn iter(zelf: PyRef<Self>) -> PyStringIterator {
12901282
PyStringIterator {
1291-
byte_position: AtomicCell::new(0),
1283+
position: AtomicCell::new(0),
12921284
string: zelf,
12931285
}
12941286
}
@@ -1298,7 +1290,7 @@ impl PyString {
12981290
let begin = zelf.value.chars().count();
12991291

13001292
PyStringReverseIterator {
1301-
position: AtomicCell::new(begin),
1293+
position: AtomicCell::new(begin as isize),
13021294
string: zelf,
13031295
}
13041296
}

vm/src/obj/objtuple.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,8 @@ impl PyValue for PyTupleIterator {
257257
impl PyTupleIterator {
258258
#[pymethod(name = "__next__")]
259259
fn next(&self, vm: &VirtualMachine) -> PyResult {
260-
let pos = self.position.load();
260+
let pos = self.position.fetch_add(1);
261261
if let Some(obj) = self.tuple.as_slice().get(pos) {
262-
self.position.store(pos + 1);
263262
Ok(obj.clone())
264263
} else {
265264
Err(objiter::new_stop_iteration(vm))

0 commit comments

Comments
 (0)