Skip to content

Commit 3937e84

Browse files
authored
Merge pull request RustPython#1890 from RustPython/coolreader18/fix-int-tests
Fix a few int tests
2 parents 8ca22df + dcbce35 commit 3937e84

File tree

2 files changed

+41
-54
lines changed

2 files changed

+41
-54
lines changed

Lib/test/test_int.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import unittest
44
from test import support
5-
# from test.test_grammar import (VALID_UNDERSCORE_LITERALS,
6-
# INVALID_UNDERSCORE_LITERALS)
5+
from test.test_grammar import (VALID_UNDERSCORE_LITERALS,
6+
INVALID_UNDERSCORE_LITERALS)
77

88
L = [
99
('0', 0),
@@ -31,7 +31,6 @@ class IntSubclass(int):
3131

3232
class IntTestCases(unittest.TestCase):
3333

34-
@unittest.skip("TODO: RUSTPYTHON")
3534
def test_basic(self):
3635
self.assertEqual(int(314), 314)
3736
self.assertEqual(int(3.14), 3)
@@ -215,7 +214,6 @@ def test_basic(self):
215214
self.assertEqual(int('2br45qc', 35), 4294967297)
216215
self.assertEqual(int('1z141z5', 36), 4294967297)
217216

218-
@unittest.skip("TODO: RUSTPYTHON")
219217
def test_underscores(self):
220218
for lit in VALID_UNDERSCORE_LITERALS:
221219
if any(ch in lit for ch in '.eEjJ'):
@@ -481,8 +479,6 @@ def __trunc__(self):
481479
self.assertEqual(n, 1)
482480
self.assertIs(type(n), IntSubclass)
483481

484-
# TODO: RUSTPYTHON
485-
@unittest.expectedFailure
486482
def test_error_message(self):
487483
def check(s, base=None):
488484
with self.assertRaises(ValueError,

vm/src/obj/objint.rs

Lines changed: 39 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ use super::objfloat;
1313
use super::objmemory::PyMemoryView;
1414
use super::objstr::{PyString, PyStringRef};
1515
use super::objtype::{self, PyClassRef};
16-
use crate::exceptions::PyBaseExceptionRef;
1716
use crate::format::FormatSpec;
1817
use crate::function::{OptionalArg, PyFuncArgs};
1918
use crate::pyhash;
@@ -724,26 +723,21 @@ struct IntToByteArgs {
724723

725724
// Casting function:
726725
pub fn to_int(vm: &VirtualMachine, obj: &PyObjectRef, base: &BigInt) -> PyResult<BigInt> {
727-
let base_u32 = match base.to_u32() {
728-
Some(base_u32) => base_u32,
729-
None => {
730-
return Err(vm.new_value_error("int() base must be >= 2 and <= 36, or 0".to_owned()))
731-
}
726+
let base = match base.to_u32() {
727+
Some(base) if base == 0 || (2..=36).contains(&base) => base,
728+
_ => return Err(vm.new_value_error("int() base must be >= 2 and <= 36, or 0".to_owned())),
732729
};
733-
if base_u32 != 0 && (base_u32 < 2 || base_u32 > 36) {
734-
return Err(vm.new_value_error("int() base must be >= 2 and <= 36, or 0".to_owned()));
735-
}
736730

737731
let bytes_to_int = |bytes: &[u8]| {
738-
let s = std::str::from_utf8(bytes)
739-
.map_err(|e| vm.new_value_error(format!("utf8 decode error: {}", e)))?;
740-
str_to_int(vm, s, base)
732+
std::str::from_utf8(bytes)
733+
.ok()
734+
.and_then(|s| str_to_int(s, base))
741735
};
742736

743-
match_class!(match obj.clone() {
737+
let opt = match_class!(match obj.clone() {
744738
string @ PyString => {
745739
let s = string.as_str();
746-
str_to_int(vm, &s, base)
740+
str_to_int(&s, base)
747741
}
748742
bytes @ PyBytes => {
749743
let bytes = bytes.get_value();
@@ -770,36 +764,39 @@ pub fn to_int(vm: &VirtualMachine, obj: &PyObjectRef, base: &BigInt) -> PyResult
770764
)
771765
})?;
772766
let result = vm.invoke(&method, PyFuncArgs::default())?;
773-
match result.payload::<PyInt>() {
767+
return match result.payload::<PyInt>() {
774768
Some(int_obj) => Ok(int_obj.as_bigint().clone()),
775769
None => Err(vm.new_type_error(format!(
776770
"TypeError: __int__ returned non-int (type '{}')",
777771
result.class().name
778772
))),
779-
}
773+
};
780774
}
781-
})
775+
});
776+
match opt {
777+
Some(int) => Ok(int),
778+
None => Err(vm.new_value_error(format!(
779+
"invalid literal for int() with base {}: {}",
780+
base,
781+
vm.to_repr(obj)?,
782+
))),
783+
}
782784
}
783785

784-
fn str_to_int(vm: &VirtualMachine, literal: &str, base: &BigInt) -> PyResult<BigInt> {
785-
let mut buf = validate_literal(vm, literal, base)?;
786+
fn str_to_int(literal: &str, mut base: u32) -> Option<BigInt> {
787+
let mut buf = validate_literal(literal)?.to_owned();
786788
let is_signed = buf.starts_with('+') || buf.starts_with('-');
787789
let radix_range = if is_signed { 1..3 } else { 0..2 };
788790
let radix_candidate = buf.get(radix_range.clone());
789791

790-
let mut base_u32 = match base.to_u32() {
791-
Some(base_u32) => base_u32,
792-
None => return Err(invalid_literal(vm, literal, base)),
793-
};
794-
795792
// try to find base
796793
if let Some(radix_candidate) = radix_candidate {
797794
if let Some(matched_radix) = detect_base(&radix_candidate) {
798-
if base_u32 == 0 || base_u32 == matched_radix {
795+
if base == 0 || base == matched_radix {
799796
/* If base is 0 or equal radix number, it means radix is validate
800797
* So change base to radix number and remove radix from literal
801798
*/
802-
base_u32 = matched_radix;
799+
base = matched_radix;
803800
buf.drain(radix_range);
804801

805802
/* first underscore with radix is validate
@@ -808,49 +805,50 @@ fn str_to_int(vm: &VirtualMachine, literal: &str, base: &BigInt) -> PyResult<Big
808805
if buf.starts_with('_') {
809806
buf.remove(0);
810807
}
811-
} else if (matched_radix == 2 && base_u32 < 12)
812-
|| (matched_radix == 8 && base_u32 < 25)
813-
|| (matched_radix == 16 && base_u32 < 34)
808+
} else if (matched_radix == 2 && base < 12)
809+
|| (matched_radix == 8 && base < 25)
810+
|| (matched_radix == 16 && base < 34)
814811
{
815-
return Err(invalid_literal(vm, literal, base));
812+
return None;
816813
}
817814
}
818815
}
819816

820817
// base still not found, try to use default
821-
if base_u32 == 0 {
818+
if base == 0 {
822819
if buf.starts_with('0') {
823-
return Err(invalid_literal(vm, literal, base));
820+
if buf.chars().all(|c| matches!(c, '+' | '-' | '0' | '_')) {
821+
return Some(BigInt::zero());
822+
}
823+
return None;
824824
}
825825

826-
base_u32 = 10;
826+
base = 10;
827827
}
828828

829-
BigInt::from_str_radix(&buf, base_u32).map_err(|_err| invalid_literal(vm, literal, base))
829+
BigInt::from_str_radix(&buf, base).ok()
830830
}
831831

832-
fn validate_literal(vm: &VirtualMachine, literal: &str, base: &BigInt) -> PyResult<String> {
832+
fn validate_literal(literal: &str) -> Option<&str> {
833833
let trimmed = literal.trim();
834834
if trimmed.starts_with('_') || trimmed.ends_with('_') {
835-
return Err(invalid_literal(vm, literal, base));
835+
return None;
836836
}
837837

838-
let mut buf = String::with_capacity(trimmed.len());
839838
let mut last_tok = None;
840839
for c in trimmed.chars() {
841840
if !(c.is_ascii_alphanumeric() || c == '_' || c == '+' || c == '-') {
842-
return Err(invalid_literal(vm, literal, base));
841+
return None;
843842
}
844843

845844
if c == '_' && Some(c) == last_tok {
846-
return Err(invalid_literal(vm, literal, base));
845+
return None;
847846
}
848847

849848
last_tok = Some(c);
850-
buf.push(c);
851849
}
852850

853-
Ok(buf)
851+
Some(trimmed)
854852
}
855853

856854
fn detect_base(literal: &str) -> Option<u32> {
@@ -862,13 +860,6 @@ fn detect_base(literal: &str) -> Option<u32> {
862860
}
863861
}
864862

865-
fn invalid_literal(vm: &VirtualMachine, literal: &str, base: &BigInt) -> PyBaseExceptionRef {
866-
vm.new_value_error(format!(
867-
"invalid literal for int() with base {}: '{}'",
868-
base, literal
869-
))
870-
}
871-
872863
// Retrieve inner int value:
873864
pub fn get_value(obj: &PyObjectRef) -> &BigInt {
874865
&obj.payload::<PyInt>().unwrap().value

0 commit comments

Comments
 (0)