@@ -177,14 +177,20 @@ fn inner_divmod(int1: &BigInt, int2: &BigInt, vm: &VirtualMachine) -> PyResult {
177
177
}
178
178
}
179
179
180
- fn inner_lshift ( int1 : & BigInt , int2 : & BigInt , vm : & VirtualMachine ) -> PyResult {
181
- let n_bits = get_shift_amount ( int2, vm) ?;
182
- Ok ( vm. ctx . new_int ( int1 << n_bits) )
183
- }
184
-
185
- fn inner_rshift ( int1 : & BigInt , int2 : & BigInt , vm : & VirtualMachine ) -> PyResult {
186
- let n_bits = get_shift_amount ( int2, vm) ?;
187
- Ok ( vm. ctx . new_int ( int1 >> n_bits) )
180
+ fn inner_shift < F > ( int1 : & BigInt , int2 : & BigInt , shift_op : F , vm : & VirtualMachine ) -> PyResult
181
+ where
182
+ F : Fn ( & BigInt , usize ) -> BigInt ,
183
+ {
184
+ if int2. is_negative ( ) {
185
+ Err ( vm. new_value_error ( "negative shift count" . to_owned ( ) ) )
186
+ } else if int1. is_zero ( ) {
187
+ Ok ( vm. ctx . new_int ( 0 ) )
188
+ } else {
189
+ let int2 = int2. to_usize ( ) . ok_or_else ( || {
190
+ vm. new_overflow_error ( "the number is too large to convert to int" . to_owned ( ) )
191
+ } ) ?;
192
+ Ok ( vm. ctx . new_int ( shift_op ( int1, int2) ) )
193
+ }
188
194
}
189
195
190
196
#[ inline]
@@ -379,22 +385,22 @@ impl PyInt {
379
385
380
386
#[ pymethod( name = "__lshift__" ) ]
381
387
fn lshift ( & self , other : PyObjectRef , vm : & VirtualMachine ) -> PyResult {
382
- self . general_op ( other, |a, b| inner_lshift ( a, b, vm) , vm)
388
+ self . general_op ( other, |a, b| inner_shift ( a, b , |a , b| a << b, vm) , vm)
383
389
}
384
390
385
391
#[ pymethod( name = "__rlshift__" ) ]
386
392
fn rlshift ( & self , other : PyObjectRef , vm : & VirtualMachine ) -> PyResult {
387
- self . general_op ( other, |a, b| inner_lshift ( b, a, vm) , vm)
393
+ self . general_op ( other, |a, b| inner_shift ( b, a, |a , b| a << b , vm) , vm)
388
394
}
389
395
390
396
#[ pymethod( name = "__rshift__" ) ]
391
397
fn rshift ( & self , other : PyObjectRef , vm : & VirtualMachine ) -> PyResult {
392
- self . general_op ( other, |a, b| inner_rshift ( a, b, vm) , vm)
398
+ self . general_op ( other, |a, b| inner_shift ( a, b , |a , b| a >> b, vm) , vm)
393
399
}
394
400
395
401
#[ pymethod( name = "__rrshift__" ) ]
396
402
fn rrshift ( & self , other : PyObjectRef , vm : & VirtualMachine ) -> PyResult {
397
- self . general_op ( other, |a, b| inner_rshift ( b, a, vm) , vm)
403
+ self . general_op ( other, |a, b| inner_shift ( b, a, |a , b| a >> b , vm) , vm)
398
404
}
399
405
400
406
#[ pymethod( name = "__xor__" ) ]
@@ -926,20 +932,6 @@ pub fn try_float(int: &BigInt, vm: &VirtualMachine) -> PyResult<f64> {
926
932
. ok_or_else ( || vm. new_overflow_error ( "int too large to convert to float" . to_owned ( ) ) )
927
933
}
928
934
929
- fn get_shift_amount ( amount : & BigInt , vm : & VirtualMachine ) -> PyResult < usize > {
930
- if let Some ( n_bits) = amount. to_usize ( ) {
931
- Ok ( n_bits)
932
- } else {
933
- match amount {
934
- v if * v < BigInt :: zero ( ) => Err ( vm. new_value_error ( "negative shift count" . to_owned ( ) ) ) ,
935
- v if * v > BigInt :: from ( usize:: max_value ( ) ) => {
936
- Err ( vm. new_overflow_error ( "the number is too large to convert to int" . to_owned ( ) ) )
937
- }
938
- _ => panic ! ( "Failed converting {} to rust usize" , amount) ,
939
- }
940
- }
941
- }
942
-
943
935
pub ( crate ) fn init ( context : & PyContext ) {
944
936
PyInt :: extend_class ( context, & context. types . int_type ) ;
945
937
}
0 commit comments