Skip to content

Commit 9feeee0

Browse files
authored
Merge pull request RustPython#886 from youknowone/pycomplex
Add complex.__bool__, complex.__sub__, OverflowError/float support of complex.__add__
2 parents 53dea48 + dc05459 commit 9feeee0

File tree

2 files changed

+109
-36
lines changed

2 files changed

+109
-36
lines changed

tests/snippets/builtin_complex.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from testutils import assertRaises
2+
13
# __abs__
24

35
assert abs(complex(3, 4)) == 5
@@ -27,6 +29,12 @@
2729
assert -complex(1, -1) == complex(-1, 1)
2830
assert -complex(0, 0) == complex(0, 0)
2931

32+
# __bool__
33+
34+
assert bool(complex(0, 0)) is False
35+
assert bool(complex(0, 1)) is True
36+
assert bool(complex(1, 0)) is True
37+
3038
# real
3139

3240
a = complex(3, 4)
@@ -44,3 +52,32 @@
4452
assert 1j + 1 == complex(1, 1)
4553
assert (1j + 1) + 3 == complex(4, 1)
4654
assert 3 + (1j + 1) == complex(4, 1)
55+
56+
# float and complex addition
57+
assert 1.1 + 1.2j == complex(1.1, 1.2)
58+
assert 1.3j + 1.4 == complex(1.4, 1.3)
59+
assert (1.5j + 1.6) + 3 == complex(4.6, 1.5)
60+
assert 3.5 + (1.1j + 1.2) == complex(4.7, 1.1)
61+
62+
# subtraction
63+
assert 1 - 1j == complex(1, -1)
64+
assert 1j - 1 == complex(-1, 1)
65+
assert 2j - 1j == complex(0, 1)
66+
67+
# type error addition
68+
with assertRaises(TypeError):
69+
assert 1j + 'str'
70+
with assertRaises(TypeError):
71+
assert 1j - 'str'
72+
with assertRaises(TypeError):
73+
assert 'str' + 1j
74+
with assertRaises(TypeError):
75+
assert 'str' - 1j
76+
77+
# overflow
78+
with assertRaises(OverflowError):
79+
complex(10 ** 1000, 0)
80+
with assertRaises(OverflowError):
81+
complex(0, 10 ** 1000)
82+
with assertRaises(OverflowError):
83+
complex(0, 0) + 10 ** 1000

vm/src/obj/objcomplex.rs

Lines changed: 72 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
use num_complex::Complex64;
2-
use num_traits::ToPrimitive;
2+
use num_traits::{ToPrimitive, Zero};
33

44
use crate::function::OptionalArg;
5-
use crate::pyobject::{PyContext, PyObjectRef, PyRef, PyResult, PyValue};
5+
use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue};
66
use crate::vm::VirtualMachine;
77

88
use super::objfloat::{self, PyFloat};
99
use super::objint;
1010
use super::objtype::{self, PyClassRef};
1111

