From d6192cd927de41bb169fde6e693a3dd9b0ae2216 Mon Sep 17 00:00:00 2001 From: TheAnyKey Date: Thu, 7 May 2020 19:48:23 +0000 Subject: [PATCH 01/19] merged from other brach --- tests/snippets/dict_union.py | 83 ++++++++++++++++++++++++++++++++++++ vm/src/obj/objdict.rs | 43 +++++++++++++++++++ 2 files changed, 126 insertions(+) create mode 100644 tests/snippets/dict_union.py diff --git a/tests/snippets/dict_union.py b/tests/snippets/dict_union.py new file mode 100644 index 0000000000..29e0718d45 --- /dev/null +++ b/tests/snippets/dict_union.py @@ -0,0 +1,83 @@ + +import testutils + +def test_dunion_ior0(): + a={1:2,2:3} + b={3:4,5:6} + a|=b + + assert a == {1:2,2:3,3:4,5:6}, f"wrong value assigned {a=}" + assert b == {3:4,5:6}, f"right hand side modified, {b=}" + +def test_dunion_or0(): + a={1:2,2:3} + b={3:4,5:6} + c=a|b + + assert a == {1:2,2:3}, f"left hand side of non-assignment operator modified {a=}" + assert b == {3:4,5:6}, f"right hand side of non-assignment operator modified, {b=}" + assert c == {1:2,2:3, 3:4, 5:6}, f"unexpected result of dict union {c=}" + + +def test_dunion_or1(): + a={1:2,2:3} + b={3:4,5:6} + c=a.__or__(b) + + assert a == {1:2,2:3}, f"left hand side of non-assignment operator modified {a=}" + assert b == {3:4,5:6}, f"right hand side of non-assignment operator modified, {b=}" + assert c == {1:2,2:3, 3:4, 5:6}, f"unexpected result of dict union {c=}" + + +def test_dunion_ror0(): + a={1:2,2:3} + b={3:4,5:6} + c=b.__ror__(a) + + assert a == {1:2,2:3}, f"left hand side of non-assignment operator modified {a=}" + assert b == {3:4,5:6}, f"right hand side of non-assignment operator modified, {b=}" + assert c == {1:2,2:3, 3:4, 5:6}, f"unexpected result of dict union {c=}" + + +def test_dunion_other_types(): + def perf_test_or(other_obj): + d={1:2} + try: + d.__or__(other_obj) + except: + return True + return False + + def perf_test_ior(other_obj): + d={1:2} + try: + d.__ior__(other_obj) + except: + return True + return False + + def perf_test_ror(other_obj): + d={1:2} + try: + d.__ror__(other_obj) + except: + return True + return False + + test_fct={'__or__':perf_test_or, '__ror__':perf_test_ror, '__ior__':perf_test_ior} + others=['FooBar', 42, [36], set([19]), ['aa'], None] + for tfn,tf in test_fct.items(): + for other in others: + assert tf(other), f"Failed: dict {tfn}, accepted {other}" + + + + +testutils.skip_if_unsupported(3,9,test_dunion_ior0) +testutils.skip_if_unsupported(3,9,test_dunion_or0) +testutils.skip_if_unsupported(3,9,test_dunion_or1) +testutils.skip_if_unsupported(3,9,test_dunion_ror0) +testutils.skip_if_unsupported(3,9,test_dunion_other_types) + + + diff --git a/vm/src/obj/objdict.rs b/vm/src/obj/objdict.rs index a5330bdbcf..e2cdb108db 100644 --- a/vm/src/obj/objdict.rs +++ b/vm/src/obj/objdict.rs @@ -104,6 +104,17 @@ impl PyDictRef { Ok(()) } + fn merge_dict( + dict: &DictContentType, + dict_other: PyDictRef, + vm: &VirtualMachine, + ) -> PyResult<()> { + for (key, value) in dict_other { + dict.insert(vm, &key, value)?; + } + Ok(()) + } + #[pyclassmethod] fn fromkeys( class: PyClassRef, @@ -320,6 +331,38 @@ impl PyDictRef { PyDictRef::merge(&self.entries, dict_obj, kwargs, vm) } + #[pymethod(name = "__ior__")] + fn ior(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let dicted: Result = other.clone().downcast(); + if let Ok(other) = dicted { + PyDictRef::merge_dict(&self.entries, other, vm)?; + return Ok(self.into_object()); + } + Err(vm.new_type_error("__ior__ not implemented for non-dict type".to_owned())) + } + + #[pymethod(name = "__ror__")] + fn ror(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let dicted: Result = other.clone().downcast(); + if let Ok(other) = dicted { + let other_cp = other.copy(); + PyDictRef::merge_dict(&other_cp.entries, self, vm)?; + return Ok(other_cp); + } + Err(vm.new_type_error("__ror__ not implemented for non-dict type".to_owned())) + } + + #[pymethod(name = "__or__")] + fn or(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let dicted: Result = other.clone().downcast(); + if let Ok(other) = dicted { + let self_cp = self.copy(); + PyDictRef::merge_dict(&self_cp.entries, other, vm)?; + return Ok(self_cp); + } + Err(vm.new_type_error("__or__ not implemented for non-dict type".to_owned())) + } + #[pymethod] fn pop( self, From a447b886148174abc73a4a8f949b8231cb5932ec Mon Sep 17 00:00:00 2001 From: TheAnyKey Date: Fri, 8 May 2020 08:49:09 +0000 Subject: [PATCH 02/19] merges testutils --- tests/snippets/testutils.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/snippets/testutils.py b/tests/snippets/testutils.py index 8a9fdddb2f..c779d2c898 100644 --- a/tests/snippets/testutils.py +++ b/tests/snippets/testutils.py @@ -1,3 +1,6 @@ +import platform +import sys + def assert_raises(expected, *args, _msg=None, **kw): if args: f, f_args = args[0], args[1:] @@ -67,3 +70,26 @@ def assert_isinstance(obj, klass): def assert_in(a, b): _assert_print(lambda: a in b, [a, 'in', b]) + +def skip_if_unsupported(req_maj_vers, req_min_vers, test_fct): + def exec(): + test_fct() + + if platform.python_implementation() == 'RustPython': + exec() + elif sys.version_info.major>=req_maj_vers and sys.version_info.minor>=req_min_vers: + exec() + else: + print(f'Skipping test as a higher python version is required. Using {platform.python_implementation()} {platform.python_version()}') + +def fail_if_unsupported(req_maj_vers, req_min_vers, test_fct): + def exec(): + test_fct() + + if platform.python_implementation() == 'RustPython': + exec() + elif sys.version_info.major>=req_maj_vers and sys.version_info.minor>=req_min_vers: + exec() + else: + assert False, f'Test cannot performed on this python version. {platform.python_implementation()} {platform.python_version()}' + From 84a6e8e8df72a7ceeb7f45781857b70ed0a7b06e Mon Sep 17 00:00:00 2001 From: TheAnyKey Date: Fri, 8 May 2020 15:28:46 +0000 Subject: [PATCH 03/19] merged from TheAnyKey/p39_string_rem_pre_suffix --- tests/snippets/strings.py | 72 ++++++++++++++++++++++++++++++++++++++- vm/src/obj/objstr.rs | 16 +++++++++ 2 files changed, 87 insertions(+), 1 deletion(-) diff --git a/tests/snippets/strings.py b/tests/snippets/strings.py index 7471b70050..7ceb653c1c 100644 --- a/tests/snippets/strings.py +++ b/tests/snippets/strings.py @@ -1,4 +1,4 @@ -from testutils import assert_raises, AssertRaises +from testutils import assert_raises, AssertRaises, skip_if_unsupported assert "".__eq__(1) == NotImplemented assert "a" == 'a' @@ -471,3 +471,73 @@ def try_mutate_str(): assert '{:e}'.format(float('inf')) == 'inf' assert '{:e}'.format(float('-inf')) == '-inf' assert '{:E}'.format(float('inf')) == 'INF' + + +# remove*fix test +def test_removeprefix(): + s='foobarfoo' + s_ref='foobarfoo' + assert s.removeprefix('f') == s_ref[1:] + assert s.removeprefix('fo') == s_ref[2:] + assert s.removeprefix('foo') == s_ref[3:] + + assert s.removeprefix('') == s_ref + assert s.removeprefix('bar') == s_ref + assert s.removeprefix('lol') == s_ref + assert s.removeprefix('_foo') == s_ref + assert s.removeprefix('-foo') == s_ref + assert s.removeprefix('afoo') == s_ref + assert s.removeprefix('*foo') == s_ref + + assert s==s_ref, 'undefined test fail' + +def test_removeprefix_types(): + s='0123456' + s_ref='0123456' + others=[0,['012']] + found=False + for o in others: + try: + s.removeprefix(o) + except: + found=True + + assert found, f'Removeprefix accepts other type: {type(o)}: {o=}' + +def test_removesuffix(): + s='foobarfoo' + s_ref='foobarfoo' + assert s.removesuffix('o') == s_ref[:-1] + assert s.removesuffix('oo') == s_ref[:-2] + assert s.removesuffix('foo') == s_ref[:-3] + + assert s.removesuffix('') == s_ref + assert s.removesuffix('bar') == s_ref + assert s.removesuffix('lol') == s_ref + assert s.removesuffix('foo_') == s_ref + assert s.removesuffix('foo-') == s_ref + assert s.removesuffix('foo*') == s_ref + assert s.removesuffix('fooa') == s_ref + + assert s==s_ref, 'undefined test fail' + +def test_removesuffix_types(): + s='0123456' + s_ref='0123456' + others=[0,6,['6']] + found=False + for o in others: + try: + s.removesuffix(o) + except: + found=True + + assert found, f'Removesuffix accepts other type: {type(o)}: {o=}' + + +skip_if_unsupported(3,9,test_removeprefix) +skip_if_unsupported(3,9,test_removeprefix_types) +skip_if_unsupported(3,9,test_removesuffix) +skip_if_unsupported(3,9,test_removesuffix_types) + + diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index eb3a434db8..d4579cf095 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -510,6 +510,22 @@ impl PyString { .to_owned() } + #[pymethod] + fn removeprefix(&self, pref: PyStringRef) -> PyResult { + if self.value.as_str().starts_with(&pref.value) { + return Ok(self.value[pref.len()..].to_string()); + } + Ok(self.value.to_string()) + } + + #[pymethod] + fn removesuffix(&self, suff: PyStringRef) -> PyResult { + if self.value.as_str().ends_with(&suff.value) { + return Ok(self.value[..self.value.len() - suff.len()].to_string()); + } + Ok(self.value.to_string()) + } + #[pymethod] fn endswith(&self, args: pystr::StartsEndsWithArgs, vm: &VirtualMachine) -> PyResult { self.value.as_str().py_startsendswith( From da09370fb6fa0043b352a14fea2d34256304d137 Mon Sep 17 00:00:00 2001 From: TheAnyKey Date: Fri, 8 May 2020 15:32:12 +0000 Subject: [PATCH 04/19] Revert "merged from TheAnyKey/p39_string_rem_pre_suffix" This reverts commit 84a6e8e8df72a7ceeb7f45781857b70ed0a7b06e. --- tests/snippets/strings.py | 72 +-------------------------------------- vm/src/obj/objstr.rs | 16 --------- 2 files changed, 1 insertion(+), 87 deletions(-) diff --git a/tests/snippets/strings.py b/tests/snippets/strings.py index 7ceb653c1c..7471b70050 100644 --- a/tests/snippets/strings.py +++ b/tests/snippets/strings.py @@ -1,4 +1,4 @@ -from testutils import assert_raises, AssertRaises, skip_if_unsupported +from testutils import assert_raises, AssertRaises assert "".__eq__(1) == NotImplemented assert "a" == 'a' @@ -471,73 +471,3 @@ def try_mutate_str(): assert '{:e}'.format(float('inf')) == 'inf' assert '{:e}'.format(float('-inf')) == '-inf' assert '{:E}'.format(float('inf')) == 'INF' - - -# remove*fix test -def test_removeprefix(): - s='foobarfoo' - s_ref='foobarfoo' - assert s.removeprefix('f') == s_ref[1:] - assert s.removeprefix('fo') == s_ref[2:] - assert s.removeprefix('foo') == s_ref[3:] - - assert s.removeprefix('') == s_ref - assert s.removeprefix('bar') == s_ref - assert s.removeprefix('lol') == s_ref - assert s.removeprefix('_foo') == s_ref - assert s.removeprefix('-foo') == s_ref - assert s.removeprefix('afoo') == s_ref - assert s.removeprefix('*foo') == s_ref - - assert s==s_ref, 'undefined test fail' - -def test_removeprefix_types(): - s='0123456' - s_ref='0123456' - others=[0,['012']] - found=False - for o in others: - try: - s.removeprefix(o) - except: - found=True - - assert found, f'Removeprefix accepts other type: {type(o)}: {o=}' - -def test_removesuffix(): - s='foobarfoo' - s_ref='foobarfoo' - assert s.removesuffix('o') == s_ref[:-1] - assert s.removesuffix('oo') == s_ref[:-2] - assert s.removesuffix('foo') == s_ref[:-3] - - assert s.removesuffix('') == s_ref - assert s.removesuffix('bar') == s_ref - assert s.removesuffix('lol') == s_ref - assert s.removesuffix('foo_') == s_ref - assert s.removesuffix('foo-') == s_ref - assert s.removesuffix('foo*') == s_ref - assert s.removesuffix('fooa') == s_ref - - assert s==s_ref, 'undefined test fail' - -def test_removesuffix_types(): - s='0123456' - s_ref='0123456' - others=[0,6,['6']] - found=False - for o in others: - try: - s.removesuffix(o) - except: - found=True - - assert found, f'Removesuffix accepts other type: {type(o)}: {o=}' - - -skip_if_unsupported(3,9,test_removeprefix) -skip_if_unsupported(3,9,test_removeprefix_types) -skip_if_unsupported(3,9,test_removesuffix) -skip_if_unsupported(3,9,test_removesuffix_types) - - diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index d4579cf095..eb3a434db8 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -510,22 +510,6 @@ impl PyString { .to_owned() } - #[pymethod] - fn removeprefix(&self, pref: PyStringRef) -> PyResult { - if self.value.as_str().starts_with(&pref.value) { - return Ok(self.value[pref.len()..].to_string()); - } - Ok(self.value.to_string()) - } - - #[pymethod] - fn removesuffix(&self, suff: PyStringRef) -> PyResult { - if self.value.as_str().ends_with(&suff.value) { - return Ok(self.value[..self.value.len() - suff.len()].to_string()); - } - Ok(self.value.to_string()) - } - #[pymethod] fn endswith(&self, args: pystr::StartsEndsWithArgs, vm: &VirtualMachine) -> PyResult { self.value.as_str().py_startsendswith( From f448654229e670ea459f71775680af823a363299 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Sat, 9 May 2020 10:45:12 +0300 Subject: [PATCH 05/19] Make PyArrayIter ThreadSafe --- vm/src/stdlib/array.rs | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/vm/src/stdlib/array.rs b/vm/src/stdlib/array.rs index b4a437477e..9c96659c71 100644 --- a/vm/src/stdlib/array.rs +++ b/vm/src/stdlib/array.rs @@ -10,10 +10,11 @@ use crate::pyobject::{ }; use crate::VirtualMachine; -use std::cell::Cell; use std::fmt; use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; +use crossbeam_utils::atomic::AtomicCell; + struct ArrayTypeSpecifierError { _priv: (), } @@ -421,7 +422,7 @@ impl PyArray { #[pymethod(name = "__iter__")] fn iter(zelf: PyRef) -> PyArrayIter { PyArrayIter { - position: Cell::new(0), + position: AtomicCell::new(0), array: zelf, } } @@ -430,10 +431,12 @@ impl PyArray { #[pyclass] #[derive(Debug)] pub struct PyArrayIter { - position: Cell, + position: AtomicCell, array: PyArrayRef, } +impl ThreadSafe for PyArrayIter {} + impl PyValue for PyArrayIter { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("array", "arrayiterator") @@ -444,14 +447,9 @@ impl PyValue for PyArrayIter { impl PyArrayIter { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - if self.position.get() < self.array.borrow_value().len() { - let ret = self - .array - .borrow_value() - .getitem_by_idx(self.position.get(), vm) - .unwrap()?; - self.position.set(self.position.get() + 1); - Ok(ret) + let pos = self.position.fetch_add(1); + if let Some(item) = self.array.borrow_value().getitem_by_idx(pos, vm) { + Ok(item?) } else { Err(objiter::new_stop_iteration(vm)) } From 94e93f72625ce7ce81f9d68e735b64d00544318c Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Sat, 9 May 2020 11:36:44 +0300 Subject: [PATCH 06/19] Make PyDeque, PyDequeIterator ThreadSafe --- vm/src/stdlib/collections.rs | 95 ++++++++++++++++++++---------------- 1 file changed, 53 insertions(+), 42 deletions(-) diff --git a/vm/src/stdlib/collections.rs b/vm/src/stdlib/collections.rs index 4ec3633b5c..0bfb21d2c7 100644 --- a/vm/src/stdlib/collections.rs +++ b/vm/src/stdlib/collections.rs @@ -6,23 +6,27 @@ mod _collections { use crate::obj::{objiter, objtype::PyClassRef}; use crate::pyobject::{ IdProtocol, PyArithmaticValue::*, PyClassImpl, PyComparisonValue, PyIterable, PyObjectRef, - PyRef, PyResult, PyValue, + PyRef, PyResult, PyValue, ThreadSafe, }; use crate::sequence; use crate::vm::ReprGuard; use crate::VirtualMachine; use itertools::Itertools; - use std::cell::{Cell, RefCell}; use std::collections::VecDeque; + use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; + + use crossbeam_utils::atomic::AtomicCell; #[pyclass(name = "deque")] - #[derive(Debug, Clone)] + #[derive(Debug)] struct PyDeque { - deque: RefCell>, - maxlen: Cell>, + deque: RwLock>, + maxlen: AtomicCell>, } type PyDequeRef = PyRef; + impl ThreadSafe for PyDeque {} + impl PyValue for PyDeque { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("_collections", "deque") @@ -36,8 +40,12 @@ mod _collections { } impl PyDeque { - fn borrow_deque<'a>(&'a self) -> impl std::ops::Deref> + 'a { - self.deque.borrow() + fn borrow_deque(&self) -> RwLockReadGuard<'_, VecDeque> { + self.deque.read().unwrap() + } + + fn borrow_deque_mut(&self) -> RwLockWriteGuard<'_, VecDeque> { + self.deque.write().unwrap() } } @@ -51,8 +59,8 @@ mod _collections { vm: &VirtualMachine, ) -> PyResult> { let py_deque = PyDeque { - deque: RefCell::default(), - maxlen: maxlen.into(), + deque: RwLock::default(), + maxlen: AtomicCell::new(maxlen), }; if let OptionalArg::Present(iter) = iter { py_deque.extend(iter, vm)?; @@ -62,8 +70,8 @@ mod _collections { #[pymethod] fn append(&self, obj: PyObjectRef) { - let mut deque = self.deque.borrow_mut(); - if self.maxlen.get() == Some(deque.len()) { + let mut deque = self.borrow_deque_mut(); + if self.maxlen.load() == Some(deque.len()) { deque.pop_front(); } deque.push_back(obj); @@ -71,8 +79,8 @@ mod _collections { #[pymethod] fn appendleft(&self, obj: PyObjectRef) { - let mut deque = self.deque.borrow_mut(); - if self.maxlen.get() == Some(deque.len()) { + let mut deque = self.borrow_deque_mut(); + if self.maxlen.load() == Some(deque.len()) { deque.pop_back(); } deque.push_front(obj); @@ -80,18 +88,21 @@ mod _collections { #[pymethod] fn clear(&self) { - self.deque.borrow_mut().clear() + self.borrow_deque_mut().clear() } #[pymethod] fn copy(&self) -> Self { - self.clone() + PyDeque { + deque: RwLock::new(self.borrow_deque().clone()), + maxlen: AtomicCell::new(self.maxlen.load()), + } } #[pymethod] fn count(&self, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { let mut count = 0; - for elem in self.deque.borrow().iter() { + for elem in self.borrow_deque().iter() { if vm.identical_or_equal(elem, &obj)? { count += 1; } @@ -124,7 +135,7 @@ mod _collections { stop: OptionalArg, vm: &VirtualMachine, ) -> PyResult { - let deque = self.deque.borrow(); + let deque = self.borrow_deque(); let start = start.unwrap_or(0); let stop = stop.unwrap_or_else(|| deque.len()); for (i, elem) in deque.iter().skip(start).take(stop - start).enumerate() { @@ -141,9 +152,9 @@ mod _collections { #[pymethod] fn insert(&self, idx: i32, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - let mut deque = self.deque.borrow_mut(); + let mut deque = self.borrow_deque_mut(); - if self.maxlen.get() == Some(deque.len()) { + if self.maxlen.load() == Some(deque.len()) { return Err(vm.new_index_error("deque already at its maximum size".to_owned())); } @@ -166,23 +177,21 @@ mod _collections { #[pymethod] fn pop(&self, vm: &VirtualMachine) -> PyResult { - self.deque - .borrow_mut() + self.borrow_deque_mut() .pop_back() .ok_or_else(|| vm.new_index_error("pop from an empty deque".to_owned())) } #[pymethod] fn popleft(&self, vm: &VirtualMachine) -> PyResult { - self.deque - .borrow_mut() + self.borrow_deque_mut() .pop_front() .ok_or_else(|| vm.new_index_error("pop from an empty deque".to_owned())) } #[pymethod] fn remove(&self, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let mut deque = self.deque.borrow_mut(); + let mut deque = self.borrow_deque_mut(); let mut idx = None; for (i, elem) in deque.iter().enumerate() { if vm.identical_or_equal(elem, &obj)? { @@ -196,13 +205,13 @@ mod _collections { #[pymethod] fn reverse(&self) { - self.deque - .replace_with(|deque| deque.iter().cloned().rev().collect()); + let rev: VecDeque<_> = self.borrow_deque().iter().cloned().rev().collect(); + *self.borrow_deque_mut() = rev; } #[pymethod] fn rotate(&self, mid: OptionalArg) { - let mut deque = self.deque.borrow_mut(); + let mut deque = self.borrow_deque_mut(); let mid = mid.unwrap_or(1); if mid < 0 { deque.rotate_left(-mid as usize); @@ -213,26 +222,25 @@ mod _collections { #[pyproperty] fn maxlen(&self) -> Option { - self.maxlen.get() + self.maxlen.load() } #[pyproperty(setter)] fn set_maxlen(&self, maxlen: Option) { - self.maxlen.set(maxlen); + self.maxlen.store(maxlen); } #[pymethod(name = "__repr__")] fn repr(zelf: PyRef, vm: &VirtualMachine) -> PyResult { let repr = if let Some(_guard) = ReprGuard::enter(zelf.as_object()) { let elements = zelf - .deque - .borrow() + .borrow_deque() .iter() .map(|obj| vm.to_repr(obj)) .collect::, _>>()?; let maxlen = zelf .maxlen - .get() + .load() .map(|maxlen| format!(", maxlen={}", maxlen)) .unwrap_or_default(); format!("deque([{}]{})", elements.into_iter().format(", "), maxlen) @@ -336,29 +344,29 @@ mod _collections { #[pymethod(name = "__mul__")] fn mul(&self, n: isize) -> Self { - let deque: &VecDeque<_> = &self.deque.borrow(); + let deque: &VecDeque<_> = &self.borrow_deque(); let mul = sequence::seq_mul(deque, n); - let skipped = if let Some(maxlen) = self.maxlen.get() { + let skipped = if let Some(maxlen) = self.maxlen.load() { mul.len() - maxlen } else { 0 }; let deque = mul.skip(skipped).cloned().collect(); PyDeque { - deque: RefCell::new(deque), - maxlen: self.maxlen.clone(), + deque: RwLock::new(deque), + maxlen: AtomicCell::new(self.maxlen.load()), } } #[pymethod(name = "__len__")] fn len(&self) -> usize { - self.deque.borrow().len() + self.borrow_deque().len() } #[pymethod(name = "__iter__")] fn iter(zelf: PyRef) -> PyDequeIterator { PyDequeIterator { - position: Cell::new(0), + position: AtomicCell::new(0), deque: zelf, } } @@ -367,10 +375,12 @@ mod _collections { #[pyclass(name = "_deque_iterator")] #[derive(Debug)] struct PyDequeIterator { - position: Cell, + position: AtomicCell, deque: PyDequeRef, } + impl ThreadSafe for PyDequeIterator {} + impl PyValue for PyDequeIterator { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("_collections", "_deque_iterator") @@ -381,9 +391,10 @@ mod _collections { impl PyDequeIterator { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - if self.position.get() < self.deque.deque.borrow().len() { - let ret = self.deque.deque.borrow()[self.position.get()].clone(); - self.position.set(self.position.get() + 1); + let pos = self.position.fetch_add(1); + let deque = self.deque.borrow_deque(); + if pos < deque.len() { + let ret = deque[pos].clone(); Ok(ret) } else { Err(objiter::new_stop_iteration(vm)) From 1b585bda61c97673ff87cff510f758ddf303c960 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Sat, 9 May 2020 11:41:55 +0300 Subject: [PATCH 07/19] Make Reader ThreadSafe --- vm/src/stdlib/csv.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vm/src/stdlib/csv.rs b/vm/src/stdlib/csv.rs index fc42643a2a..b2ff582268 100644 --- a/vm/src/stdlib/csv.rs +++ b/vm/src/stdlib/csv.rs @@ -1,5 +1,5 @@ -use std::cell::RefCell; use std::fmt::{self, Debug, Formatter}; +use std::sync::RwLock; use csv as rust_csv; use itertools::join; @@ -10,7 +10,7 @@ use crate::obj::objiter; use crate::obj::objstr::{self, PyString}; use crate::obj::objtype::PyClassRef; use crate::pyobject::{IntoPyObject, TryFromObject, TypeProtocol}; -use crate::pyobject::{PyClassImpl, PyIterable, PyObjectRef, PyRef, PyResult, PyValue}; +use crate::pyobject::{PyClassImpl, PyIterable, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe}; use crate::types::create_type; use crate::VirtualMachine; @@ -126,9 +126,11 @@ impl ReadState { #[pyclass(name = "Reader")] struct Reader { - state: RefCell, + state: RwLock, } +impl ThreadSafe for Reader {} + impl Debug for Reader { fn fmt(&self, f: &mut Formatter) -> fmt::Result { write!(f, "_csv.reader") @@ -143,7 +145,7 @@ impl PyValue for Reader { impl Reader { fn new(iter: PyIterable, config: ReaderOption) -> Self { - let state = RefCell::new(ReadState::new(iter, config)); + let state = RwLock::new(ReadState::new(iter, config)); Reader { state } } } @@ -152,13 +154,13 @@ impl Reader { impl Reader { #[pymethod(name = "__iter__")] fn iter(this: PyRef, vm: &VirtualMachine) -> PyResult { - this.state.borrow_mut().cast_to_reader(vm)?; + this.state.write().unwrap().cast_to_reader(vm)?; this.into_pyobject(vm) } #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - let mut state = self.state.borrow_mut(); + let mut state = self.state.write().unwrap(); state.cast_to_reader(vm)?; if let ReadState::CsvIter(ref mut reader) = &mut *state { From 75af7f6b1c9f4f1fc99e1ac97a4cb8e72d973d1a Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Sat, 9 May 2020 12:07:35 +0300 Subject: [PATCH 08/19] Make PyBytesIO, PyStringIO ThreadSafe --- vm/src/stdlib/io.rs | 43 +++++++++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index 3b2760b92d..0a0736802e 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -1,10 +1,11 @@ /* * I/O core tools. */ -use std::cell::{RefCell, RefMut}; use std::fs; use std::io::{self, prelude::*, Cursor, SeekFrom}; +use std::sync::{RwLock, RwLockWriteGuard}; +use crossbeam_utils::atomic::AtomicCell; use num_traits::ToPrimitive; use crate::exceptions::PyBaseExceptionRef; @@ -18,7 +19,7 @@ use crate::obj::objiter; use crate::obj::objstr::{self, PyStringRef}; use crate::obj::objtype::{self, PyClassRef}; use crate::pyobject::{ - BufferProtocol, Either, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, + BufferProtocol, Either, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, TryFromObject, }; use crate::vm::VirtualMachine; @@ -120,11 +121,14 @@ impl BufferedIO { #[derive(Debug)] struct PyStringIO { - buffer: RefCell>, + buffer: RwLock, + closed: AtomicCell, } type PyStringIORef = PyRef; +impl ThreadSafe for PyStringIO {} + impl PyValue for PyStringIO { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("io", "StringIO") @@ -132,10 +136,9 @@ impl PyValue for PyStringIO { } impl PyStringIORef { - fn buffer(&self, vm: &VirtualMachine) -> PyResult> { - let buffer = self.buffer.borrow_mut(); - if buffer.is_some() { - Ok(RefMut::map(buffer, |opt| opt.as_mut().unwrap())) + fn buffer(&self, vm: &VirtualMachine) -> PyResult> { + if !self.closed.load() { + Ok(self.buffer.write().unwrap()) } else { Err(vm.new_value_error("I/O operation on closed file.".to_owned())) } @@ -209,11 +212,11 @@ impl PyStringIORef { } fn closed(self) -> bool { - self.buffer.borrow().is_none() + self.closed.load() } fn close(self) { - self.buffer.replace(None); + self.closed.store(true); } } @@ -235,18 +238,22 @@ fn string_io_new( let input = flatten.map_or_else(Vec::new, |v| objstr::borrow_value(&v).as_bytes().to_vec()); PyStringIO { - buffer: RefCell::new(Some(BufferedIO::new(Cursor::new(input)))), + buffer: RwLock::new(BufferedIO::new(Cursor::new(input))), + closed: AtomicCell::new(false), } .into_ref_with_type(vm, cls) } #[derive(Debug)] struct PyBytesIO { - buffer: RefCell>, + buffer: RwLock, + closed: AtomicCell, } type PyBytesIORef = PyRef; +impl ThreadSafe for PyBytesIO {} + impl PyValue for PyBytesIO { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("io", "BytesIO") @@ -254,10 +261,9 @@ impl PyValue for PyBytesIO { } impl PyBytesIORef { - fn buffer(&self, vm: &VirtualMachine) -> PyResult> { - let buffer = self.buffer.borrow_mut(); - if buffer.is_some() { - Ok(RefMut::map(buffer, |opt| opt.as_mut().unwrap())) + fn buffer(&self, vm: &VirtualMachine) -> PyResult> { + if !self.closed.load() { + Ok(self.buffer.write().unwrap()) } else { Err(vm.new_value_error("I/O operation on closed file.".to_owned())) } @@ -320,11 +326,11 @@ impl PyBytesIORef { } fn closed(self) -> bool { - self.buffer.borrow().is_none() + self.closed.load() } fn close(self) { - self.buffer.replace(None); + self.closed.store(true) } } @@ -339,7 +345,8 @@ fn bytes_io_new( }; PyBytesIO { - buffer: RefCell::new(Some(BufferedIO::new(Cursor::new(raw_bytes)))), + buffer: RwLock::new(BufferedIO::new(Cursor::new(raw_bytes))), + closed: AtomicCell::new(false), } .into_ref_with_type(vm, cls) } From 4cf5178d43647b2528e25a7e9cb95e989ab0012a Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Sat, 9 May 2020 17:12:27 +0300 Subject: [PATCH 09/19] Make PyItertools* ThreadSafe --- vm/src/stdlib/itertools.rs | 350 +++++++++++++++++++++---------------- 1 file changed, 196 insertions(+), 154 deletions(-) diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index ea7767fb20..03573f1326 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -2,11 +2,11 @@ pub(crate) use decl::make_module; #[pymodule(name = "itertools")] mod decl { + use crossbeam_utils::atomic::AtomicCell; use num_bigint::BigInt; use num_traits::{One, Signed, ToPrimitive, Zero}; - use std::cell::{Cell, RefCell}; use std::iter; - use std::rc::Rc; + use std::sync::{Arc, RwLock, RwLockWriteGuard}; use crate::function::{Args, OptionalArg, OptionalOption, PyFuncArgs}; use crate::obj::objbool; @@ -15,7 +15,8 @@ mod decl { use crate::obj::objtuple::PyTuple; use crate::obj::objtype::{self, PyClassRef}; use crate::pyobject::{ - IdProtocol, PyCallable, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, + IdProtocol, PyCallable, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, + TypeProtocol, }; use crate::vm::VirtualMachine; @@ -23,9 +24,12 @@ mod decl { #[derive(Debug)] struct PyItertoolsChain { iterables: Vec, - cur: RefCell<(usize, Option)>, + cur_idx: AtomicCell, + cached_iter: RwLock>, } + impl ThreadSafe for PyItertoolsChain {} + impl PyValue for PyItertoolsChain { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "chain") @@ -38,27 +42,40 @@ mod decl { fn tp_new(cls: PyClassRef, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult> { PyItertoolsChain { iterables: args.args, - cur: RefCell::new((0, None)), + cur_idx: AtomicCell::new(0), + cached_iter: RwLock::new(None), } .into_ref_with_type(vm, cls) } #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - let (ref mut cur_idx, ref mut cur_iter) = *self.cur.borrow_mut(); - while *cur_idx < self.iterables.len() { - if cur_iter.is_none() { - *cur_iter = Some(get_iter(vm, &self.iterables[*cur_idx])?); + loop { + let pos = self.cur_idx.load(); + if pos >= self.iterables.len() { + break; } + let cur_iter = if self.cached_iter.read().unwrap().is_none() { + // We need to call "get_iter" outside of the lock. + let iter = get_iter(vm, &self.iterables[pos])?; + *self.cached_iter.write().unwrap() = Some(iter.clone()); + iter + } else { + if let Some(cached_iter) = (*(self.cached_iter.read().unwrap())).clone() { + cached_iter + } else { + // Someone changed cached iter to None since we checked. + continue; + } + }; - // can't be directly inside the 'match' clause, otherwise the borrows collide. - let obj = call_next(vm, cur_iter.as_ref().unwrap()); - match obj { + // We need to call "call_next" outside of the lock. + match call_next(vm, &cur_iter) { Ok(ok) => return Ok(ok), Err(err) => { if objtype::isinstance(&err, &vm.ctx.exceptions.stop_iteration) { - *cur_idx += 1; - *cur_iter = None; + self.cur_idx.fetch_add(1); + *self.cached_iter.write().unwrap() = None; } else { return Err(err); } @@ -85,7 +102,8 @@ mod decl { PyItertoolsChain { iterables, - cur: RefCell::new((0, None)), + cur_idx: AtomicCell::new(0), + cached_iter: RwLock::new(None), } .into_ref_with_type(vm, cls) } @@ -98,6 +116,8 @@ mod decl { selector: PyObjectRef, } + impl ThreadSafe for PyItertoolsCompress {} + impl PyValue for PyItertoolsCompress { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "compress") @@ -145,10 +165,12 @@ mod decl { #[pyclass(name = "count")] #[derive(Debug)] struct PyItertoolsCount { - cur: RefCell, + cur: RwLock, step: BigInt, } + impl ThreadSafe for PyItertoolsCount {} + impl PyValue for PyItertoolsCount { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "count") @@ -174,7 +196,7 @@ mod decl { }; PyItertoolsCount { - cur: RefCell::new(start), + cur: RwLock::new(start), step, } .into_ref_with_type(vm, cls) @@ -182,8 +204,9 @@ mod decl { #[pymethod(name = "__next__")] fn next(&self) -> PyResult { - let result = self.cur.borrow().clone(); - *self.cur.borrow_mut() += &self.step; + let mut cur = self.cur.write().unwrap(); + let result = cur.clone(); + *cur += &self.step; Ok(PyInt::new(result)) } @@ -196,12 +219,13 @@ mod decl { #[pyclass(name = "cycle")] #[derive(Debug)] struct PyItertoolsCycle { - iter: RefCell, - saved: RefCell>, - index: Cell, - first_pass: Cell, + iter: PyObjectRef, + saved: RwLock>, + index: AtomicCell, } + impl ThreadSafe for PyItertoolsCycle {} + impl PyValue for PyItertoolsCycle { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "cycle") @@ -219,36 +243,31 @@ mod decl { let iter = get_iter(vm, &iterable)?; PyItertoolsCycle { - iter: RefCell::new(iter.clone()), - saved: RefCell::new(Vec::new()), - index: Cell::new(0), - first_pass: Cell::new(false), + iter: iter.clone(), + saved: RwLock::new(Vec::new()), + index: AtomicCell::new(0), } .into_ref_with_type(vm, cls) } #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - let item = if let Some(item) = get_next_object(vm, &self.iter.borrow())? { - if self.first_pass.get() { - return Ok(item); - } - - self.saved.borrow_mut().push(item.clone()); + let item = if let Some(item) = get_next_object(vm, &self.iter)? { + self.saved.write().unwrap().push(item.clone()); item } else { - if self.saved.borrow().len() == 0 { + let saved = self.saved.read().unwrap(); + if saved.len() == 0 { return Err(new_stop_iteration(vm)); } - let last_index = self.index.get(); - self.index.set(self.index.get() + 1); + let last_index = self.index.fetch_add(1); - if self.index.get() >= self.saved.borrow().len() { - self.index.set(0); + if last_index >= saved.len() - 1 { + self.index.store(0); } - self.saved.borrow()[last_index].clone() + saved[last_index].clone() }; Ok(item) @@ -264,9 +283,11 @@ mod decl { #[derive(Debug)] struct PyItertoolsRepeat { object: PyObjectRef, - times: Option>, + times: Option>, } + impl ThreadSafe for PyItertoolsRepeat {} + impl PyValue for PyItertoolsRepeat { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "repeat") @@ -283,7 +304,7 @@ mod decl { vm: &VirtualMachine, ) -> PyResult> { let times = match times.into_option() { - Some(int) => Some(RefCell::new(int.as_bigint().clone())), + Some(int) => Some(RwLock::new(int.as_bigint().clone())), None => None, }; @@ -297,10 +318,11 @@ mod decl { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { if let Some(ref times) = self.times { - if *times.borrow() <= BigInt::zero() { + let mut times = times.write().unwrap(); + if *times <= BigInt::zero() { return Err(new_stop_iteration(vm)); } - *times.borrow_mut() -= 1; + *times -= 1; } Ok(self.object.clone()) @@ -314,7 +336,7 @@ mod decl { #[pymethod(name = "__length_hint__")] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { match self.times { - Some(ref times) => vm.new_int(times.borrow().clone()), + Some(ref times) => vm.new_int(times.read().unwrap().clone()), None => vm.new_int(0), } } @@ -327,6 +349,8 @@ mod decl { iter: PyObjectRef, } + impl ThreadSafe for PyItertoolsStarmap {} + impl PyValue for PyItertoolsStarmap { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "starmap") @@ -366,9 +390,11 @@ mod decl { struct PyItertoolsTakewhile { predicate: PyObjectRef, iterable: PyObjectRef, - stop_flag: RefCell, + stop_flag: AtomicCell, } + impl ThreadSafe for PyItertoolsTakewhile {} + impl PyValue for PyItertoolsTakewhile { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "takewhile") @@ -389,14 +415,14 @@ mod decl { PyItertoolsTakewhile { predicate, iterable: iter, - stop_flag: RefCell::new(false), + stop_flag: AtomicCell::new(false), } .into_ref_with_type(vm, cls) } #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - if *self.stop_flag.borrow() { + if self.stop_flag.load() { return Err(new_stop_iteration(vm)); } @@ -409,7 +435,7 @@ mod decl { if verdict { Ok(obj) } else { - *self.stop_flag.borrow_mut() = true; + self.stop_flag.store(true); Err(new_stop_iteration(vm)) } } @@ -425,9 +451,11 @@ mod decl { struct PyItertoolsDropwhile { predicate: PyCallable, iterable: PyObjectRef, - start_flag: Cell, + start_flag: AtomicCell, } + impl ThreadSafe for PyItertoolsDropwhile {} + impl PyValue for PyItertoolsDropwhile { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "dropwhile") @@ -448,7 +476,7 @@ mod decl { PyItertoolsDropwhile { predicate, iterable: iter, - start_flag: Cell::new(false), + start_flag: AtomicCell::new(false), } .into_ref_with_type(vm, cls) } @@ -458,13 +486,13 @@ mod decl { let predicate = &self.predicate; let iterable = &self.iterable; - if !self.start_flag.get() { + if !self.start_flag.load() { loop { let obj = call_next(vm, iterable)?; let pred = predicate.clone(); let pred_value = vm.invoke(&pred.into_object(), vec![obj.clone()])?; if !objbool::boolval(vm, pred_value)? { - self.start_flag.set(true); + self.start_flag.store(true); return Ok(obj); } } @@ -482,12 +510,14 @@ mod decl { #[derive(Debug)] struct PyItertoolsIslice { iterable: PyObjectRef, - cur: RefCell, - next: RefCell, + cur: AtomicCell, + next: AtomicCell, stop: Option, step: usize, } + impl ThreadSafe for PyItertoolsIslice {} + impl PyValue for PyItertoolsIslice { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "islice") @@ -567,8 +597,8 @@ mod decl { PyItertoolsIslice { iterable: iter, - cur: RefCell::new(0), - next: RefCell::new(start), + cur: AtomicCell::new(0), + next: AtomicCell::new(start), stop, step, } @@ -577,23 +607,23 @@ mod decl { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - while *self.cur.borrow() < *self.next.borrow() { + while self.cur.load() < self.next.load() { call_next(vm, &self.iterable)?; - *self.cur.borrow_mut() += 1; + self.cur.fetch_add(1); } if let Some(stop) = self.stop { - if *self.cur.borrow() >= stop { + if self.cur.load() >= stop { return Err(new_stop_iteration(vm)); } } let obj = call_next(vm, &self.iterable)?; - *self.cur.borrow_mut() += 1; + self.cur.fetch_add(1); // TODO is this overflow check required? attempts to copy CPython. - let (next, ovf) = (*self.next.borrow()).overflowing_add(self.step); - *self.next.borrow_mut() = if ovf { self.stop.unwrap() } else { next }; + let (next, ovf) = self.next.load().overflowing_add(self.step); + self.next.store(if ovf { self.stop.unwrap() } else { next }); Ok(obj) } @@ -611,6 +641,8 @@ mod decl { iterable: PyObjectRef, } + impl ThreadSafe for PyItertoolsFilterFalse {} + impl PyValue for PyItertoolsFilterFalse { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "filterfalse") @@ -665,9 +697,11 @@ mod decl { struct PyItertoolsAccumulate { iterable: PyObjectRef, binop: PyObjectRef, - acc_value: RefCell>, + acc_value: RwLock>, } + impl ThreadSafe for PyItertoolsAccumulate {} + impl PyValue for PyItertoolsAccumulate { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "accumulate") @@ -688,7 +722,7 @@ mod decl { PyItertoolsAccumulate { iterable: iter, binop: binop.unwrap_or_else(|| vm.get_none()), - acc_value: RefCell::from(Option::None), + acc_value: RwLock::new(None), } .into_ref_with_type(vm, cls) } @@ -698,7 +732,9 @@ mod decl { let iterable = &self.iterable; let obj = call_next(vm, iterable)?; - let next_acc_value = match &*self.acc_value.borrow() { + let acc_value = self.acc_value.read().unwrap().clone(); + + let next_acc_value = match acc_value { None => obj.clone(), Some(value) => { if self.binop.is(&vm.get_none()) { @@ -708,7 +744,7 @@ mod decl { } } }; - self.acc_value.replace(Option::from(next_acc_value.clone())); + *self.acc_value.write().unwrap() = Some(next_acc_value.clone()); Ok(next_acc_value) } @@ -722,33 +758,37 @@ mod decl { #[derive(Debug)] struct PyItertoolsTeeData { iterable: PyObjectRef, - values: RefCell>, + values: RwLock>, } + impl ThreadSafe for PyItertoolsTeeData {} + impl PyItertoolsTeeData { - fn new(iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult> { - Ok(Rc::new(PyItertoolsTeeData { + fn new(iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult> { + Ok(Arc::new(PyItertoolsTeeData { iterable: get_iter(vm, &iterable)?, - values: RefCell::new(vec![]), + values: RwLock::new(vec![]), })) } fn get_item(&self, vm: &VirtualMachine, index: usize) -> PyResult { - if self.values.borrow().len() == index { + if self.values.read().unwrap().len() == index { let result = call_next(vm, &self.iterable)?; - self.values.borrow_mut().push(result); + self.values.write().unwrap().push(result); } - Ok(self.values.borrow()[index].clone()) + Ok(self.values.read().unwrap()[index].clone()) } } #[pyclass(name = "tee")] #[derive(Debug)] struct PyItertoolsTee { - tee_data: Rc, - index: Cell, + tee_data: Arc, + index: AtomicCell, } + impl ThreadSafe for PyItertoolsTee {} + impl PyValue for PyItertoolsTee { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "tee") @@ -764,7 +804,7 @@ mod decl { } Ok(PyItertoolsTee { tee_data: PyItertoolsTeeData::new(it, vm)?, - index: Cell::from(0), + index: AtomicCell::new(0), } .into_ref_with_type(vm, PyItertoolsTee::class(vm))? .into_object()) @@ -800,8 +840,8 @@ mod decl { #[pymethod(name = "__copy__")] fn copy(&self, vm: &VirtualMachine) -> PyResult { Ok(PyItertoolsTee { - tee_data: Rc::clone(&self.tee_data), - index: self.index.clone(), + tee_data: Arc::clone(&self.tee_data), + index: AtomicCell::new(self.index.load()), } .into_ref_with_type(vm, Self::class(vm))? .into_object()) @@ -809,8 +849,8 @@ mod decl { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - let value = self.tee_data.get_item(vm, self.index.get())?; - self.index.set(self.index.get() + 1); + let value = self.tee_data.get_item(vm, self.index.load())?; + self.index.fetch_add(1); Ok(value) } @@ -824,11 +864,13 @@ mod decl { #[derive(Debug)] struct PyItertoolsProduct { pools: Vec>, - idxs: RefCell>, - cur: Cell, - stop: Cell, + idxs: RwLock>, + cur: AtomicCell, + stop: AtomicCell, } + impl ThreadSafe for PyItertoolsProduct {} + impl PyValue for PyItertoolsProduct { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "product") @@ -871,9 +913,9 @@ mod decl { PyItertoolsProduct { pools, - idxs: RefCell::new(vec![0; l]), - cur: Cell::new(l - 1), - stop: Cell::new(false), + idxs: RwLock::new(vec![0; l]), + cur: AtomicCell::new(l - 1), + stop: AtomicCell::new(false), } .into_ref_with_type(vm, cls) } @@ -881,7 +923,7 @@ mod decl { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { // stop signal - if self.stop.get() { + if self.stop.load() { return Err(new_stop_iteration(vm)); } @@ -893,41 +935,36 @@ mod decl { } } + let idxs = self.idxs.write().unwrap(); + let res = PyTuple::from( pools .iter() - .zip(self.idxs.borrow().iter()) + .zip(idxs.iter()) .map(|(pool, idx)| pool[*idx].clone()) .collect::>(), ); - self.update_idxs(); - - if self.is_end() { - self.stop.set(true); - } + self.update_idxs(idxs); Ok(res.into_ref(vm).into_object()) } - fn is_end(&self) -> bool { - let cur = self.cur.get(); - self.idxs.borrow()[cur] == self.pools[cur].len() - 1 && cur == 0 - } - - fn update_idxs(&self) { - let lst_idx = &self.pools[self.cur.get()].len() - 1; + fn update_idxs(&self, mut idxs: RwLockWriteGuard<'_, Vec>) { + let cur = self.cur.load(); + let lst_idx = &self.pools[cur].len() - 1; - if self.idxs.borrow()[self.cur.get()] == lst_idx { - if self.is_end() { + if idxs[cur] == lst_idx { + if cur == 0 { + self.stop.store(true); return; } - self.idxs.borrow_mut()[self.cur.get()] = 0; - self.cur.set(self.cur.get() - 1); - self.update_idxs(); + idxs[cur] = 0; + self.cur.fetch_sub(1); + self.update_idxs(idxs); } else { - self.idxs.borrow_mut()[self.cur.get()] += 1; - self.cur.set(self.idxs.borrow().len() - 1); + idxs[cur] += 1; + self.cur.store(idxs.len() - 1); } } @@ -941,11 +978,13 @@ mod decl { #[derive(Debug)] struct PyItertoolsCombinations { pool: Vec, - indices: RefCell>, - r: Cell, - exhausted: Cell, + indices: RwLock>, + r: AtomicCell, + exhausted: AtomicCell, } + impl ThreadSafe for PyItertoolsCombinations {} + impl PyValue for PyItertoolsCombinations { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "combinations") @@ -974,9 +1013,9 @@ mod decl { PyItertoolsCombinations { pool, - indices: RefCell::new((0..r).collect()), - r: Cell::new(r), - exhausted: Cell::new(r > n), + indices: RwLock::new((0..r).collect()), + r: AtomicCell::new(r), + exhausted: AtomicCell::new(r > n), } .into_ref_with_type(vm, cls) } @@ -989,27 +1028,28 @@ mod decl { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { // stop signal - if self.exhausted.get() { + if self.exhausted.load() { return Err(new_stop_iteration(vm)); } let n = self.pool.len(); - let r = self.r.get(); + let r = self.r.load(); if r == 0 { - self.exhausted.set(true); + self.exhausted.store(true); return Ok(vm.ctx.new_tuple(vec![])); } let res = PyTuple::from( self.indices - .borrow() + .read() + .unwrap() .iter() .map(|&i| self.pool[i].clone()) .collect::>(), ); - let mut indices = self.indices.borrow_mut(); + let mut indices = self.indices.write().unwrap(); // Scan indices right-to-left until finding one that is not at its maximum (i + n - r). let mut idx = r as isize - 1; @@ -1020,7 +1060,7 @@ mod decl { // If no suitable index is found, then the indices are all at // their maximum value and we're done. if idx < 0 { - self.exhausted.set(true); + self.exhausted.store(true); } else { // Increment the current index which we know is not at its // maximum. Then move back to the right setting each index @@ -1040,11 +1080,13 @@ mod decl { #[derive(Debug)] struct PyItertoolsCombinationsWithReplacement { pool: Vec, - indices: RefCell>, - r: Cell, - exhausted: Cell, + indices: RwLock>, + r: AtomicCell, + exhausted: AtomicCell, } + impl ThreadSafe for PyItertoolsCombinationsWithReplacement {} + impl PyValue for PyItertoolsCombinationsWithReplacement { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "combinations_with_replacement") @@ -1073,9 +1115,9 @@ mod decl { PyItertoolsCombinationsWithReplacement { pool, - indices: RefCell::new(vec![0; r]), - r: Cell::new(r), - exhausted: Cell::new(n == 0 && r > 0), + indices: RwLock::new(vec![0; r]), + r: AtomicCell::new(r), + exhausted: AtomicCell::new(n == 0 && r > 0), } .into_ref_with_type(vm, cls) } @@ -1088,19 +1130,19 @@ mod decl { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { // stop signal - if self.exhausted.get() { + if self.exhausted.load() { return Err(new_stop_iteration(vm)); } let n = self.pool.len(); - let r = self.r.get(); + let r = self.r.load(); if r == 0 { - self.exhausted.set(true); + self.exhausted.store(true); return Ok(vm.ctx.new_tuple(vec![])); } - let mut indices = self.indices.borrow_mut(); + let mut indices = self.indices.write().unwrap(); let res = vm .ctx @@ -1115,7 +1157,7 @@ mod decl { // If no suitable index is found, then the indices are all at // their maximum value and we're done. if idx < 0 { - self.exhausted.set(true); + self.exhausted.store(true); } else { let index = indices[idx as usize] + 1; @@ -1133,14 +1175,16 @@ mod decl { #[pyclass(name = "permutations")] #[derive(Debug)] struct PyItertoolsPermutations { - pool: Vec, // Collected input iterable - indices: RefCell>, // One index per element in pool - cycles: RefCell>, // One rollover counter per element in the result - result: RefCell>>, // Indexes of the most recently returned result - r: Cell, // Size of result tuple - exhausted: Cell, // Set when the iterator is exhausted + pool: Vec, // Collected input iterable + indices: RwLock>, // One index per element in pool + cycles: RwLock>, // One rollover counter per element in the result + result: RwLock>>, // Indexes of the most recently returned result + r: AtomicCell, // Size of result tuple + exhausted: AtomicCell, // Set when the iterator is exhausted } + impl ThreadSafe for PyItertoolsPermutations {} + impl PyValue for PyItertoolsPermutations { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "permutations") @@ -1179,11 +1223,11 @@ mod decl { PyItertoolsPermutations { pool, - indices: RefCell::new((0..n).collect()), - cycles: RefCell::new((0..r).map(|i| n - i).collect()), - result: RefCell::new(None), - r: Cell::new(r), - exhausted: Cell::new(r > n), + indices: RwLock::new((0..n).collect()), + cycles: RwLock::new((0..r).map(|i| n - i).collect()), + result: RwLock::new(None), + r: AtomicCell::new(r), + exhausted: AtomicCell::new(r > n), } .into_ref_with_type(vm, cls) } @@ -1196,23 +1240,23 @@ mod decl { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { // stop signal - if self.exhausted.get() { + if self.exhausted.load() { return Err(new_stop_iteration(vm)); } let n = self.pool.len(); - let r = self.r.get(); + let r = self.r.load(); if n == 0 { - self.exhausted.set(true); + self.exhausted.store(true); return Ok(vm.ctx.new_tuple(vec![])); } - let result = &mut *self.result.borrow_mut(); + let mut result = self.result.write().unwrap(); - if let Some(ref mut result) = result { - let mut indices = self.indices.borrow_mut(); - let mut cycles = self.cycles.borrow_mut(); + if let Some(ref mut result) = *result { + let mut indices = self.indices.write().unwrap(); + let mut cycles = self.cycles.write().unwrap(); let mut sentinel = false; // Decrement rightmost cycle, moving leftward upon zero rollover @@ -1241,7 +1285,7 @@ mod decl { } } if !sentinel { - self.exhausted.set(true); + self.exhausted.store(true); return Err(new_stop_iteration(vm)); } } else { @@ -1265,9 +1309,10 @@ mod decl { struct PyItertoolsZipLongest { iterators: Vec, fillvalue: PyObjectRef, - numactive: Cell, } + impl ThreadSafe for PyItertoolsZipLongest {} + impl PyValue for PyItertoolsZipLongest { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "zip_longest") @@ -1299,12 +1344,9 @@ mod decl { .map(|iterable| get_iter(vm, &iterable)) .collect::, _>>()?; - let numactive = Cell::new(iterators.len()); - PyItertoolsZipLongest { iterators, fillvalue, - numactive, } .into_ref_with_type(vm, cls) } @@ -1315,7 +1357,7 @@ mod decl { Err(new_stop_iteration(vm)) } else { let mut result: Vec = Vec::new(); - let mut numactive = self.numactive.get(); + let mut numactive = self.iterators.len(); for idx in 0..self.iterators.len() { let next_obj = match call_next(vm, &self.iterators[idx]) { From 8466f45f2a255d8285224615af4cac3cd793010c Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Sat, 9 May 2020 23:23:18 +0300 Subject: [PATCH 10/19] Fix clippy error --- vm/src/stdlib/itertools.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 03573f1326..742b2b9384 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -60,13 +60,11 @@ mod decl { let iter = get_iter(vm, &self.iterables[pos])?; *self.cached_iter.write().unwrap() = Some(iter.clone()); iter + } else if let Some(cached_iter) = (*(self.cached_iter.read().unwrap())).clone() { + cached_iter } else { - if let Some(cached_iter) = (*(self.cached_iter.read().unwrap())).clone() { - cached_iter - } else { - // Someone changed cached iter to None since we checked. - continue; - } + // Someone changed cached iter to None since we checked. + continue; }; // We need to call "call_next" outside of the lock. From 25913a613f8507c4e3a23bef72c0f2f8d044d743 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Sat, 9 May 2020 23:29:20 +0300 Subject: [PATCH 11/19] Remove expected failure from test_exhausted_iterator --- Lib/test/test_array.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/Lib/test/test_array.py b/Lib/test/test_array.py index 6bdbfe9f0a..7cca83d783 100644 --- a/Lib/test/test_array.py +++ b/Lib/test/test_array.py @@ -341,8 +341,6 @@ def test_iterator_pickle(self): a.fromlist(data2) self.assertEqual(list(it), []) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_exhausted_iterator(self): a = array.array(self.typecode, self.example) self.assertEqual(list(a), list(self.example)) From 73d35bcdbaf6c7117b69b6656f69a0090356e158 Mon Sep 17 00:00:00 2001 From: mrmiywj Date: Sun, 10 May 2020 13:44:05 -0700 Subject: [PATCH 12/19] add support for getresuid --- vm/src/stdlib/os.rs | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/vm/src/stdlib/os.rs b/vm/src/stdlib/os.rs index bd2a4e6536..583d2b3503 100644 --- a/vm/src/stdlib/os.rs +++ b/vm/src/stdlib/os.rs @@ -1457,6 +1457,25 @@ fn os_sync(_vm: &VirtualMachine) -> PyResult<()> { Ok(()) } +// cfg from nix +#[cfg(any( + target_os = "android", + target_os = "freebsd", + target_os = "linux", + target_os = "openbsd" +))] +fn os_getresuid(vm: &VirtualMachine) -> PyResult<(u32, u32, u32)> { + let mut ruid = 0; + let mut euid = 0; + let mut suid = 0; + let ret = unsafe { libc::getresuid(&mut ruid, &mut euid, &mut suid) }; + if ret == 0 { + Ok((ruid, euid, suid)) + } else { + Err(errno_err(vm)) + } +} + pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let ctx = &vm.ctx; @@ -1688,6 +1707,7 @@ fn extend_module_platform_specific(vm: &VirtualMachine, module: &PyObjectRef) { ))] extend_module!(vm, module, { "setresuid" => ctx.new_function(os_setresuid), + "getresuid" => ctx.new_function(os_getresuid), }); // cfg taken from nix From 93526950a13f50b7352a3d8e488b02e7612c2976 Mon Sep 17 00:00:00 2001 From: mrmiywj Date: Sun, 10 May 2020 16:04:22 -0700 Subject: [PATCH 13/19] add_getresgid_support --- vm/src/stdlib/os.rs | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/vm/src/stdlib/os.rs b/vm/src/stdlib/os.rs index 583d2b3503..99a926f587 100644 --- a/vm/src/stdlib/os.rs +++ b/vm/src/stdlib/os.rs @@ -1476,6 +1476,25 @@ fn os_getresuid(vm: &VirtualMachine) -> PyResult<(u32, u32, u32)> { } } +// cfg from nix +#[cfg(any( + target_os = "android", + target_os = "freebsd", + target_os = "linux", + target_os = "openbsd" +))] +fn os_getresgid(vm: &VirtualMachine) -> PyResult<(u32, u32, u32)> { + let mut rgid = 0; + let mut egid = 0; + let mut sgid = 0; + let ret = unsafe { libc::getresgid(&mut rgid, &mut egid, &mut sgid) }; + if ret == 0 { + Ok((rgid, egid, sgid)) + } else { + Err(errno_err(vm)) + } +} + pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let ctx = &vm.ctx; @@ -1708,6 +1727,7 @@ fn extend_module_platform_specific(vm: &VirtualMachine, module: &PyObjectRef) { extend_module!(vm, module, { "setresuid" => ctx.new_function(os_setresuid), "getresuid" => ctx.new_function(os_getresuid), + "getresgid" => ctx.new_function(os_getresgid), }); // cfg taken from nix From 86dd0d6bfdb89ab5e1c9876db3293080ff8ebaea Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Wed, 13 May 2020 03:07:51 +0900 Subject: [PATCH 14/19] suppress new version of flake8 lint error for test --- tests/snippets/tuple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/snippets/tuple.py b/tests/snippets/tuple.py index 0f4306fa61..fd59e90609 100644 --- a/tests/snippets/tuple.py +++ b/tests/snippets/tuple.py @@ -44,7 +44,7 @@ def __eq__(self, x): b = (55, *a) assert b == (55, 1, 2, 3, 1) -assert () is () +assert () is () # noqa a = () b = () From 174a21cba8ffc4fc8068e90c9fdc3a23a13d6afc Mon Sep 17 00:00:00 2001 From: TheAnyKey Date: Thu, 14 May 2020 19:06:00 +0000 Subject: [PATCH 15/19] implement and test Py39 string operations removeprefix and removesuffix. Added test snippets for it using an also contained extension of testutils --- tests/snippets/strings.py | 69 ++++++++++++++++++++++++++++++++++++- tests/snippets/testutils.py | 1 - vm/src/obj/objstr.rs | 16 +++++++++ 3 files changed, 84 insertions(+), 2 deletions(-) diff --git a/tests/snippets/strings.py b/tests/snippets/strings.py index 7471b70050..6bab4a0863 100644 --- a/tests/snippets/strings.py +++ b/tests/snippets/strings.py @@ -1,4 +1,4 @@ -from testutils import assert_raises, AssertRaises +from testutils import assert_raises, AssertRaises, skip_if_unsupported assert "".__eq__(1) == NotImplemented assert "a" == 'a' @@ -471,3 +471,70 @@ def try_mutate_str(): assert '{:e}'.format(float('inf')) == 'inf' assert '{:e}'.format(float('-inf')) == '-inf' assert '{:E}'.format(float('inf')) == 'INF' + + +# remove*fix test +def test_removeprefix(): + s='foobarfoo' + s_ref='foobarfoo' + assert s.removeprefix('f') == s_ref[1:] + assert s.removeprefix('fo') == s_ref[2:] + assert s.removeprefix('foo') == s_ref[3:] + + assert s.removeprefix('') == s_ref + assert s.removeprefix('bar') == s_ref + assert s.removeprefix('lol') == s_ref + assert s.removeprefix('_foo') == s_ref + assert s.removeprefix('-foo') == s_ref + assert s.removeprefix('afoo') == s_ref + assert s.removeprefix('*foo') == s_ref + + assert s==s_ref, 'undefined test fail' + +def test_removeprefix_types(): + s='0123456' + s_ref='0123456' + others=[0,['012']] + found=False + for o in others: + try: + s.removeprefix(o) + except: + found=True + + assert found, f'Removeprefix accepts other type: {type(o)}: {o=}' + +def test_removesuffix(): + s='foobarfoo' + s_ref='foobarfoo' + assert s.removesuffix('o') == s_ref[:-1] + assert s.removesuffix('oo') == s_ref[:-2] + assert s.removesuffix('foo') == s_ref[:-3] + + assert s.removesuffix('') == s_ref + assert s.removesuffix('bar') == s_ref + assert s.removesuffix('lol') == s_ref + assert s.removesuffix('foo_') == s_ref + assert s.removesuffix('foo-') == s_ref + assert s.removesuffix('foo*') == s_ref + assert s.removesuffix('fooa') == s_ref + + assert s==s_ref, 'undefined test fail' + +def test_removesuffix_types(): + s='0123456' + s_ref='0123456' + others=[0,6,['6']] + found=False + for o in others: + try: + s.removesuffix(o) + except: + found=True + + assert found, f'Removesuffix accepts other type: {type(o)}: {o=}' + +skip_if_unsupported(3,9,test_removeprefix) +skip_if_unsupported(3,9,test_removeprefix_types) +skip_if_unsupported(3,9,test_removesuffix) +skip_if_unsupported(3,9,test_removesuffix_types) \ No newline at end of file diff --git a/tests/snippets/testutils.py b/tests/snippets/testutils.py index c779d2c898..437fa06ae3 100644 --- a/tests/snippets/testutils.py +++ b/tests/snippets/testutils.py @@ -92,4 +92,3 @@ def exec(): exec() else: assert False, f'Test cannot performed on this python version. {platform.python_implementation()} {platform.python_version()}' - diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index eb3a434db8..0b1747affd 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -532,6 +532,22 @@ impl PyString { ) } + #[pymethod] + fn removeprefix(&self, pref: PyStringRef) -> PyResult { + if self.value.as_str().starts_with(&pref.value) { + return Ok(self.value[pref.len()..].to_string()); + } + Ok(self.value.to_string()) + } + + #[pymethod] + fn removesuffix(&self, suff: PyStringRef) -> PyResult { + if self.value.as_str().ends_with(&suff.value) { + return Ok(self.value[..self.value.len() - suff.len()].to_string()); + } + Ok(self.value.to_string()) + } + #[pymethod] fn isalnum(&self) -> bool { !self.value.is_empty() && self.value.chars().all(char::is_alphanumeric) From 5b1fb7f5acfc14b6fb671b4afab780235ac70009 Mon Sep 17 00:00:00 2001 From: TheAnyKey Date: Sat, 9 May 2020 17:22:59 +0000 Subject: [PATCH 16/19] fixed len issue and updated test accordingly --- Lib/test/string_tests.py | 36 ++++++++++++++++++++++++++++++++++++ tests/snippets/strings.py | 30 +++++++++++++++++++++++++++++- vm/src/obj/objstr.rs | 4 ++-- 3 files changed, 67 insertions(+), 3 deletions(-) diff --git a/Lib/test/string_tests.py b/Lib/test/string_tests.py index 2c6e0f84fa..ef2dc93cde 100644 --- a/Lib/test/string_tests.py +++ b/Lib/test/string_tests.py @@ -681,6 +681,42 @@ def test_replace_overflow(self): self.checkraises(OverflowError, A2_16, "replace", "A", A2_16) self.checkraises(OverflowError, A2_16, "replace", "AA", A2_16+A2_16) + def test_removeprefix(self): + self.checkequal('am', 'spam', 'removeprefix', 'sp') + self.checkequal('spamspam', 'spamspamspam', 'removeprefix', 'spam') + self.checkequal('spam', 'spam', 'removeprefix', 'python') + self.checkequal('spam', 'spam', 'removeprefix', 'spider') + self.checkequal('spam', 'spam', 'removeprefix', 'spam and eggs') + + self.checkequal('', '', 'removeprefix', '') + self.checkequal('', '', 'removeprefix', 'abcde') + self.checkequal('abcde', 'abcde', 'removeprefix', '') + self.checkequal('', 'abcde', 'removeprefix', 'abcde') + + self.checkraises(TypeError, 'hello', 'removeprefix') + self.checkraises(TypeError, 'hello', 'removeprefix', 42) + self.checkraises(TypeError, 'hello', 'removeprefix', 42, 'h') + self.checkraises(TypeError, 'hello', 'removeprefix', 'h', 42) + self.checkraises(TypeError, 'hello', 'removeprefix', ("he", "l")) + + def test_removesuffix(self): + self.checkequal('sp', 'spam', 'removesuffix', 'am') + self.checkequal('spamspam', 'spamspamspam', 'removesuffix', 'spam') + self.checkequal('spam', 'spam', 'removesuffix', 'python') + self.checkequal('spam', 'spam', 'removesuffix', 'blam') + self.checkequal('spam', 'spam', 'removesuffix', 'eggs and spam') + + self.checkequal('', '', 'removesuffix', '') + self.checkequal('', '', 'removesuffix', 'abcde') + self.checkequal('abcde', 'abcde', 'removesuffix', '') + self.checkequal('', 'abcde', 'removesuffix', 'abcde') + + self.checkraises(TypeError, 'hello', 'removesuffix') + self.checkraises(TypeError, 'hello', 'removesuffix', 42) + self.checkraises(TypeError, 'hello', 'removesuffix', 42, 'h') + self.checkraises(TypeError, 'hello', 'removesuffix', 'h', 42) + self.checkraises(TypeError, 'hello', 'removesuffix', ("lo", "l")) + def test_capitalize(self): self.checkequal(' hello ', ' hello ', 'capitalize') self.checkequal('Hello ', 'Hello ','capitalize') diff --git a/tests/snippets/strings.py b/tests/snippets/strings.py index 6bab4a0863..1d53e30d19 100644 --- a/tests/snippets/strings.py +++ b/tests/snippets/strings.py @@ -475,7 +475,7 @@ def try_mutate_str(): # remove*fix test def test_removeprefix(): - s='foobarfoo' + s = 'foobarfoo' s_ref='foobarfoo' assert s.removeprefix('f') == s_ref[1:] assert s.removeprefix('fo') == s_ref[2:] @@ -491,6 +491,20 @@ def test_removeprefix(): assert s==s_ref, 'undefined test fail' + s_uc = '๐Ÿ˜ฑfoobarfoo๐Ÿ––' + s_ref_uc = '๐Ÿ˜ฑfoobarfoo๐Ÿ––' + assert s_uc.removeprefix('๐Ÿ˜ฑ') == s_ref_uc[1:] + assert s_uc.removeprefix('๐Ÿ˜ฑfo') == s_ref_uc[3:] + assert s_uc.removeprefix('๐Ÿ˜ฑfoo') == s_ref_uc[4:] + + assert s_uc.removeprefix('๐Ÿ––') == s_ref_uc + assert s_uc.removeprefix('foo') == s_ref_uc + assert s_uc.removeprefix(' ') == s_ref_uc + assert s_uc.removeprefix('_๐Ÿ˜ฑ') == s_ref_uc + assert s_uc.removeprefix(' ๐Ÿ˜ฑ') == s_ref_uc + assert s_uc.removeprefix('-๐Ÿ˜ฑ') == s_ref_uc + assert s_uc.removeprefix('#๐Ÿ˜ฑ') == s_ref_uc + def test_removeprefix_types(): s='0123456' s_ref='0123456' @@ -521,6 +535,20 @@ def test_removesuffix(): assert s==s_ref, 'undefined test fail' + s_uc = '๐Ÿ˜ฑfoobarfoo๐Ÿ––' + s_ref_uc = '๐Ÿ˜ฑfoobarfoo๐Ÿ––' + assert s_uc.removesuffix('๐Ÿ––') == s_ref_uc[:-1] + assert s_uc.removesuffix('oo๐Ÿ––') == s_ref_uc[:-3] + assert s_uc.removesuffix('foo๐Ÿ––') == s_ref_uc[:-4] + + assert s_uc.removesuffix('๐Ÿ˜ฑ') == s_ref_uc + assert s_uc.removesuffix('foo') == s_ref_uc + assert s_uc.removesuffix(' ') == s_ref_uc + assert s_uc.removesuffix('๐Ÿ––_') == s_ref_uc + assert s_uc.removesuffix('๐Ÿ–– ') == s_ref_uc + assert s_uc.removesuffix('๐Ÿ––-') == s_ref_uc + assert s_uc.removesuffix('๐Ÿ––#') == s_ref_uc + def test_removesuffix_types(): s='0123456' s_ref='0123456' diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index 0b1747affd..59ad91a794 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -535,7 +535,7 @@ impl PyString { #[pymethod] fn removeprefix(&self, pref: PyStringRef) -> PyResult { if self.value.as_str().starts_with(&pref.value) { - return Ok(self.value[pref.len()..].to_string()); + return Ok(self.value[pref.value.len()..].to_string()); } Ok(self.value.to_string()) } @@ -543,7 +543,7 @@ impl PyString { #[pymethod] fn removesuffix(&self, suff: PyStringRef) -> PyResult { if self.value.as_str().ends_with(&suff.value) { - return Ok(self.value[..self.value.len() - suff.len()].to_string()); + return Ok(self.value[..self.value.len() - suff.value.len()].to_string()); } Ok(self.value.to_string()) } From 785454df2c1c5d8795e8ce89176c7b1ee17b7269 Mon Sep 17 00:00:00 2001 From: TheAnyKey Date: Sun, 10 May 2020 17:31:29 +0000 Subject: [PATCH 17/19] removeprefix, suffix: implementation for bytes and bytes array --- Lib/test/string_tests.py | 3 +++ vm/src/obj/objbytearray.rs | 20 ++++++++++++++++++++ vm/src/obj/objbyteinner.rs | 16 ++++++++++++++++ vm/src/obj/objbytes.rs | 10 ++++++++++ 4 files changed, 49 insertions(+) diff --git a/Lib/test/string_tests.py b/Lib/test/string_tests.py index ef2dc93cde..2b0d8a9d79 100644 --- a/Lib/test/string_tests.py +++ b/Lib/test/string_tests.py @@ -681,6 +681,8 @@ def test_replace_overflow(self): self.checkraises(OverflowError, A2_16, "replace", "A", A2_16) self.checkraises(OverflowError, A2_16, "replace", "AA", A2_16+A2_16) + + # Python 3.9 def test_removeprefix(self): self.checkequal('am', 'spam', 'removeprefix', 'sp') self.checkequal('spamspam', 'spamspamspam', 'removeprefix', 'spam') @@ -699,6 +701,7 @@ def test_removeprefix(self): self.checkraises(TypeError, 'hello', 'removeprefix', 'h', 42) self.checkraises(TypeError, 'hello', 'removeprefix', ("he", "l")) + # Python 3.9 def test_removesuffix(self): self.checkequal('sp', 'spam', 'removesuffix', 'am') self.checkequal('spamspam', 'spamspamspam', 'removesuffix', 'spam') diff --git a/vm/src/obj/objbytearray.rs b/vm/src/obj/objbytearray.rs index 029be2d1c9..2cd346684c 100644 --- a/vm/src/obj/objbytearray.rs +++ b/vm/src/obj/objbytearray.rs @@ -389,6 +389,26 @@ impl PyByteArray { self.borrow_value().rstrip(chars).into() } + #[pymethod(name = "removeprefix")] + fn removeprefix(&self, prefix: PyByteInner) -> PyByteArray { + let value = self.borrow_value(); + if value.elements.starts_with(&prefix.elements) { + return value.elements[prefix.elements.len()..].to_vec().into(); + } + value.elements.to_vec().into() + } + + #[pymethod(name = "removesuffix")] + fn removesuffix(&self, suffix: PyByteInner) -> PyByteArray { + let value = self.borrow_value(); + if value.elements.ends_with(&suffix.elements) { + return value.elements[..value.elements.len() - suffix.elements.len()] + .to_vec() + .into(); + } + value.elements.to_vec().into() + } + #[pymethod(name = "split")] fn split(&self, options: ByteInnerSplitOptions, vm: &VirtualMachine) -> PyResult { self.borrow_value() diff --git a/vm/src/obj/objbyteinner.rs b/vm/src/obj/objbyteinner.rs index 38f6f0b718..41fa4829ef 100644 --- a/vm/src/obj/objbyteinner.rs +++ b/vm/src/obj/objbyteinner.rs @@ -843,6 +843,22 @@ impl PyByteInner { .to_vec() } + // new in Python 3.9 + pub fn removeprefix(&self, prefix: PyByteInner) -> Vec { + if self.elements.starts_with(&prefix.elements) { + return self.elements[prefix.elements.len()..].to_vec(); + } + self.elements.to_vec() + } + + // new in Python 3.9 + pub fn removesuffix(&self, suffix: PyByteInner) -> Vec { + if self.elements.ends_with(&suffix.elements) { + return self.elements[..self.elements.len() - suffix.elements.len()].to_vec(); + } + self.elements.to_vec() + } + pub fn split( &self, options: ByteInnerSplitOptions, diff --git a/vm/src/obj/objbytes.rs b/vm/src/obj/objbytes.rs index bf76f49a2a..284939788b 100644 --- a/vm/src/obj/objbytes.rs +++ b/vm/src/obj/objbytes.rs @@ -347,6 +347,16 @@ impl PyBytes { self.inner.rstrip(chars).into() } + #[pymethod(name = "removeprefix")] + fn removeprefix(&self, prefix: PyByteInner) -> PyBytes { + self.inner.removeprefix(prefix).into() + } + + #[pymethod(name = "removesuffix")] + fn removesuffix(&self, suffix: PyByteInner) -> PyBytes { + self.inner.removesuffix(suffix).into() + } + #[pymethod(name = "split")] fn split(&self, options: ByteInnerSplitOptions, vm: &VirtualMachine) -> PyResult { self.inner From 46bc077e55e24bb05d52af3f8c87ee38a9534b57 Mon Sep 17 00:00:00 2001 From: TheAnyKey Date: Wed, 13 May 2020 20:52:11 +0000 Subject: [PATCH 18/19] unified implementation of removeprefix and removesuffix, added pydocs --- vm/src/obj/objbytearray.rs | 38 ++++++++++++++++++++++++++------------ vm/src/obj/objbyteinner.rs | 20 ++++++++++++-------- vm/src/obj/objbytes.rs | 14 ++++++++++++++ vm/src/obj/objstr.rs | 34 ++++++++++++++++++++++++---------- vm/src/obj/pystr.rs | 38 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 114 insertions(+), 30 deletions(-) diff --git a/vm/src/obj/objbytearray.rs b/vm/src/obj/objbytearray.rs index 2cd346684c..3f06674f8d 100644 --- a/vm/src/obj/objbytearray.rs +++ b/vm/src/obj/objbytearray.rs @@ -389,24 +389,38 @@ impl PyByteArray { self.borrow_value().rstrip(chars).into() } + /// removeprefix($self, prefix, /) + /// + /// + /// Return a bytearray object with the given prefix string removed if present. + /// + /// If the bytearray starts with the prefix string, return string[len(prefix):] + /// Otherwise, return a copy of the original bytearray. #[pymethod(name = "removeprefix")] fn removeprefix(&self, prefix: PyByteInner) -> PyByteArray { - let value = self.borrow_value(); - if value.elements.starts_with(&prefix.elements) { - return value.elements[prefix.elements.len()..].to_vec().into(); - } - value.elements.to_vec().into() + self.borrow_value().elements[..] + .py_removeprefix(&prefix.elements, prefix.elements.len(), |s, p| { + s.starts_with(p) + }) + .to_vec() + .into() } + /// removesuffix(self, prefix, /) + /// + /// + /// Return a bytearray object with the given suffix string removed if present. + /// + /// If the bytearray ends with the suffix string, return string[:len(suffix)] + /// Otherwise, return a copy of the original bytearray. #[pymethod(name = "removesuffix")] fn removesuffix(&self, suffix: PyByteInner) -> PyByteArray { - let value = self.borrow_value(); - if value.elements.ends_with(&suffix.elements) { - return value.elements[..value.elements.len() - suffix.elements.len()] - .to_vec() - .into(); - } - value.elements.to_vec().into() + self.borrow_value().elements[..] + .py_removesuffix(&suffix.elements, suffix.elements.len(), |s, p| { + s.ends_with(p) + }) + .to_vec() + .into() } #[pymethod(name = "split")] diff --git a/vm/src/obj/objbyteinner.rs b/vm/src/obj/objbyteinner.rs index 41fa4829ef..2be78f15d0 100644 --- a/vm/src/obj/objbyteinner.rs +++ b/vm/src/obj/objbyteinner.rs @@ -845,18 +845,22 @@ impl PyByteInner { // new in Python 3.9 pub fn removeprefix(&self, prefix: PyByteInner) -> Vec { - if self.elements.starts_with(&prefix.elements) { - return self.elements[prefix.elements.len()..].to_vec(); - } - self.elements.to_vec() + // self.elements.py_removeprefix(&prefix.elements, prefix.elements.len(), |s:&Self, p:&Vec| s.elements.starts_with(&p)).to_vec() + + self.elements + .py_removeprefix(&prefix.elements, prefix.elements.len(), |s, p| { + s.starts_with(p) + }) + .to_vec() } // new in Python 3.9 pub fn removesuffix(&self, suffix: PyByteInner) -> Vec { - if self.elements.ends_with(&suffix.elements) { - return self.elements[..self.elements.len() - suffix.elements.len()].to_vec(); - } - self.elements.to_vec() + self.elements + .py_removesuffix(&suffix.elements, suffix.elements.len(), |s, p| { + s.ends_with(p) + }) + .to_vec() } pub fn split( diff --git a/vm/src/obj/objbytes.rs b/vm/src/obj/objbytes.rs index 284939788b..c752a2da02 100644 --- a/vm/src/obj/objbytes.rs +++ b/vm/src/obj/objbytes.rs @@ -347,11 +347,25 @@ impl PyBytes { self.inner.rstrip(chars).into() } + /// removeprefix($self, prefix, /) + /// + /// + /// Return a bytes object with the given prefix string removed if present. + /// + /// If the bytes starts with the prefix string, return string[len(prefix):] + /// Otherwise, return a copy of the original bytes. #[pymethod(name = "removeprefix")] fn removeprefix(&self, prefix: PyByteInner) -> PyBytes { self.inner.removeprefix(prefix).into() } + /// removesuffix(self, prefix, /) + /// + /// + /// Return a bytes object with the given suffix string removed if present. + /// + /// If the bytes ends with the suffix string, return string[:len(suffix)] + /// Otherwise, return a copy of the original bytes. #[pymethod(name = "removesuffix")] fn removesuffix(&self, suffix: PyByteInner) -> PyBytes { self.inner.removesuffix(suffix).into() diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index 59ad91a794..da6b156d37 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -532,20 +532,34 @@ impl PyString { ) } + /// removeprefix($self, prefix, /) + /// + /// + /// Return a str with the given prefix string removed if present. + /// + /// If the string starts with the prefix string, return string[len(prefix):] + /// Otherwise, return a copy of the original string. #[pymethod] - fn removeprefix(&self, pref: PyStringRef) -> PyResult { - if self.value.as_str().starts_with(&pref.value) { - return Ok(self.value[pref.value.len()..].to_string()); - } - Ok(self.value.to_string()) + fn removeprefix(&self, pref: PyStringRef) -> String { + self.value + .as_str() + .py_removeprefix(&pref.value, pref.value.len(), |s, p| s.starts_with(p)) + .to_string() } + /// removesuffix(self, prefix, /) + /// + /// + /// Return a str with the given suffix string removed if present. + /// + /// If the string ends with the suffix string, return string[:len(suffix)] + /// Otherwise, return a copy of the original string. #[pymethod] - fn removesuffix(&self, suff: PyStringRef) -> PyResult { - if self.value.as_str().ends_with(&suff.value) { - return Ok(self.value[..self.value.len() - suff.value.len()].to_string()); - } - Ok(self.value.to_string()) + fn removesuffix(&self, suff: PyStringRef) -> String { + self.value + .as_str() + .py_removesuffix(&suff.value, suff.value.len(), |s, p| s.ends_with(p)) + .to_string() } #[pymethod] diff --git a/vm/src/obj/pystr.rs b/vm/src/obj/pystr.rs index 43a0260128..b404792fea 100644 --- a/vm/src/obj/pystr.rs +++ b/vm/src/obj/pystr.rs @@ -3,6 +3,7 @@ use crate::obj::objint::PyIntRef; use crate::pyobject::{PyObjectRef, PyResult, TryFromObject, TypeProtocol}; use crate::vm::VirtualMachine; use num_traits::{cast::ToPrimitive, sign::Signed}; +use std::ops::Range; #[derive(FromArgs)] pub struct SplitArgs @@ -264,4 +265,41 @@ pub trait PyCommonString { fn py_rjust(&self, width: usize, fillchar: E) -> Self::Container { self.py_pad(width - self.chars_len(), 0, fillchar) } + + fn py_removeprefix( + &self, + prefix: &Self::Container, + prefix_len: usize, + is_prefix: FC, + ) -> &Self + where + FC: Fn(&Self, &Self::Container) -> bool, + { + //if self.py_starts_with(prefix) { + if is_prefix(&self, &prefix) { + return self.get_bytes(Range { + start: prefix_len, + end: self.bytes_len(), + }); + } + &self + } + + fn py_removesuffix( + &self, + suffix: &Self::Container, + suffix_len: usize, + is_suffix: FC, + ) -> &Self + where + FC: Fn(&Self, &Self::Container) -> bool, + { + if is_suffix(&self, &suffix) { + return self.get_bytes(Range { + start: 0, + end: self.bytes_len() - suffix_len, + }); + } + &self + } } From 7c66acf11b806c7d2cb76ae343598502bd002744 Mon Sep 17 00:00:00 2001 From: TheAnyKey Date: Sat, 16 May 2020 13:41:22 +0000 Subject: [PATCH 19/19] fixed findings: removed unnecessary returns and double implementation --- vm/src/obj/objbytearray.rs | 14 ++------------ vm/src/obj/objbyteinner.rs | 2 -- vm/src/obj/pystr.rs | 17 ++++++----------- 3 files changed, 8 insertions(+), 25 deletions(-) diff --git a/vm/src/obj/objbytearray.rs b/vm/src/obj/objbytearray.rs index 3f06674f8d..76c7203f86 100644 --- a/vm/src/obj/objbytearray.rs +++ b/vm/src/obj/objbytearray.rs @@ -398,12 +398,7 @@ impl PyByteArray { /// Otherwise, return a copy of the original bytearray. #[pymethod(name = "removeprefix")] fn removeprefix(&self, prefix: PyByteInner) -> PyByteArray { - self.borrow_value().elements[..] - .py_removeprefix(&prefix.elements, prefix.elements.len(), |s, p| { - s.starts_with(p) - }) - .to_vec() - .into() + self.borrow_value().removeprefix(prefix).into() } /// removesuffix(self, prefix, /) @@ -415,12 +410,7 @@ impl PyByteArray { /// Otherwise, return a copy of the original bytearray. #[pymethod(name = "removesuffix")] fn removesuffix(&self, suffix: PyByteInner) -> PyByteArray { - self.borrow_value().elements[..] - .py_removesuffix(&suffix.elements, suffix.elements.len(), |s, p| { - s.ends_with(p) - }) - .to_vec() - .into() + self.borrow_value().removesuffix(suffix).to_vec().into() } #[pymethod(name = "split")] diff --git a/vm/src/obj/objbyteinner.rs b/vm/src/obj/objbyteinner.rs index 2be78f15d0..b20d724d84 100644 --- a/vm/src/obj/objbyteinner.rs +++ b/vm/src/obj/objbyteinner.rs @@ -845,8 +845,6 @@ impl PyByteInner { // new in Python 3.9 pub fn removeprefix(&self, prefix: PyByteInner) -> Vec { - // self.elements.py_removeprefix(&prefix.elements, prefix.elements.len(), |s:&Self, p:&Vec| s.elements.starts_with(&p)).to_vec() - self.elements .py_removeprefix(&prefix.elements, prefix.elements.len(), |s, p| { s.starts_with(p) diff --git a/vm/src/obj/pystr.rs b/vm/src/obj/pystr.rs index b404792fea..9f262dfe7f 100644 --- a/vm/src/obj/pystr.rs +++ b/vm/src/obj/pystr.rs @@ -3,7 +3,6 @@ use crate::obj::objint::PyIntRef; use crate::pyobject::{PyObjectRef, PyResult, TryFromObject, TypeProtocol}; use crate::vm::VirtualMachine; use num_traits::{cast::ToPrimitive, sign::Signed}; -use std::ops::Range; #[derive(FromArgs)] pub struct SplitArgs @@ -277,12 +276,10 @@ pub trait PyCommonString { { //if self.py_starts_with(prefix) { if is_prefix(&self, &prefix) { - return self.get_bytes(Range { - start: prefix_len, - end: self.bytes_len(), - }); + self.get_bytes(prefix_len..self.bytes_len()) + } else { + &self } - &self } fn py_removesuffix( @@ -295,11 +292,9 @@ pub trait PyCommonString { FC: Fn(&Self, &Self::Container) -> bool, { if is_suffix(&self, &suffix) { - return self.get_bytes(Range { - start: 0, - end: self.bytes_len() - suffix_len, - }); + self.get_bytes(0..self.bytes_len() - suffix_len) + } else { + &self } - &self } }