Skip to content

Commit 46945c3

Browse files
authored
Merge pull request RustPython#1897 from youknowone/fix-bytes-find
Fix find/index/count not to crash for bigint start/end arguments
2 parents 766a598 + a115005 commit 46945c3

File tree

5 files changed

+57
-67
lines changed

5 files changed

+57
-67
lines changed

Lib/test/string_tests.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,6 @@ def test_count(self):
156156
self.assertEqual(rem, 0, '%s != 0 for %s' % (rem, i))
157157
self.assertEqual(r1, r2, '%s != %s for %s' % (r1, r2, i))
158158

159-
# TODO: RUSTPYTHON
160-
@unittest.expectedFailure
161159
def test_find(self):
162160
self.checkequal(0, 'abcdefghiabc', 'find', 'abc')
163161
self.checkequal(9, 'abcdefghiabc', 'find', 'abc', 1)
@@ -215,8 +213,6 @@ def test_find(self):
215213
if loc != -1:
216214
self.assertEqual(i[loc:loc+len(j)], j)
217215

218-
# TODO: RUSTPYTHON
219-
@unittest.expectedFailure
220216
def test_rfind(self):
221217
self.checkequal(9, 'abcdefghiabc', 'rfind', 'abc')
222218
self.checkequal(12, 'abcdefghiabc', 'rfind', '')

Lib/test/test_unicode.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,7 @@ def test_count(self):
199199
self.checkequal(0, 'a' * 10, 'count', 'a\U00100304')
200200
self.checkequal(0, '\u0102' * 10, 'count', '\u0102\U00100304')
201201

202-
# TODO: RUSTPYTHON
203-
@unittest.expectedFailure
202+
@unittest.skip("TODO: RUSTPYTHON")
204203
def test_find(self):
205204
string_tests.CommonTest.test_find(self)
206205
# test implementation details of the memchr fast path
@@ -232,8 +231,7 @@ def test_find(self):
232231
self.checkequal(-1, 'a' * 100, 'find', 'a\U00100304')
233232
self.checkequal(-1, '\u0102' * 100, 'find', '\u0102\U00100304')
234233

235-
# TODO: RUSTPYTHON
236-
@unittest.expectedFailure
234+
@unittest.skip("TODO: RUSTPYTHON")
237235
def test_rfind(self):
238236
string_tests.CommonTest.test_rfind(self)
239237
# test implementation details of the memrchr fast path

