Skip to content

Commit 89a97b6

Browse files
Merge pull request RustPython#1290 from mpajkowski/int_fix
Fix panics with int()
2 parents bd293fa + 6d618c5 commit 89a97b6

File tree

2 files changed

+132
-15
lines changed

2 files changed

+132
-15
lines changed

tests/snippets/ints.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,57 @@
9797
assert -10 // -4 == 2
9898

9999
assert int() == 0
100+
assert int(1) == 1
100101
assert int("101", 2) == 5
101102
assert int("101", base=2) == 5
102-
assert int(1) == 1
103+
104+
# implied base
105+
assert int('1', base=0) == 1
106+
assert int('123', base=0) == 123
107+
assert int('0b101', base=0) == 5
108+
assert int('0B101', base=0) == 5
109+
assert int('0o100', base=0) == 64
110+
assert int('0O100', base=0) == 64
111+
assert int('0xFF', base=0) == 255
112+
assert int('0XFF', base=0) == 255
113+
with assertRaises(ValueError):
114+
int('0xFF', base=10)
115+
with assertRaises(ValueError):
116+
int('0oFF', base=10)
117+
with assertRaises(ValueError):
118+
int('0bFF', base=10)
119+
with assertRaises(ValueError):
120+
int('0bFF', base=10)
121+
with assertRaises(ValueError):
122+
int(b"F\xc3\xb8\xc3\xb6\xbbB\xc3\xa5r")
123+
with assertRaises(ValueError):
124+
int(b"F\xc3\xb8\xc3\xb6\xbbB\xc3\xa5r")
125+
126+
# underscore
127+
assert int('0xFF_FF_FF', base=16) == 16_777_215
128+
with assertRaises(ValueError):
129+
int("_123_")
130+
with assertRaises(ValueError):
131+
int("123_")
132+
with assertRaises(ValueError):
133+
int("_123")
134+
with assertRaises(ValueError):
135+
int("1__23")
136+
137+
# signed
138+
assert int('-123') == -123
139+
assert int('+0b101', base=2) == +5
140+
141+
# trailing spaces
103142
assert int(' 1') == 1
104143
assert int('1 ') == 1
105144
assert int(' 1 ') == 1
106145
assert int('10', base=0) == 10
107146

147+
# type byte, signed, implied base
148+
assert int(b' -0XFF ', base=0) == -255
149+
150+
108151
assert int.from_bytes(b'\x00\x10', 'big') == 16
109152
assert int.from_bytes(b'\x00\x10', 'little') == 4096
110153
assert int.from_bytes(b'\xfc\x00', 'big', signed=True) == -1024
@@ -179,4 +222,4 @@ def __int__(self):
179222
assert_raises(TypeError, lambda: (0).__round__(None))
180223
assert_raises(TypeError, lambda: (1).__round__(None))
181224
assert_raises(TypeError, lambda: (0).__round__(0.0))
182-
assert_raises(TypeError, lambda: (1).__round__(0.0))
225+
assert_raises(TypeError, lambda: (1).__round__(0.0))

vm/src/obj/objint.rs

Lines changed: 87 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
use std::fmt;
2+
use std::str;
23

34
use num_bigint::{BigInt, Sign};
45
use num_integer::Integer;
5-
use num_traits::{One, Pow, Signed, ToPrimitive, Zero};
6+
use num_traits::{Num, One, Pow, Signed, ToPrimitive, Zero};
67

