Skip to content

Commit 813307f

Browse files
committed
Refactor PyFloat using try_float()
1 parent a0ad436 commit 813307f

File tree

2 files changed

+89
-117
lines changed

2 files changed

+89
-117
lines changed

tests/snippets/floats.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
b = 1.3
99
c = 1.2
1010
z = 2
11+
ov = 10 ** 1000
1112

1213
assert -a == -1.2
1314

@@ -37,6 +38,20 @@
3738
assert 6 / a == 5.0
3839
assert 2.0 % z == 0.0
3940
assert z % 2.0 == 0.0
41+
assert_raises(OverflowError, lambda: a + ov)
42+
assert_raises(OverflowError, lambda: a - ov)
43+
assert_raises(OverflowError, lambda: a * ov)
44+
assert_raises(OverflowError, lambda: a / ov)
45+
assert_raises(OverflowError, lambda: a // ov)
46+
assert_raises(OverflowError, lambda: a % ov)
47+
assert_raises(OverflowError, lambda: a ** ov)
48+
assert_raises(OverflowError, lambda: ov + a)
49+
assert_raises(OverflowError, lambda: ov - a)
50+
assert_raises(OverflowError, lambda: ov * a)
51+
assert_raises(OverflowError, lambda: ov / a)
52+
assert_raises(OverflowError, lambda: ov // a)
53+
assert_raises(OverflowError, lambda: ov % a)
54+
# assert_raises(OverflowError, lambda: ov ** a)
4055

4156
assert a < 5
4257
assert a <= 5

vm/src/obj/objfloat.rs

Lines changed: 74 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,30 @@ fn try_float(value: &PyObjectRef, vm: &VirtualMachine) -> PyResult<Option<f64>>
5151
})
5252
}
5353