12+
#[pyclass(name = "complex")]
1213
#[derive(Debug, Copy, Clone, PartialEq)]
1314
pub struct PyComplex {
1415
value: Complex64,
@@ -28,24 +29,14 @@ impl From<Complex64> for PyComplex {
2829
}
2930

3031
pub fn init(context: &PyContext) {
31-
let complex_type = &context.complex_type;
32-
32+
PyComplex::extend_class(context, &context.complex_type);
3333
let complex_doc =
3434
"Create a complex number from a real part and an optional imaginary part.\n\n\
3535
This is equivalent to (real + imag*1j) where imag defaults to 0.";
3636

37-
extend_class!(context, complex_type, {
37+
extend_class!(context, &context.complex_type, {
3838
"__doc__" => context.new_str(complex_doc.to_string()),
39-
"__abs__" => context.new_rustfunc(PyComplexRef::abs),
40-
"__add__" => context.new_rustfunc(PyComplexRef::add),
41-
"__eq__" => context.new_rustfunc(PyComplexRef::eq),
42-
"__neg__" => context.new_rustfunc(PyComplexRef::neg),
4339
"__new__" => context.new_rustfunc(PyComplexRef::new),
44-
"__radd__" => context.new_rustfunc(PyComplexRef::radd),
45-
"__repr__" => context.new_rustfunc(PyComplexRef::repr),
46-
"conjugate" => context.new_rustfunc(PyComplexRef::conjugate),
47-
"imag" => context.new_property(PyComplexRef::imag),
48-
"real" => context.new_property(PyComplexRef::real)
4940
});
5041
}
5142

@@ -73,49 +64,87 @@ impl PyComplexRef {
7364
let value = Complex64::new(real, imag);
7465
PyComplex { value }.into_ref_with_type(vm, cls)
7566
}
67+
}
7668

77-
fn real(self, _vm: &VirtualMachine) -> PyFloat {
69+
fn to_complex(value: PyObjectRef, vm: &VirtualMachine) -> PyResult<Option<Complex64>> {
70+
if objtype::isinstance(&value, &vm.ctx.int_type()) {
71+
match objint::get_value(&value).to_f64() {
72+
Some(v) => Ok(Some(Complex64::new(v, 0.0))),
73+
None => Err(vm.new_overflow_error("int too large to convert to float".to_string())),
74+
}
75+
} else if objtype::isinstance(&value, &vm.ctx.float_type()) {
76+
let v = objfloat::get_value(&value);
77+
Ok(Some(Complex64::new(v, 0.0)))
78+
} else {
79+
Ok(None)
80+
}
81+
}
82+
83+
#[pyimpl]
84+
impl PyComplex {
85+
#[pyproperty(name = "real")]
86+
fn real(&self, _vm: &VirtualMachine) -> PyFloat {
7887
self.value.re.into()
7988
}
8089

81-
fn imag(self, _vm: &VirtualMachine) -> PyFloat {
90+
#[pyproperty(name = "imag")]
91+
fn imag(&self, _vm: &VirtualMachine) -> PyFloat {
8292
self.value.im.into()
8393
}
8494

85-
fn abs(self, _vm: &VirtualMachine) -> PyFloat {
95+
#[pymethod(name = "__abs__")]
96+
fn abs(&self, _vm: &VirtualMachine) -> PyFloat {
8697
let Complex64 { im, re } = self.value;
8798
re.hypot(im).into()
8899
}
89100

90-
fn add(self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
101+
#[pymethod(name = "__add__")]
102+
fn add(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
91103
if objtype::isinstance(&other, &vm.ctx.complex_type()) {
92-
vm.ctx.new_complex(self.value + get_value(&other))
93-
} else if objtype::isinstance(&other, &vm.ctx.int_type()) {
94-
vm.ctx.new_complex(Complex64::new(
95-
self.value.re + objint::get_value(&other).to_f64().unwrap(),
96-
self.value.im,
97-
))
104+
Ok(vm.ctx.new_complex(self.value + get_value(&other)))
98105
} else {
99-
vm.ctx.not_implemented()
106+
self.radd(other, vm)
100107
}
101108
}
102109

103-
fn radd(self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
104-
if objtype::isinstance(&other, &vm.ctx.int_type()) {
105-
vm.ctx.new_complex(Complex64::new(
106-
self.value.re + objint::get_value(&other).to_f64().unwrap(),
107-
self.value.im,
108-
))
110+
#[pymethod(name = "__radd__")]
111+
fn radd(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
112+
match to_complex(other, vm) {
113+
Ok(Some(other)) => Ok(vm.ctx.new_complex(self.value + other)),
114+
Ok(None) => Ok(vm.ctx.not_implemented()),
115+
Err(err) => Err(err),
116+
}
117+
}
118+
119+
#[pymethod(name = "__sub__")]
120+
fn sub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
121+
if objtype::isinstance(&other, &vm.ctx.complex_type()) {
122+
Ok(vm.ctx.new_complex(self.value - get_value(&other)))
109123
} else {
110-
vm.ctx.not_implemented()
124+
match to_complex(other, vm) {
125+
Ok(Some(other)) => Ok(vm.ctx.new_complex(self.value - other)),
126+
Ok(None) => Ok(vm.ctx.not_implemented()),
127+
Err(err) => Err(err),
128+
}
129+
}
130+
}
131+
132+
#[pymethod(name = "__rsub__")]
133+
fn rsub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
134+
match to_complex(other, vm) {
135+
Ok(Some(other)) => Ok(vm.ctx.new_complex(other - self.value)),
136+
Ok(None) => Ok(vm.ctx.not_implemented()),
137+
Err(err) => Err(err),
111138
}
112139
}
113140

114-
fn conjugate(self, _vm: &VirtualMachine) -> PyComplex {
141+
#[pymethod(name = "conjugate")]
142+
fn conjugate(&self, _vm: &VirtualMachine) -> PyComplex {
115143
self.value.conj().into()
116144
}
117145

118-
fn eq(self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
146+
#[pymethod(name = "__eq__")]
147+
fn eq(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
119148
let result = if objtype::isinstance(&other, &vm.ctx.complex_type()) {
120149
self.value == get_value(&other)
121150
} else if objtype::isinstance(&other, &vm.ctx.int_type()) {
@@ -132,16 +161,23 @@ impl PyComplexRef {
132161
vm.ctx.new_bool(result)
133162
}
134163

135-
fn neg(self, _vm: &VirtualMachine) -> PyComplex {
164+
#[pymethod(name = "__neg__")]
165+
fn neg(&self, _vm: &VirtualMachine) -> PyComplex {
136166
PyComplex::from(-self.value)
137167
}
138168

139-
fn repr(self, _vm: &VirtualMachine) -> String {
169+
#[pymethod(name = "__repr__")]
170+
fn repr(&self, _vm: &VirtualMachine) -> String {
140171
let Complex64 { re, im } = self.value;
141172
if re == 0.0 {
142173
format!("{}j", im)
143174
} else {
144175
format!("({}+{}j)", re, im)
145176
}
146177
}
178+
179+
#[pymethod(name = "__bool__")]
180+
fn bool(&self, _vm: &VirtualMachine) -> bool {
181+
self.value != Complex64::zero()
182+
}
147183
}

0 commit comments

Comments
 (0)