78
use crate::format::FormatSpec;
89
use crate::function::{KwArgs, OptionalArg, PyFuncArgs};
@@ -713,7 +714,9 @@ impl IntOptions {
713714
fn get_int_value(self, vm: &VirtualMachine) -> PyResult<BigInt> {
714715
if let OptionalArg::Present(val) = self.val_options {
715716
let base = if let OptionalArg::Present(base) = self.base {
716-
if !objtype::isinstance(&val, &vm.ctx.str_type()) {
717+
if !(objtype::isinstance(&val, &vm.ctx.str_type())
718+
|| objtype::isinstance(&val, &vm.ctx.bytes_type()))
719+
{
717720
return Err(vm.new_type_error(
718721
"int() can't convert non-string with explicit base".to_string(),
719722
));
@@ -736,21 +739,22 @@ fn int_new(cls: PyClassRef, options: IntOptions, vm: &VirtualMachine) -> PyResul
736739
}
737740

738741
// Casting function:
739-
pub fn to_int(vm: &VirtualMachine, obj: &PyObjectRef, mut base: u32) -> PyResult<BigInt> {
740-
if base == 0 {
741-
base = 10
742-
} else if base < 2 || base > 36 {
742+
pub fn to_int(vm: &VirtualMachine, obj: &PyObjectRef, base: u32) -> PyResult<BigInt> {
743+
if base != 0 && (base < 2 || base > 36) {
743744
return Err(vm.new_value_error("int() base must be >= 2 and <= 36, or 0".to_string()));
744745
}
745746

746747
match_class!(obj.clone(),
747-
s @ PyString => {
748-
i32::from_str_radix(s.as_str().trim(), base)
749-
.map(BigInt::from)
750-
.map_err(|_|vm.new_value_error(format!(
751-
"invalid literal for int() with base {}: '{}'",
752-
base, s
753-
)))
748+
string @ PyString => {
749+
let s = string.value.as_str().trim();
750+
str_to_int(vm, s, base)
751+
},
752+
bytes @ PyBytes => {
753+
let bytes = bytes.get_value();
754+
let s = std::str::from_utf8(bytes)
755+
.map(|s| s.trim())
756+
.map_err(|e| vm.new_value_error(format!("utf8 decode error: {}", e)))?;
757+
str_to_int(vm, s, base)
754758
},
755759
obj => {
756760
let method = vm.get_method_or_type_error(obj.clone(), "__int__", || {
@@ -766,6 +770,76 @@ pub fn to_int(vm: &VirtualMachine, obj: &PyObjectRef, mut base: u32) -> PyResult
766770
)
767771
}
768772

773+
fn str_to_int(vm: &VirtualMachine, literal: &str, mut base: u32) -> PyResult<BigInt> {
774+
let mut buf = validate_literal(vm, literal, base)?;
775+
let is_signed = buf.starts_with('+') || buf.starts_with('-');
776+
let radix_range = if is_signed { 1..3 } else { 0..2 };
777+
let radix_candidate = buf.get(radix_range.clone());
778+
779+
// try to find base
780+
if let Some(radix_candidate) = radix_candidate {
781+
if let Some(matched_radix) = detect_base(&radix_candidate) {
782+
if base != 0 && base != matched_radix {
783+
return Err(invalid_literal(vm, literal, base));
784+
} else {
785+
base = matched_radix;
786+
}
787+
788+
buf.drain(radix_range);
789+
}
790+
}
791+
792+
// base still not found, try to use default
793+
if base == 0 {
794+
if buf.starts_with('0') {
795+
return Err(invalid_literal(vm, literal, base));
796+
}
797+
798+
base = 10;
799+
}
800+
801+
BigInt::from_str_radix(&buf, base).map_err(|_err| invalid_literal(vm, literal, base))
802+
}
803+
804+
fn validate_literal(vm: &VirtualMachine, literal: &str, base: u32) -> PyResult<String> {
805+
if literal.starts_with('_') || literal.ends_with('_') {
806+
return Err(invalid_literal(vm, literal, base));
807+
}
808+
809+
let mut buf = String::with_capacity(literal.len());
810+
let mut last_tok = None;
811+
for c in literal.chars() {
812+
if !(c.is_ascii_alphanumeric() || c == '_' || c == '+' || c == '-') {
813+
return Err(invalid_literal(vm, literal, base));
814+
}
815+
816+
if c == '_' && Some(c) == last_tok {
817+
return Err(invalid_literal(vm, literal, base));
818+
}
819+
820+
last_tok = Some(c);
821+
buf.push(c);
822+
}
823+
824+
Ok(buf)
825+
}
826+
827+
fn detect_base(literal: &str) -> Option<u32> {
828+
match literal {
829+
"0x" | "0X" => Some(16),
830+
"0o" | "0O" => Some(8),
831+
"0b" | "0B" => Some(2),
832+
_ => None,
833+
}
834+
}
835+
836+
fn invalid_literal(vm: &VirtualMachine, literal: &str, base: u32) -> PyObjectRef {
837+
vm.new_value_error(format!(
838+
"invalid literal for int() with base {}: '{}'",
839+
base, literal
840+
))
841+
}
842+
769843
// Retrieve inner int value:
770844
pub fn get_value(obj: &PyObjectRef) -> &BigInt {
771845
&get_py_int(obj).value

0 commit comments

Comments
 (0)