Skip to content

Commit f682a94

Browse files
authored
Fix int shifts for zero
1 parent f4b3dc5 commit f682a94

File tree

2 files changed

+19
-28
lines changed

2 files changed

+19
-28
lines changed

Lib/test/test_long.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -933,7 +933,6 @@ def test_negative_shift_count(self):
933933
with self.assertRaises(ValueError):
934934
42 >> -(1 << 1000)
935935

936-
@unittest.expectedFailure # TODO: RUSTPYTHON
937936
def test_lshift_of_zero(self):
938937
self.assertEqual(0 << 0, 0)
939938
self.assertEqual(0 << 10, 0)
@@ -1398,4 +1397,4 @@ class myint(int):
13981397

13991398

14001399
if __name__ == "__main__":
1401-
unittest.main()
1400+
unittest.main()

vm/src/obj/objint.rs

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -177,14 +177,20 @@ fn inner_divmod(int1: &BigInt, int2: &BigInt, vm: &VirtualMachine) -> PyResult {
177177
}
178178
}
179179

180-
fn inner_lshift(int1: &BigInt, int2: &BigInt, vm: &VirtualMachine) -> PyResult {
181-
let n_bits = get_shift_amount(int2, vm)?;
182-
Ok(vm.ctx.new_int(int1 << n_bits))
183-
}
184-
185-
fn inner_rshift(int1: &BigInt, int2: &BigInt, vm: &VirtualMachine) -> PyResult {
186-
let n_bits = get_shift_amount(int2, vm)?;
187-
Ok(vm.ctx.new_int(int1 >> n_bits))
180+
fn inner_shift<F>(int1: &BigInt, int2: &BigInt, shift_op: F, vm: &VirtualMachine) -> PyResult
181+
where
182+
F: Fn(&BigInt, usize) -> BigInt,
183+
{
184+
if int2.is_negative() {
185+
Err(vm.new_value_error("negative shift count".to_owned()))
186+
} else if int1.is_zero() {
187+
Ok(vm.ctx.new_int(0))
188+
} else {
189+
let int2 = int2.to_usize().ok_or_else(|| {
190+
vm.new_overflow_error("the number is too large to convert to int".to_owned())
191+
})?;
192+
Ok(vm.ctx.new_int(shift_op(int1, int2)))
193+
}
188194
}
189195

190196
#[inline]
@@ -379,22 +385,22 @@ impl PyInt {
379385

380386
#[pymethod(name = "__lshift__")]
381387
fn lshift(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
382-
self.general_op(other, |a, b| inner_lshift(a, b, vm), vm)
388+
self.general_op(other, |a, b| inner_shift(a, b, |a, b| a << b, vm), vm)
383389
}
384390

385391
#[pymethod(name = "__rlshift__")]
386392
fn rlshift(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
387-
self.general_op(other, |a, b| inner_lshift(b, a, vm), vm)
393+
self.general_op(other, |a, b| inner_shift(b, a, |a, b| a << b, vm), vm)
388394
}
389395

390396
#[pymethod(name = "__rshift__")]
391397
fn rshift(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
392-
self.general_op(other, |a, b| inner_rshift(a, b, vm), vm)
398+
self.general_op(other, |a, b| inner_shift(a, b, |a, b| a >> b, vm), vm)
393399
}
394400

395401
#[pymethod(name = "__rrshift__")]
396402
fn rrshift(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
397-
self.general_op(other, |a, b| inner_rshift(b, a, vm), vm)
403+
self.general_op(other, |a, b| inner_shift(b, a, |a, b| a >> b, vm), vm)
398404
}
399405

400406
#[pymethod(name = "__xor__")]
@@ -926,20 +932,6 @@ pub fn try_float(int: &BigInt, vm: &VirtualMachine) -> PyResult<f64> {
926932
.ok_or_else(|| vm.new_overflow_error("int too large to convert to float".to_owned()))
927933
}
928934

929-
fn get_shift_amount(amount: &BigInt, vm: &VirtualMachine) -> PyResult<usize> {
930-
if let Some(n_bits) = amount.to_usize() {
931-
Ok(n_bits)
932-
} else {
933-
match amount {
934-
v if *v < BigInt::zero() => Err(vm.new_value_error("negative shift count".to_owned())),
935-
v if *v > BigInt::from(usize::max_value()) => {
936-
Err(vm.new_overflow_error("the number is too large to convert to int".to_owned()))
937-
}
938-
_ => panic!("Failed converting {} to rust usize", amount),
939-
}
940-
}
941-
}
942-
943935
pub(crate) fn init(context: &PyContext) {
944936
PyInt::extend_class(context, &context.types.int_type);
945937
}

0 commit comments

Comments
 (0)