Skip to content

Commit 2c7f3b5

Browse files
Merge pull request RustPython#888 from youknowone/math-rounding
Add trunc, ceil, floor to math module
2 parents 9feeee0 + 6afe78c commit 2c7f3b5

File tree

3 files changed

+126
-13
lines changed

3 files changed

+126
-13
lines changed

tests/snippets/math_basics.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from testutils import assertRaises
2+
13
a = 4
24

35
#print(a ** 3)
@@ -16,14 +18,3 @@
1618
assert a - 3 == 1
1719
assert -a == -4
1820
assert +a == 4
19-
20-
# import math
21-
# assert(math.exp(2) == math.exp(2.0))
22-
# assert(math.exp(True) == math.exp(1.0))
23-
#
24-
# class Conversible():
25-
# def __float__(self):
26-
# print("Converting to float now!")
27-
# return 1.1111
28-
#
29-
# assert math.log(1.1111) == math.log(Conversible())

tests/snippets/math_module.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import math
2+
from testutils import assertRaises
3+
4+
# assert(math.exp(2) == math.exp(2.0))
5+
# assert(math.exp(True) == math.exp(1.0))
6+
#
7+
# class Conversible():
8+
# def __float__(self):
9+
# print("Converting to float now!")
10+
# return 1.1111
11+
#
12+
# assert math.log(1.1111) == math.log(Conversible())
13+
14+
# roundings
15+
assert int.__trunc__
16+
assert int.__floor__
17+
assert int.__ceil__
18+
19+
# assert float.__trunc__
20+
with assertRaises(AttributeError):
21+
assert float.__floor__
22+
with assertRaises(AttributeError):
23+
assert float.__ceil__
24+
25+
assert math.trunc(2) == 2
26+
assert math.ceil(3) == 3
27+
assert math.floor(4) == 4
28+
29+
assert math.trunc(2.2) == 2
30+
assert math.ceil(3.3) == 4
31+
assert math.floor(4.4) == 4
32+
33+
class A(object):
34+
def __trunc__(self):
35+
return 2
36+
37+
def __ceil__(self):
38+
return 3
39+
40+
def __floor__(self):
41+
return 4
42+
43+
assert math.trunc(A()) == 2
44+
assert math.ceil(A()) == 3
45+
assert math.floor(A()) == 4
46+
47+
class A(object):
48+
def __trunc__(self):
49+
return 2.2
50+
51+
def __ceil__(self):
52+
return 3.3
53+
54+
def __floor__(self):
55+
return 4.4
56+
57+
assert math.trunc(A()) == 2.2
58+
assert math.ceil(A()) == 3.3
59+
assert math.floor(A()) == 4.4
60+
61+
class A(object):
62+
def __trunc__(self):
63+
return 'trunc'
64+
65+
def __ceil__(self):
66+
return 'ceil'
67+
68+
def __floor__(self):
69+
return 'floor'
70+
71+
assert math.trunc(A()) == 'trunc'
72+
assert math.ceil(A()) == 'ceil'
73+
assert math.floor(A()) == 'floor'
74+
75+
with assertRaises(TypeError):
76+
math.trunc(object())
77+
with assertRaises(TypeError):
78+
math.ceil(object())
79+
with assertRaises(TypeError):
80+
math.floor(object())

vm/src/stdlib/math.rs

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ use statrs::function::erf::{erf, erfc};
77
use statrs::function::gamma::{gamma, ln_gamma};
88

99
use crate::function::PyFuncArgs;
10-
use crate::obj::objfloat;
11-
use crate::pyobject::{PyObjectRef, PyResult};
10+
use crate::obj::{objfloat, objtype};
11+
use crate::pyobject::{PyObjectRef, PyResult, TypeProtocol};
1212
use crate::vm::VirtualMachine;
1313

1414
// Helper macro:
@@ -172,6 +172,43 @@ fn math_lgamma(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
172172
}
173173
}
174174

175+
fn try_magic_method(func_name: &str, vm: &VirtualMachine, value: &PyObjectRef) -> PyResult {
176+
if let Ok(method) = vm.get_method(value.clone(), func_name) {
177+
vm.invoke(method, vec![])
178+
} else {
179+
Err(vm.new_type_error(format!(
180+
"TypeError: type {} doesn't define {} method",
181+
value.class().name,
182+
func_name,
183+
)))
184+
}
185+
}
186+
187+
fn math_trunc(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
188+
arg_check!(vm, args, required = [(value, None)]);
189+
try_magic_method("__trunc__", vm, value)
190+
}
191+
192+
fn math_ceil(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
193+
arg_check!(vm, args, required = [(value, None)]);
194+
if objtype::isinstance(value, &vm.ctx.float_type) {
195+
let v = objfloat::get_value(value);
196+
Ok(vm.ctx.new_float(v.ceil()))
197+
} else {
198+
try_magic_method("__ceil__", vm, value)
199+
}
200+
}
201+
202+
fn math_floor(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
203+
arg_check!(vm, args, required = [(value, None)]);
204+
if objtype::isinstance(value, &vm.ctx.float_type) {
205+
let v = objfloat::get_value(value);
206+
Ok(vm.ctx.new_float(v.floor()))
207+
} else {
208+
try_magic_method("__floor__", vm, value)
209+
}
210+
}
211+
175212
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
176213
let ctx = &vm.ctx;
177214

@@ -219,6 +256,11 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
219256
"gamma" => ctx.new_rustfunc(math_gamma),
220257
"lgamma" => ctx.new_rustfunc(math_lgamma),
221258

259+
// Rounding functions:
260+
"trunc" => ctx.new_rustfunc(math_trunc),
261+
"ceil" => ctx.new_rustfunc(math_ceil),
262+
"floor" => ctx.new_rustfunc(math_floor),
263+
222264
// Constants:
223265
"pi" => ctx.new_float(std::f64::consts::PI), // 3.14159...
224266
"e" => ctx.new_float(std::f64::consts::E), // 2.71..

0 commit comments

Comments
 (0)