1
1
use num_complex:: Complex64 ;
2
- use num_traits:: ToPrimitive ;
2
+ use num_traits:: { ToPrimitive , Zero } ;
3
3
4
4
use crate :: function:: OptionalArg ;
5
- use crate :: pyobject:: { PyContext , PyObjectRef , PyRef , PyResult , PyValue } ;
5
+ use crate :: pyobject:: { PyClassImpl , PyContext , PyObjectRef , PyRef , PyResult , PyValue } ;
6
6
use crate :: vm:: VirtualMachine ;
7
7
8
8
use super :: objfloat:: { self , PyFloat } ;
9
9
use super :: objint;
10
10
use super :: objtype:: { self , PyClassRef } ;
11
11
12
+ #[ pyclass( name = "complex" ) ]
12
13
#[ derive( Debug , Copy , Clone , PartialEq ) ]
13
14
pub struct PyComplex {
14
15
value : Complex64 ,
@@ -28,24 +29,14 @@ impl From<Complex64> for PyComplex {
28
29
}
29
30
30
31
pub fn init ( context : & PyContext ) {
31
- let complex_type = & context. complex_type ;
32
-
32
+ PyComplex :: extend_class ( context, & context. complex_type ) ;
33
33
let complex_doc =
34
34
"Create a complex number from a real part and an optional imaginary part.\n \n \
35
35
This is equivalent to (real + imag*1j) where imag defaults to 0.";
36
36
37
- extend_class ! ( context, complex_type, {
37
+ extend_class ! ( context, & context . complex_type, {
38
38
"__doc__" => context. new_str( complex_doc. to_string( ) ) ,
39
- "__abs__" => context. new_rustfunc( PyComplexRef :: abs) ,
40
- "__add__" => context. new_rustfunc( PyComplexRef :: add) ,
41
- "__eq__" => context. new_rustfunc( PyComplexRef :: eq) ,
42
- "__neg__" => context. new_rustfunc( PyComplexRef :: neg) ,
43
39
"__new__" => context. new_rustfunc( PyComplexRef :: new) ,
44
- "__radd__" => context. new_rustfunc( PyComplexRef :: radd) ,
45
- "__repr__" => context. new_rustfunc( PyComplexRef :: repr) ,
46
- "conjugate" => context. new_rustfunc( PyComplexRef :: conjugate) ,
47
- "imag" => context. new_property( PyComplexRef :: imag) ,
48
- "real" => context. new_property( PyComplexRef :: real)
49
40
} ) ;
50
41
}
51
42
@@ -73,49 +64,87 @@ impl PyComplexRef {
73
64
let value = Complex64 :: new ( real, imag) ;
74
65
PyComplex { value } . into_ref_with_type ( vm, cls)
75
66
}
67
+ }
76
68
77
- fn real ( self , _vm : & VirtualMachine ) -> PyFloat {
69
+ fn to_complex ( value : PyObjectRef , vm : & VirtualMachine ) -> PyResult < Option < Complex64 > > {
70
+ if objtype:: isinstance ( & value, & vm. ctx . int_type ( ) ) {
71
+ match objint:: get_value ( & value) . to_f64 ( ) {
72
+ Some ( v) => Ok ( Some ( Complex64 :: new ( v, 0.0 ) ) ) ,
73
+ None => Err ( vm. new_overflow_error ( "int too large to convert to float" . to_string ( ) ) ) ,
74
+ }
75
+ } else if objtype:: isinstance ( & value, & vm. ctx . float_type ( ) ) {
76
+ let v = objfloat:: get_value ( & value) ;
77
+ Ok ( Some ( Complex64 :: new ( v, 0.0 ) ) )
78
+ } else {
79
+ Ok ( None )
80
+ }
81
+ }
82
+
83
+ #[ pyimpl]
84
+ impl PyComplex {
85
+ #[ pyproperty( name = "real" ) ]
86
+ fn real ( & self , _vm : & VirtualMachine ) -> PyFloat {
78
87
self . value . re . into ( )
79
88
}
80
89
81
- fn imag ( self , _vm : & VirtualMachine ) -> PyFloat {
90
+ #[ pyproperty( name = "imag" ) ]
91
+ fn imag ( & self , _vm : & VirtualMachine ) -> PyFloat {
82
92
self . value . im . into ( )
83
93
}
84
94
85
- fn abs ( self , _vm : & VirtualMachine ) -> PyFloat {
95
+ #[ pymethod( name = "__abs__" ) ]
96
+ fn abs ( & self , _vm : & VirtualMachine ) -> PyFloat {
86
97
let Complex64 { im, re } = self . value ;
87
98
re. hypot ( im) . into ( )
88
99
}
89
100
90
- fn add ( self , other : PyObjectRef , vm : & VirtualMachine ) -> PyObjectRef {
101
+ #[ pymethod( name = "__add__" ) ]
102
+ fn add ( & self , other : PyObjectRef , vm : & VirtualMachine ) -> PyResult {
91
103
if objtype:: isinstance ( & other, & vm. ctx . complex_type ( ) ) {
92
- vm. ctx . new_complex ( self . value + get_value ( & other) )
93
- } else if objtype:: isinstance ( & other, & vm. ctx . int_type ( ) ) {
94
- vm. ctx . new_complex ( Complex64 :: new (
95
- self . value . re + objint:: get_value ( & other) . to_f64 ( ) . unwrap ( ) ,
96
- self . value . im ,
97
- ) )
104
+ Ok ( vm. ctx . new_complex ( self . value + get_value ( & other) ) )
98
105
} else {
99
- vm . ctx . not_implemented ( )
106
+ self . radd ( other , vm )
100
107
}
101
108
}
102
109
103
- fn radd ( self , other : PyObjectRef , vm : & VirtualMachine ) -> PyObjectRef {
104
- if objtype:: isinstance ( & other, & vm. ctx . int_type ( ) ) {
105
- vm. ctx . new_complex ( Complex64 :: new (
106
- self . value . re + objint:: get_value ( & other) . to_f64 ( ) . unwrap ( ) ,
107
- self . value . im ,
108
- ) )
110
+ #[ pymethod( name = "__radd__" ) ]
111
+ fn radd ( & self , other : PyObjectRef , vm : & VirtualMachine ) -> PyResult {
112
+ match to_complex ( other, vm) {
113
+ Ok ( Some ( other) ) => Ok ( vm. ctx . new_complex ( self . value + other) ) ,
114
+ Ok ( None ) => Ok ( vm. ctx . not_implemented ( ) ) ,
115
+ Err ( err) => Err ( err) ,
116
+ }
117
+ }
118
+
119
+ #[ pymethod( name = "__sub__" ) ]
120
+ fn sub ( & self , other : PyObjectRef , vm : & VirtualMachine ) -> PyResult {
121
+ if objtype:: isinstance ( & other, & vm. ctx . complex_type ( ) ) {
122
+ Ok ( vm. ctx . new_complex ( self . value - get_value ( & other) ) )
109
123
} else {
110
- vm. ctx . not_implemented ( )
124
+ match to_complex ( other, vm) {
125
+ Ok ( Some ( other) ) => Ok ( vm. ctx . new_complex ( self . value - other) ) ,
126
+ Ok ( None ) => Ok ( vm. ctx . not_implemented ( ) ) ,
127
+ Err ( err) => Err ( err) ,
128
+ }
129
+ }
130
+ }
131
+
132
+ #[ pymethod( name = "__rsub__" ) ]
133
+ fn rsub ( & self , other : PyObjectRef , vm : & VirtualMachine ) -> PyResult {
134
+ match to_complex ( other, vm) {
135
+ Ok ( Some ( other) ) => Ok ( vm. ctx . new_complex ( other - self . value ) ) ,
136
+ Ok ( None ) => Ok ( vm. ctx . not_implemented ( ) ) ,
137
+ Err ( err) => Err ( err) ,
111
138
}
112
139
}
113
140
114
- fn conjugate ( self , _vm : & VirtualMachine ) -> PyComplex {
141
+ #[ pymethod( name = "conjugate" ) ]
142
+ fn conjugate ( & self , _vm : & VirtualMachine ) -> PyComplex {
115
143
self . value . conj ( ) . into ( )
116
144
}
117
145
118
- fn eq ( self , other : PyObjectRef , vm : & VirtualMachine ) -> PyObjectRef {
146
+ #[ pymethod( name = "__eq__" ) ]
147
+ fn eq ( & self , other : PyObjectRef , vm : & VirtualMachine ) -> PyObjectRef {
119
148
let result = if objtype:: isinstance ( & other, & vm. ctx . complex_type ( ) ) {
120
149
self . value == get_value ( & other)
121
150
} else if objtype:: isinstance ( & other, & vm. ctx . int_type ( ) ) {
@@ -132,16 +161,23 @@ impl PyComplexRef {
132
161
vm. ctx . new_bool ( result)
133
162
}
134
163
135
- fn neg ( self , _vm : & VirtualMachine ) -> PyComplex {
164
+ #[ pymethod( name = "__neg__" ) ]
165
+ fn neg ( & self , _vm : & VirtualMachine ) -> PyComplex {
136
166
PyComplex :: from ( -self . value )
137
167
}
138
168
139
- fn repr ( self , _vm : & VirtualMachine ) -> String {
169
+ #[ pymethod( name = "__repr__" ) ]
170
+ fn repr ( & self , _vm : & VirtualMachine ) -> String {
140
171
let Complex64 { re, im } = self . value ;
141
172
if re == 0.0 {
142
173
format ! ( "{}j" , im)
143
174
} else {
144
175
format ! ( "({}+{}j)" , re, im)
145
176
}
146
177
}
178
+
179
+ #[ pymethod( name = "__bool__" ) ]
180
+ fn bool ( & self , _vm : & VirtualMachine ) -> bool {
181
+ self . value != Complex64 :: zero ( )
182
+ }
147
183
}
0 commit comments