vm/src/obj/objbyteinner.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,10 @@ impl ByteInnerNewOptions {
154154
pub struct ByteInnerFindOptions {
155155
#[pyarg(positional_only, optional = false)]
156156
sub: Either<PyByteInner, PyIntRef>,
157-
#[pyarg(positional_only, optional = true)]
158-
start: OptionalArg<Option<isize>>,
159-
#[pyarg(positional_only, optional = true)]
160-
end: OptionalArg<Option<isize>>,
157+
#[pyarg(positional_only, default = "None")]
158+
start: Option<PyIntRef>,
159+
#[pyarg(positional_only, default = "None")]
160+
end: Option<PyIntRef>,
161161
}
162162

163163
impl ByteInnerFindOptions {

vm/src/obj/objstr.rs

Lines changed: 30 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -816,63 +816,35 @@ impl PyString {
816816
}
817817

818818
#[inline]
819-
fn _find<F>(
820-
&self,
821-
sub: PyStringRef,
822-
start: OptionalArg<Option<isize>>,
823-
end: OptionalArg<Option<isize>>,
824-
find: F,
825-
) -> Option<usize>
819+
fn _find<F>(&self, args: FindArgs, find: F) -> Option<usize>
826820
where
827821
F: Fn(&str, &str) -> Option<usize>,
828822
{
829-
let range = adjust_indices(start, end, self.value.len());
823+
let (sub, range) = args.get_value(self.len());
830824
self.value.py_find(&sub.value, range, find)
831825
}
832826

833827
#[pymethod]
834-
fn find(
835-
&self,
836-
sub: PyStringRef,
837-
start: OptionalArg<Option<isize>>,
838-
end: OptionalArg<Option<isize>>,
839-
) -> isize {
840-
self._find(sub, start, end, |r, s| r.find(s))
828+
fn find(&self, args: FindArgs) -> isize {
829+
self._find(args, |r, s| r.find(s))
841830
.map_or(-1, |v| v as isize)
842831
}
843832

844833
#[pymethod]
845-
fn rfind(
846-
&self,
847-
sub: PyStringRef,
848-
start: OptionalArg<Option<isize>>,
849-
end: OptionalArg<Option<isize>>,
850-
) -> isize {
851-
self._find(sub, start, end, |r, s| r.rfind(s))
834+
fn rfind(&self, args: FindArgs) -> isize {
835+
self._find(args, |r, s| r.rfind(s))
852836
.map_or(-1, |v| v as isize)
853837
}
854838

855839
#[pymethod]
856-
fn index(
857-
&self,
858-
sub: PyStringRef,
859-
start: OptionalArg<Option<isize>>,
860-
end: OptionalArg<Option<isize>>,
861-
vm: &VirtualMachine,
862-
) -> PyResult<usize> {
863-
self._find(sub, start, end, |r, s| r.find(s))
840+
fn index(&self, args: FindArgs, vm: &VirtualMachine) -> PyResult<usize> {
841+
self._find(args, |r, s| r.find(s))
864842
.ok_or_else(|| vm.new_value_error("substring not found".to_owned()))
865843
}
866844

867845
#[pymethod]
868-
fn rindex(
869-
&self,
870-
sub: PyStringRef,
871-
start: OptionalArg<Option<isize>>,
872-
end: OptionalArg<Option<isize>>,
873-
vm: &VirtualMachine,
874-
) -> PyResult<usize> {
875-
self._find(sub, start, end, |r, s| r.rfind(s))
846+
fn rindex(&self, args: FindArgs, vm: &VirtualMachine) -> PyResult<usize> {
847+
self._find(args, |r, s| r.rfind(s))
876848
.ok_or_else(|| vm.new_value_error("substring not found".to_owned()))
877849
}
878850

@@ -947,15 +919,10 @@ impl PyString {
947919
}
948920

949921
#[pymethod]
950-
fn count(
951-
&self,
952-
sub: PyStringRef,
953-
start: OptionalArg<Option<isize>>,
954-
end: OptionalArg<Option<isize>>,
955-
) -> usize {
956-
let range = adjust_indices(start, end, self.value.len());
922+
fn count(&self, args: FindArgs) -> usize {
923+
let (needle, range) = args.get_value(self.len());
957924
self.value
958-
.py_count(&sub.value, range, |h, n| h.matches(n).count())
925+
.py_count(&needle.value, range, |h, n| h.matches(n).count())
959926
}
960927

961928
#[pymethod]
@@ -1256,6 +1223,23 @@ impl TryFromObject for std::ffi::CString {
12561223

12571224
type SplitArgs = pystr::SplitArgs<PyStringRef, str, char>;
12581225

1226+
#[derive(FromArgs)]
1227+
pub struct FindArgs {
1228+
#[pyarg(positional_only, optional = false)]
1229+
sub: PyStringRef,
1230+
#[pyarg(positional_only, default = "None")]
1231+
start: Option<PyIntRef>,
1232+
#[pyarg(positional_only, default = "None")]
1233+
end: Option<PyIntRef>,
1234+
}
1235+
1236+
impl FindArgs {
1237+
fn get_value(self, len: usize) -> (PyStringRef, std::ops::Range<usize>) {
1238+
let range = adjust_indices(self.start, self.end, len);
1239+
(self.sub, range)
1240+
}
1241+
}
1242+
12591243
pub fn init(ctx: &PyContext) {
12601244
PyString::extend_class(ctx, &ctx.types.str_type);
12611245

vm/src/obj/pystr.rs

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
use crate::function::{single_or_tuple_any, OptionalOption};
2+
use crate::obj::objint::PyIntRef;
23
use crate::pyobject::{PyObjectRef, PyResult, TryFromObject, TypeProtocol};
34
use crate::vm::VirtualMachine;
4-
use num_traits::cast::ToPrimitive;
5+
use num_traits::{cast::ToPrimitive, sign::Signed};
56

67
#[derive(FromArgs)]
78
pub struct SplitArgs<T, S, E>
@@ -58,10 +59,10 @@ impl ExpandTabsArgs {
5859
pub struct StartsEndsWithArgs {
5960
#[pyarg(positional_only, optional = false)]
6061
affix: PyObjectRef,
61-
#[pyarg(positional_only, optional = true)]
62-
start: OptionalOption<isize>,
63-
#[pyarg(positional_only, optional = true)]
64-
end: OptionalOption<isize>,
62+
#[pyarg(positional_only, default = "None")]
63+
start: Option<PyIntRef>,
64+
#[pyarg(positional_only, default = "None")]
65+
end: Option<PyIntRef>,
6566
}
6667

6768
impl StartsEndsWithArgs {
@@ -71,14 +72,25 @@ impl StartsEndsWithArgs {
7172
}
7273
}
7374

75+
fn cap_to_isize(py_int: PyIntRef) -> isize {
76+
let big = py_int.as_bigint();
77+
big.to_isize().unwrap_or_else(|| {
78+
if big.is_negative() {
79+
std::isize::MIN
80+
} else {
81+
std::isize::MAX
82+
}
83+
})
84+
}
85+
7486
// help get optional string indices
7587
pub fn adjust_indices(
76-
start: OptionalOption<isize>,
77-
end: OptionalOption<isize>,
88+
start: Option<PyIntRef>,
89+
end: Option<PyIntRef>,
7890
len: usize,
7991
) -> std::ops::Range<usize> {
80-
let mut start = start.flat_option().unwrap_or(0);
81-
let mut end = end.flat_option().unwrap_or(len as isize);
92+
let mut start = start.map_or(0, cap_to_isize);
93+
let mut end = end.map_or(len as isize, cap_to_isize);
8294
if end > len as isize {
8395
end = len as isize;
8496
} else if end < 0 {

0 commit comments

Comments
 (0)