54-
fn mod_(v1: f64, v2: f64, vm: &VirtualMachine) -> PyResult {
54+
fn inner_div(v1: f64, v2: f64, vm: &VirtualMachine) -> PyResult<f64> {
5555
if v2 != 0.0 {
56-
Ok(vm.ctx.new_float(v1 % v2))
56+
Ok(v1 / v2)
57+
} else {
58+
Err(vm.new_zero_division_error("float division by zero".to_string()))
59+
}
60+
}
61+
62+
fn inner_mod(v1: f64, v2: f64, vm: &VirtualMachine) -> PyResult<f64> {
63+
if v2 != 0.0 {
64+
Ok(v1 % v2)
5765
} else {
5866
Err(vm.new_zero_division_error("float mod by zero".to_string()))
5967
}
6068
}
6169

70+
fn inner_floordiv(v1: f64, v2: f64, vm: &VirtualMachine) -> PyResult<f64> {
71+
if v2 != 0.0 {
72+
Ok((v1 / v2).floor())
73+
} else {
74+
Err(vm.new_zero_division_error("float floordiv by zero".to_string()))
75+
}
76+
}
77+
6278
#[pyimpl]
6379
impl PyFloat {
6480
#[pymethod(name = "__eq__")]
@@ -139,20 +155,15 @@ impl PyFloat {
139155
}
140156

141157
#[pymethod(name = "__add__")]
142-
fn add(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
143-
let v1 = self.value;
144-
if objtype::isinstance(&other, &vm.ctx.float_type()) {
145-
vm.ctx.new_float(v1 + get_value(&other))
146-
} else if objtype::isinstance(&other, &vm.ctx.int_type()) {
147-
vm.ctx
148-
.new_float(v1 + objint::get_value(&other).to_f64().unwrap())
149-
} else {
150-
vm.ctx.not_implemented()
151-
}
158+
fn add(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
159+
try_float(&other, vm)?.map_or_else(
160+
|| Ok(vm.ctx.not_implemented()),
161+
|other| (self.value + other).into_pyobject(vm),
162+
)
152163
}
153164

154165
#[pymethod(name = "__radd__")]
155-
fn radd(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
166+
fn radd(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
156167
self.add(other, vm)
157168
}
158169

@@ -163,45 +174,39 @@ impl PyFloat {
163174

164175
#[pymethod(name = "__divmod__")]
165176
fn divmod(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
166-
if objtype::isinstance(&other, &vm.ctx.float_type())
167-
|| objtype::isinstance(&other, &vm.ctx.int_type())
168-
{
169-
let r1 = self.floordiv(other.clone(), vm)?;
170-
let r2 = self.mod_(other, vm)?;
171-
Ok(vm.ctx.new_tuple(vec![r1, r2]))
172-
} else {
173-
Ok(vm.ctx.not_implemented())
174-
}
177+
try_float(&other, vm)?.map_or_else(
178+
|| Ok(vm.ctx.not_implemented()),
179+
|other| {
180+
let r1 = inner_floordiv(self.value, other, vm)?;
181+
let r2 = inner_mod(self.value, other, vm)?;
182+
Ok(vm
183+
.ctx
184+
.new_tuple(vec![vm.ctx.new_float(r1), vm.ctx.new_float(r2)]))
185+
},
186+
)
175187
}
176188

177189
#[pymethod(name = "__floordiv__")]
178190
fn floordiv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
179-
let v1 = self.value;
180-
let v2 = if objtype::isinstance(&other, &vm.ctx.float_type) {
181-
get_value(&other)
182-
} else if objtype::isinstance(&other, &vm.ctx.int_type) {
183-
objint::get_float_value(&other, vm)?
184-
} else {
185-
return Ok(vm.ctx.not_implemented());
186-
};
191+
try_float(&other, vm)?.map_or_else(
192+
|| Ok(vm.ctx.not_implemented()),
193+
|other| inner_floordiv(self.value, other, vm)?.into_pyobject(vm),
194+
)
195+
}
187196

188-
if v2 != 0.0 {
189-
Ok(vm.ctx.new_float((v1 / v2).floor()))
190-
} else {
191-
Err(vm.new_zero_division_error("float floordiv by zero".to_string()))
192-
}
197+
#[pymethod(name = "__rfloordiv__")]
198+
fn rfloordiv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
199+
try_float(&other, vm)?.map_or_else(
200+
|| Ok(vm.ctx.not_implemented()),
201+
|other| inner_floordiv(other, self.value, vm)?.into_pyobject(vm),
202+
)
193203
}
194204

195205
fn new_float(cls: PyClassRef, arg: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyFloatRef> {
196206
let value = if objtype::isinstance(&arg, &vm.ctx.float_type()) {
197207
get_value(&arg)
198208
} else if objtype::isinstance(&arg, &vm.ctx.int_type()) {
199-
match objint::get_float_value(&arg, vm) {
200-
Ok(f) => f,
201-
Err(e) => {
202-
return Err(e);
203-
}
204-
}
209+
objint::get_float_value(&arg, vm)?
205210
} else if objtype::isinstance(&arg, &vm.ctx.str_type()) {
206211
match lexical::try_parse(objstr::get_value(&arg)) {
207212
Ok(f) => f,
@@ -232,28 +237,18 @@ impl PyFloat {
232237

233238
#[pymethod(name = "__mod__")]
234239
fn mod_(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
235-
let v1 = self.value;
236-
let v2 = if objtype::isinstance(&other, &vm.ctx.float_type) {
237-
get_value(&other)
238-
} else if objtype::isinstance(&other, &vm.ctx.int_type) {
239-
objint::get_float_value(&other, vm)?
240-
} else {
241-
return Ok(vm.ctx.not_implemented());
242-
};
243-
244-
mod_(v1, v2, vm)
240+
try_float(&other, vm)?.map_or_else(
241+
|| Ok(vm.ctx.not_implemented()),
242+
|other| inner_mod(self.value, other, vm)?.into_pyobject(vm),
243+
)
245244
}
246245

247246
#[pymethod(name = "__rmod__")]
248247
fn rmod(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
249-
let v2 = self.value;
250-
let v1 = if objtype::isinstance(&other, &vm.ctx.int_type) {
251-
objint::get_float_value(&other, vm)?
252-
} else {
253-
return Ok(vm.ctx.not_implemented());
254-
};
255-
256-
mod_(v1, v2, vm)
248+
try_float(&other, vm)?.map_or_else(
249+
|| Ok(vm.ctx.not_implemented()),
250+
|other| inner_mod(other, self.value, vm)?.into_pyobject(vm),
251+
)
257252
}
258253

259254
#[pymethod(name = "__neg__")]
@@ -271,30 +266,18 @@ impl PyFloat {
271266

272267
#[pymethod(name = "__sub__")]
273268
fn sub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
274-
let v1 = self.value;
275-
if objtype::isinstance(&other, &vm.ctx.float_type()) {
276-
Ok(vm.ctx.new_float(v1 - get_value(&other)))
277-
} else if objtype::isinstance(&other, &vm.ctx.int_type()) {
278-
Ok(vm
279-
.ctx
280-
.new_float(v1 - objint::get_value(&other).to_f64().unwrap()))
281-
} else {
282-
Ok(vm.ctx.not_implemented())
283-
}
269+
try_float(&other, vm)?.map_or_else(
270+
|| Ok(vm.ctx.not_implemented()),
271+
|other| (self.value - other).into_pyobject(vm),
272+
)
284273
}
285274

286275
#[pymethod(name = "__rsub__")]
287276
fn rsub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
288-
let v1 = self.value;
289-
if objtype::isinstance(&other, &vm.ctx.float_type()) {
290-
Ok(vm.ctx.new_float(get_value(&other) - v1))
291-
} else if objtype::isinstance(&other, &vm.ctx.int_type()) {
292-
Ok(vm
293-
.ctx
294-
.new_float(objint::get_value(&other).to_f64().unwrap() - v1))
295-
} else {
296-
Ok(vm.ctx.not_implemented())
297-
}
277+
try_float(&other, vm)?.map_or_else(
278+
|| Ok(vm.ctx.not_implemented()),
279+
|other| (other - self.value).into_pyobject(vm),
280+
)
298281
}
299282

300283
#[pymethod(name = "__repr__")]
@@ -304,52 +287,26 @@ impl PyFloat {
304287

305288
#[pymethod(name = "__truediv__")]
306289
fn truediv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
307-
let v1 = self.value;
308-
let v2 = if objtype::isinstance(&other, &vm.ctx.float_type) {
309-
get_value(&other)
310-
} else if objtype::isinstance(&other, &vm.ctx.int_type) {
311-
objint::get_float_value(&other, vm)?
312-
} else {
313-
return Ok(vm.ctx.not_implemented());
314-
};
315-
316-
if v2 != 0.0 {
317-
Ok(vm.ctx.new_float(v1 / v2))
318-
} else {
319-
Err(vm.new_zero_division_error("float division by zero".to_string()))
320-
}
290+
try_float(&other, vm)?.map_or_else(
291+
|| Ok(vm.ctx.not_implemented()),
292+
|other| inner_div(self.value, other, vm)?.into_pyobject(vm),
293+
)
321294
}
322295

323296
#[pymethod(name = "__rtruediv__")]
324297
fn rtruediv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
325-
let v1 = self.value;
326-
let v2 = if objtype::isinstance(&other, &vm.ctx.float_type) {
327-
get_value(&other)
328-
} else if objtype::isinstance(&other, &vm.ctx.int_type) {
329-
objint::get_float_value(&other, vm)?
330-
} else {
331-
return Ok(vm.ctx.not_implemented());
332-
};
333-
334-
if v1 != 0.0 {
335-
Ok(vm.ctx.new_float(v2 / v1))
336-
} else {
337-
Err(vm.new_zero_division_error("float division by zero".to_string()))
338-
}
298+
try_float(&other, vm)?.map_or_else(
299+
|| Ok(vm.ctx.not_implemented()),
300+
|other| inner_div(other, self.value, vm)?.into_pyobject(vm),
301+
)
339302
}
340303

341304
#[pymethod(name = "__mul__")]
342305
fn mul(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
343-
let v1 = self.value;
344-
if objtype::isinstance(&other, &vm.ctx.float_type) {
345-
Ok(vm.ctx.new_float(v1 * get_value(&other)))
346-
} else if objtype::isinstance(&other, &vm.ctx.int_type) {
347-
Ok(vm
348-
.ctx
349-
.new_float(v1 * objint::get_value(&other).to_f64().unwrap()))
350-
} else {
351-
Ok(vm.ctx.not_implemented())
352-
}
306+
try_float(&other, vm)?.map_or_else(
307+
|| Ok(vm.ctx.not_implemented()),
308+
|other| (self.value * other).into_pyobject(vm),
309+
)
353310
}
354311

355312
#[pymethod(name = "__rmul__")]

0 commit comments

Comments
 (0)