Skip to content

Commit 6afe78c

Browse files
committed
Add math.ceil / math.floor
1 parent c3714c2 commit 6afe78c

File tree

3 files changed

+111
-38
lines changed

3 files changed

+111
-38
lines changed

tests/snippets/math_basics.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import math
21
from testutils import assertRaises
32

43
a = 4
@@ -19,34 +18,3 @@
1918
assert a - 3 == 1
2019
assert -a == -4
2120
assert +a == 4
22-
23-
# assert(math.exp(2) == math.exp(2.0))
24-
# assert(math.exp(True) == math.exp(1.0))
25-
#
26-
# class Conversible():
27-
# def __float__(self):
28-
# print("Converting to float now!")
29-
# return 1.1111
30-
#
31-
# assert math.log(1.1111) == math.log(Conversible())
32-
33-
# roundings
34-
assert math.trunc(1) == 1
35-
36-
class A(object):
37-
def __trunc__(self):
38-
return 2
39-
40-
assert math.trunc(A()) == 2
41-
42-
class A(object):
43-
def __trunc__(self):
44-
return 2.0
45-
46-
assert math.trunc(A()) == 2.0
47-
48-
class A(object):
49-
def __trunc__(self):
50-
return 'str'
51-
52-
assert math.trunc(A()) == 'str'

tests/snippets/math_module.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import math
2+
from testutils import assertRaises
3+
4+
# assert(math.exp(2) == math.exp(2.0))
5+
# assert(math.exp(True) == math.exp(1.0))
6+
#
7+
# class Conversible():
8+
# def __float__(self):
9+
# print("Converting to float now!")
10+
# return 1.1111
11+
#
12+
# assert math.log(1.1111) == math.log(Conversible())
13+
14+
# roundings
15+
assert int.__trunc__
16+
assert int.__floor__
17+
assert int.__ceil__
18+
19+
# assert float.__trunc__
20+
with assertRaises(AttributeError):
21+
assert float.__floor__
22+
with assertRaises(AttributeError):
23+
assert float.__ceil__
24+
25+
assert math.trunc(2) == 2
26+
assert math.ceil(3) == 3
27+
assert math.floor(4) == 4
28+
29+
assert math.trunc(2.2) == 2
30+
assert math.ceil(3.3) == 4
31+
assert math.floor(4.4) == 4
32+
33+
class A(object):
34+
def __trunc__(self):
35+
return 2
36+
37+
def __ceil__(self):
38+
return 3
39+
40+
def __floor__(self):
41+
return 4
42+
43+
assert math.trunc(A()) == 2
44+
assert math.ceil(A()) == 3
45+
assert math.floor(A()) == 4
46+
47+
class A(object):
48+
def __trunc__(self):
49+
return 2.2
50+
51+
def __ceil__(self):
52+
return 3.3
53+
54+
def __floor__(self):
55+
return 4.4
56+
57+
assert math.trunc(A()) == 2.2
58+
assert math.ceil(A()) == 3.3
59+
assert math.floor(A()) == 4.4
60+
61+
class A(object):
62+
def __trunc__(self):
63+
return 'trunc'
64+
65+
def __ceil__(self):
66+
return 'ceil'
67+
68+
def __floor__(self):
69+
return 'floor'
70+
71+
assert math.trunc(A()) == 'trunc'
72+
assert math.ceil(A()) == 'ceil'
73+
assert math.floor(A()) == 'floor'
74+
75+
with assertRaises(TypeError):
76+
math.trunc(object())
77+
with assertRaises(TypeError):
78+
math.ceil(object())
79+
with assertRaises(TypeError):
80+
math.floor(object())

vm/src/stdlib/math.rs

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use statrs::function::erf::{erf, erfc};
77
use statrs::function::gamma::{gamma, ln_gamma};
88

99
use crate::function::PyFuncArgs;
10-
use crate::obj::objfloat;
10+
use crate::obj::{objfloat, objtype};
1111
use crate::pyobject::{PyObjectRef, PyResult, TypeProtocol};
1212
use crate::vm::VirtualMachine;
1313

@@ -172,20 +172,43 @@ fn math_lgamma(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
172172
}
173173
}
174174

175-
fn math_trunc(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
176-
arg_check!(vm, args, required = [(value, None)]);
177-
const MAGIC_NAME: &str = "__trunc__";
178-
if let Ok(method) = vm.get_method(value.clone(), MAGIC_NAME) {
175+
fn try_magic_method(func_name: &str, vm: &VirtualMachine, value: &PyObjectRef) -> PyResult {
176+
if let Ok(method) = vm.get_method(value.clone(), func_name) {
179177
vm.invoke(method, vec![])
180178
} else {
181179
Err(vm.new_type_error(format!(
182180
"TypeError: type {} doesn't define {} method",
183181
value.class().name,
184-
MAGIC_NAME,
182+
func_name,
185183
)))
186184
}
187185
}
188186

187+
fn math_trunc(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
188+
arg_check!(vm, args, required = [(value, None)]);
189+
try_magic_method("__trunc__", vm, value)
190+
}
191+
192+
fn math_ceil(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
193+
arg_check!(vm, args, required = [(value, None)]);
194+
if objtype::isinstance(value, &vm.ctx.float_type) {
195+
let v = objfloat::get_value(value);
196+
Ok(vm.ctx.new_float(v.ceil()))
197+
} else {
198+
try_magic_method("__ceil__", vm, value)
199+
}
200+
}
201+
202+
fn math_floor(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
203+
arg_check!(vm, args, required = [(value, None)]);
204+
if objtype::isinstance(value, &vm.ctx.float_type) {
205+
let v = objfloat::get_value(value);
206+
Ok(vm.ctx.new_float(v.floor()))
207+
} else {
208+
try_magic_method("__floor__", vm, value)
209+
}
210+
}
211+
189212
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
190213
let ctx = &vm.ctx;
191214

@@ -235,6 +258,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
235258

236259
// Rounding functions:
237260
"trunc" => ctx.new_rustfunc(math_trunc),
261+
"ceil" => ctx.new_rustfunc(math_ceil),
262+
"floor" => ctx.new_rustfunc(math_floor),
238263

239264
// Constants:
240265
"pi" => ctx.new_float(std::f64::consts::PI), // 3.14159...

0 commit comments

Comments
 (0)