Skip to content

Commit 1fff67d

Browse files
committed
Correct float rounding algorithm
1 parent 5063627 commit 1fff67d

File tree

3 files changed

+71
-44
lines changed

3 files changed

+71
-44
lines changed

Lib/test/test_float.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -832,8 +832,6 @@ def test_inf_nan(self):
832832
self.assertRaises(TypeError, round, NAN, "ceci n'est pas un integer")
833833
self.assertRaises(TypeError, round, -0.0, 1j)
834834

835-
# TODO: RUSTPYTHON
836-
@unittest.expectedFailure
837835
def test_large_n(self):
838836
for n in [324, 325, 400, 2**31-1, 2**31, 2**32, 2**100]:
839837
self.assertEqual(round(123.456, n), 123.456)
@@ -846,17 +844,13 @@ def test_large_n(self):
846844
self.assertEqual(round(1e150, 309), 1e150)
847845
self.assertEqual(round(1.4e-315, 315), 1e-315)
848846

849-
# TODO: RUSTPYTHON
850-
@unittest.expectedFailure
851847
def test_small_n(self):
852848
for n in [-308, -309, -400, 1-2**31, -2**31, -2**31-1, -2**100]:
853849
self.assertEqual(round(123.456, n), 0.0)
854850
self.assertEqual(round(-123.456, n), -0.0)
855851
self.assertEqual(round(1e300, n), 0.0)
856852
self.assertEqual(round(1e-320, n), 0.0)
857853

858-
# TODO: RUSTPYTHON
859-
@unittest.expectedFailure
860854
def test_overflow(self):
861855
self.assertRaises(OverflowError, round, 1.6e308, -308)
862856
self.assertRaises(OverflowError, round, -1.7e308, -308)

common/src/float_ops.rs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use num_bigint::{BigInt, ToBigInt};
22
use num_traits::{Float, Signed, ToPrimitive, Zero};
3+
use std::f64;
34

45
pub fn ufrexp(value: f64) -> (f64, i32) {
56
if 0.0 == value {
@@ -366,6 +367,69 @@ pub fn ulp(x: f64) -> f64 {
366367
}
367368
}
368369

370+
pub fn round_float_digits(x: f64, ndigits: i32) -> Option<f64> {
371+
let float = if ndigits.is_zero() {
372+
let fract = x.fract();
373+
if (fract.abs() - 0.5).abs() < f64::EPSILON {
374+
if x.trunc() % 2.0 == 0.0 {
375+
x - fract
376+
} else {
377+
x + fract
378+
}
379+
} else {
380+
x.round()
381+
}
382+
} else {
383+
const NDIGITS_MAX: i32 =
384+
((f64::MANTISSA_DIGITS as i32 - f64::MIN_EXP) as f64 * f64::consts::LOG10_2) as i32;
385+
const NDIGITS_MIN: i32 = -(((f64::MAX_EXP + 1) as f64 * f64::consts::LOG10_2) as i32);
386+
if ndigits > NDIGITS_MAX {
387+
x
388+
} else if ndigits < NDIGITS_MIN {
389+
0.0f64.copysign(x)
390+
} else {
391+
let (y, pow1, pow2) = if ndigits >= 0 {
392+
// according to cpython: pow1 and pow2 are each safe from overflow, but
393+
// pow1*pow2 ~= pow(10.0, ndigits) might overflow
394+
let (pow1, pow2) = if ndigits > 22 {
395+
(10.0.powf((ndigits - 22) as f64), 1e22)
396+
} else {
397+
(10.0.powf(ndigits as f64), 1.0)
398+
};
399+
let y = (x * pow1) * pow2;
400+
if !y.is_finite() {
401+
return Some(x);
402+
}
403+
(y, pow1, Some(pow2))
404+
} else {
405+
let pow1 = 10.0.powf((-ndigits) as f64);
406+
(x / pow1, pow1, None)
407+
};
408+
let z = y.round();
409+
#[allow(clippy::float_cmp)]
410+
let z = if (y - z).abs() == 0.5 {
411+
2.0 * (y / 2.0).round()
412+
} else {
413+
z
414+
};
415+
let z = if let Some(pow2) = pow2 {
416+
// ndigits >= 0
417+
(z / pow2) / pow1
418+
} else {
419+
z * pow1
420+
};
421+
422+
if !z.is_finite() {
423+
// overflow
424+
return None;
425+
}
426+
427+
z
428+
}
429+
};
430+
Some(float)
431+
}
432+
369433
#[test]
370434
fn test_to_hex() {
371435
use rand::Rng;

vm/src/builtins/float.rs

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
use num_bigint::{BigInt, ToBigInt};
22
use num_complex::Complex64;
33
use num_rational::Ratio;
4-
use num_traits::{pow, Signed, ToPrimitive, Zero};
5-
use std::f64;
4+
use num_traits::{Signed, ToPrimitive, Zero};
65

76
use super::bytes::PyBytes;
87
use super::int::{self, PyInt, PyIntRef};
@@ -395,43 +394,13 @@ impl PyFloat {
395394
let ndigits = ndigits.flatten();
396395
let value = if let Some(ndigits) = ndigits {
397396
let ndigits = ndigits.borrow_value();
398-
let float = if ndigits.is_zero() {
399-
let fract = self.value.fract();
400-
if (fract.abs() - 0.5).abs() < f64::EPSILON {
401-
if self.value.trunc() % 2.0 == 0.0 {
402-
self.value - fract
403-
} else {
404-
self.value + fract
405-
}
406-
} else {
407-
self.value.round()
408-
}
409-
} else {
410-
let ndigits = match ndigits.to_isize() {
411-
Some(n) => n,
412-
None if ndigits.is_positive() => isize::MAX,
413-
None => isize::MIN,
414-
};
415-
const NDIGITS_MAX: isize = ((f64::MANTISSA_DIGITS as i32 - f64::MIN_EXP) as f64
416-
* f64::consts::LOG10_2) as isize;
417-
const NDIGITS_MIN: isize =
418-
-(((f64::MAX_EXP + 1) as f64 * f64::consts::LOG10_2) as isize);
419-
if ndigits > NDIGITS_MAX {
420-
self.value
421-
} else if ndigits > NDIGITS_MIN {
422-
0.0f64.copysign(self.value)
423-
} else if ndigits >= 0 {
424-
(self.value * pow(10.0, ndigits as usize)).round() / pow(10.0, ndigits as usize)
425-
} else {
426-
let result = (self.value / pow(10.0, (-ndigits) as usize)).round()
427-
* pow(10.0, (-ndigits) as usize);
428-
if result.is_nan() {
429-
0.0
430-
} else {
431-
result
432-
}
433-
}
397+
let ndigits = match ndigits.to_i32() {
398+
Some(n) => n,
399+
None if ndigits.is_positive() => i32::MAX,
400+
None => i32::MIN,
434401
};
402+
let float = float_ops::round_float_digits(self.value, ndigits)
403+
.ok_or_else(|| vm.new_overflow_error("overflow ocurred during round".to_owned()))?;
435404
vm.ctx.new_float(float)
436405
} else {
437406
let fract = self.value.fract();

0 commit comments

Comments
 (0)