Skip to content

Commit f333c75

Browse files
committed
Added the ability to do addition between complex numbers and ints.
1 parent 38b4c10 commit f333c75

File tree

3 files changed

+30
-5
lines changed

3 files changed

+30
-5
lines changed

tests/snippets/builtin_complex.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,5 @@
4242
# int and complex addition
4343
assert 1 + 1j == complex(1, 1)
4444
assert 1j + 1 == complex(1, 1)
45+
assert (1j + 1) + 3 == complex(4, 1)
46+
assert 3 + (1j + 1) == complex(4, 1)

vm/src/obj/objcomplex.rs

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ pub fn init(context: &PyContext) {
1717

1818
context.set_attr(&complex_type, "__abs__", context.new_rustfunc(complex_abs));
1919
context.set_attr(&complex_type, "__add__", context.new_rustfunc(complex_add));
20+
context.set_attr(
21+
&complex_type,
22+
"__radd__",
23+
context.new_rustfunc(complex_radd),
24+
);
2025
context.set_attr(&complex_type, "__eq__", context.new_rustfunc(complex_eq));
2126
context.set_attr(&complex_type, "__neg__", context.new_rustfunc(complex_neg));
2227
context.set_attr(&complex_type, "__new__", context.new_rustfunc(complex_new));
@@ -107,7 +112,29 @@ fn complex_add(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
107112
if objtype::isinstance(i2, &vm.ctx.complex_type()) {
108113
Ok(vm.ctx.new_complex(v1 + get_value(i2)))
109114
} else if objtype::isinstance(i2, &vm.ctx.int_type()) {
110-
Ok(vm.ctx.new_complex(Complex64::new(v1.re + objint::get_value(i2).to_f64().unwrap(), v1.im)))
115+
Ok(vm.ctx.new_complex(Complex64::new(
116+
v1.re + objint::get_value(i2).to_f64().unwrap(),
117+
v1.im,
118+
)))
119+
} else {
120+
Err(vm.new_type_error(format!("Cannot add {} and {}", i.borrow(), i2.borrow())))
121+
}
122+
}
123+
124+
fn complex_radd(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
125+
arg_check!(
126+
vm,
127+
args,
128+
required = [(i, Some(vm.ctx.complex_type())), (i2, None)]
129+
);
130+
131+
let v1 = get_value(i);
132+
133+
if objtype::isinstance(i2, &vm.ctx.int_type()) {
134+
Ok(vm.ctx.new_complex(Complex64::new(
135+
v1.re + objint::get_value(i2).to_f64().unwrap(),
136+
v1.im,
137+
)))
111138
} else {
112139
Err(vm.new_type_error(format!("Cannot add {} and {}", i.borrow(), i2.borrow())))
113140
}

vm/src/obj/objint.rs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use super::objfloat;
22
use super::objstr;
33
use super::objtype;
4-
use super::objcomplex;
54
use crate::format::FormatSpec;
65
use crate::pyobject::{
76
FromPyObjectRef, PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult,
@@ -11,7 +10,6 @@ use crate::vm::VirtualMachine;
1110
use num_bigint::{BigInt, ToBigInt};
1211
use num_integer::Integer;
1312
use num_traits::{Pow, Signed, ToPrimitive, Zero};
14-
use num_complex::Complex64;
1513
use std::hash::{Hash, Hasher};
1614

1715
// This proxy allows for easy switching between types.
@@ -291,8 +289,6 @@ fn int_add(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
291289
);
292290
if objtype::isinstance(other, &vm.ctx.int_type()) {
293291
Ok(vm.ctx.new_int(get_value(zelf) + get_value(other)))
294-
} else if objtype::isinstance(other, &vm.ctx.complex_type()) {
295-
Ok(vm.ctx.new_complex(Complex64::new(get_value(zelf).to_f64().unwrap() + objcomplex::get_value(other).re, objcomplex::get_value(other).im)))
296292
} else {
297293
Ok(vm.ctx.not_implemented())
298294
}

0 commit comments

Comments
 (0)