Skip to content

Commit d960ca3

Browse files
committed
fix int – float equality for big ints, inf and nan
1 parent a6d6f0f commit d960ca3

File tree

3 files changed

+32
-6
lines changed

3 files changed

+32
-6
lines changed

tests/snippets/int_float_equality.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# 10**308 cannot be represented exactly in f64, thus it is not equal to 1e308 float
2+
assert not (10**308 == 1e308)
3+
# but the 1e308 float can be converted to big int and then it still should be equal to itself
4+
assert int(1e308) == 1e308
5+
6+
# and the equalities should be the same when operands switch sides
7+
assert not (1e308 == 10**308)
8+
assert 1e308 == int(1e308)
9+
10+
# floats that cannot be converted to big ints shouldn’t crash the vm
11+
import math
12+
assert not (10**500 == math.inf)
13+
assert not (math.inf == 10**500)
14+
assert not (10**500 == math.nan)
15+
assert not (math.nan == 10**500)

vm/src/obj/objfloat.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use super::super::pyobject::{
44
use super::super::vm::VirtualMachine;
55
use super::objint;
66
use super::objtype;
7+
use num_bigint::ToBigInt;
78
use num_traits::ToPrimitive;
89

910
fn float_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
@@ -71,8 +72,13 @@ fn float_eq(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
7172
let other = get_value(other);
7273
zelf == other
7374
} else if objtype::isinstance(other, &vm.ctx.int_type()) {
74-
let other = objint::get_value(other).to_f64().unwrap();
75-
zelf == other
75+
let other_int = objint::get_value(other);
76+
77+
if let (Some(zelf_int), Some(other_float)) = (zelf.to_bigint(), other_int.to_f64()) {
78+
zelf == other_float && zelf_int == other_int
79+
} else {
80+
false
81+
}
7682
} else {
7783
false
7884
};

vm/src/obj/objint.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,19 @@ fn int_eq(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
110110
args,
111111
required = [(zelf, Some(vm.ctx.int_type())), (other, None)]
112112
);
113+
114+
let zelf = BigInt::from_pyobj(zelf);
113115
let result = if objtype::isinstance(other, &vm.ctx.int_type()) {
114-
let zelf = BigInt::from_pyobj(zelf);
115116
let other = BigInt::from_pyobj(other);
116117
zelf == other
117118
} else if objtype::isinstance(other, &vm.ctx.float_type()) {
118-
let zelf = BigInt::from_pyobj(zelf).to_f64().unwrap();
119-
let other = objfloat::get_value(other);
120-
zelf == other
119+
let other_float = objfloat::get_value(other);
120+
121+
if let (Some(zelf_float), Some(other_int)) = (zelf.to_f64(), other_float.to_bigint()) {
122+
zelf_float == other_float && zelf == other_int
123+
} else {
124+
false
125+
}
121126
} else {
122127
false
123128
};

0 commit comments

Comments
 (0)