Skip to content

Commit 4374d93

Browse files
committed
Add bitwise operators. Add in-place operators. Add missing reflected operators
1 parent 4140527 commit 4374d93

File tree

1 file changed

+202
-2
lines changed

1 file changed

+202
-2
lines changed

arrayfire/array_api/_array_object.py

Lines changed: 202 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__(
8787
if _no_initial_dtype:
8888
raise TypeError("Expected to receive the initial dtype due to the x being a data pointer.")
8989

90-
_type_char = dtype.typecode
90+
_type_char = dtype.typecode # type: ignore[assignment] # FIXME
9191

9292
else:
9393
raise TypeError("Passed object x is an object of unsupported class.")
@@ -138,6 +138,8 @@ def __repr__(self) -> str: # FIXME
138138
def __len__(self) -> int:
139139
return self.shape[0] if self.shape else 0 # type: ignore[return-value]
140140

141+
# Arithmetic Operators
142+
141143
def __pos__(self) -> Array:
142144
"""
143145
Return +self
@@ -148,7 +150,7 @@ def __neg__(self) -> Array:
148150
"""
149151
Return -self
150152
"""
151-
return 0 - self
153+
return 0 - self # type: ignore[no-any-return, operator] # FIXME
152154

153155
def __add__(self, other: int | float | Array, /) -> Array:
154156
# TODO discuss either we need to support complex and bool as other input type
@@ -191,10 +193,92 @@ def __pow__(self, other: int | float | bool | complex | Array, /) -> Array:
191193
"""
192194
return _process_c_function(self, other, backend.get().af_pow)
193195

196+
# Array Operators
197+
194198
def __matmul__(self, other: Array, /) -> Array:
195199
# TODO
196200
return NotImplemented
197201

202+
# Bitwise Operators
203+
204+
def __invert__(self) -> Array:
205+
"""
206+
Return ~self.
207+
"""
208+
out = Array()
209+
safe_call(backend.get().af_bitnot(ctypes.pointer(out.arr), self.arr))
210+
return out
211+
212+
def __and__(self, other: int | bool | Array, /) -> Array:
213+
"""
214+
Return self & other.
215+
"""
216+
return _process_c_function(self, other, backend.get().af_bitand)
217+
218+
def __or__(self, other: int | bool | Array, /) -> Array:
219+
"""
220+
Return self | other.
221+
"""
222+
return _process_c_function(self, other, backend.get().af_bitor)
223+
224+
def __xor__(self, other: int | bool | Array, /) -> Array:
225+
"""
226+
Return self ^ other.
227+
"""
228+
return _process_c_function(self, other, backend.get().af_bitxor)
229+
230+
def __lshift__(self, other: int | Array, /) -> Array:
231+
"""
232+
Return self << other.
233+
"""
234+
return _process_c_function(self, other, backend.get().af_bitshiftl)
235+
236+
def __rshift__(self, other: int | Array, /) -> Array:
237+
"""
238+
Return self >> other.
239+
"""
240+
return _process_c_function(self, other, backend.get().af_bitshiftr)
241+
242+
# Comparison Operators
243+
244+
def __lt__(self, other: int | float | Array, /) -> Array:
245+
"""
246+
Return self < other.
247+
"""
248+
return _process_c_function(self, other, backend.get().af_lt)
249+
250+
def __le__(self, other: int | float | Array, /) -> Array:
251+
"""
252+
Return self <= other.
253+
"""
254+
return _process_c_function(self, other, backend.get().af_le)
255+
256+
def __gt__(self, other: int | float | Array, /) -> Array:
257+
"""
258+
Return self > other.
259+
"""
260+
return _process_c_function(self, other, backend.get().af_gt)
261+
262+
def __ge__(self, other: int | float | Array, /) -> Array:
263+
"""
264+
Return self >= other.
265+
"""
266+
return _process_c_function(self, other, backend.get().af_ge)
267+
268+
def __eq__(self, other: int | float | bool | Array, /) -> Array: # type: ignore[override] # FIXME
269+
"""
270+
Return self == other.
271+
"""
272+
return _process_c_function(self, other, backend.get().af_eq)
273+
274+
def __ne__(self, other: int | float | bool | Array, /) -> Array: # type: ignore[override] # FIXME
275+
"""
276+
Return self != other.
277+
"""
278+
return _process_c_function(self, other, backend.get().af_neq)
279+
280+
# Reflected Arithmetic Operators
281+
198282
def __radd__(self, other: Array, /) -> Array:
199283
# TODO discuss either we need to support complex and bool as other input type
200284
"""
@@ -236,10 +320,125 @@ def __rpow__(self, other: Array, /) -> Array:
236320
"""
237321
return _process_c_function(other, self, backend.get().af_pow)
238322

323+
# Reflected Array Operators
324+
239325
def __rmatmul__(self, other: Array, /) -> Array:
240326
# TODO
241327
return NotImplemented
242328

329+
# Reflected Bitwise Operators
330+
331+
def __rand__(self, other: Array, /) -> Array:
332+
"""
333+
Return other & self.
334+
"""
335+
return _process_c_function(other, self, backend.get().af_bitand)
336+
337+
def __ror__(self, other: Array, /) -> Array:
338+
"""
339+
Return other & self.
340+
"""
341+
return _process_c_function(other, self, backend.get().af_bitor)
342+
343+
def __rxor__(self, other: Array, /) -> Array:
344+
"""
345+
Return other ^ self.
346+
"""
347+
return _process_c_function(other, self, backend.get().af_bitxor)
348+
349+
def __rlshift__(self, other: Array, /) -> Array:
350+
"""
351+
Return other << self.
352+
"""
353+
return _process_c_function(other, self, backend.get().af_bitshiftl)
354+
355+
def __rrshift__(self, other: Array, /) -> Array:
356+
"""
357+
Return other >> self.
358+
"""
359+
return _process_c_function(other, self, backend.get().af_bitshiftr)
360+
361+
# In-place Arithmetic Operators
362+
363+
def __iadd__(self, other: int | float | Array, /) -> Array:
364+
# TODO discuss either we need to support complex and bool as other input type
365+
"""
366+
Return self += other.
367+
"""
368+
return _process_c_function(self, other, backend.get().af_add)
369+
370+
def __isub__(self, other: int | float | bool | complex | Array, /) -> Array:
371+
"""
372+
Return self -= other.
373+
"""
374+
return _process_c_function(self, other, backend.get().af_sub)
375+
376+
def __imul__(self, other: int | float | bool | complex | Array, /) -> Array:
377+
"""
378+
Return self *= other.
379+
"""
380+
return _process_c_function(self, other, backend.get().af_mul)
381+
382+
def __itruediv__(self, other: int | float | bool | complex | Array, /) -> Array:
383+
"""
384+
Return self /= other.
385+
"""
386+
return _process_c_function(self, other, backend.get().af_div)
387+
388+
def __ifloordiv__(self, other: int | float | bool | complex | Array, /) -> Array:
389+
# TODO
390+
return NotImplemented
391+
392+
def __imod__(self, other: int | float | bool | complex | Array, /) -> Array:
393+
"""
394+
Return self %= other.
395+
"""
396+
return _process_c_function(self, other, backend.get().af_mod)
397+
398+
def __ipow__(self, other: int | float | bool | complex | Array, /) -> Array:
399+
"""
400+
Return self **= other.
401+
"""
402+
return _process_c_function(self, other, backend.get().af_pow)
403+
404+
# In-place Array Operators
405+
406+
def __imatmul__(self, other: Array, /) -> Array:
407+
# TODO
408+
return NotImplemented
409+
410+
# In-place Bitwise Operators
411+
412+
def __iand__(self, other: int | bool | Array, /) -> Array:
413+
"""
414+
Return self &= other.
415+
"""
416+
return _process_c_function(self, other, backend.get().af_bitand)
417+
418+
def __ior__(self, other: int | bool | Array, /) -> Array:
419+
"""
420+
Return self |= other.
421+
"""
422+
return _process_c_function(self, other, backend.get().af_bitor)
423+
424+
def __ixor__(self, other: int | bool | Array, /) -> Array:
425+
"""
426+
Return self ^= other.
427+
"""
428+
return _process_c_function(self, other, backend.get().af_bitxor)
429+
430+
def __ilshift__(self, other: int | Array, /) -> Array:
431+
"""
432+
Return self <<= other.
433+
"""
434+
return _process_c_function(self, other, backend.get().af_bitshiftl)
435+
436+
def __irshift__(self, other: int | Array, /) -> Array:
437+
"""
438+
Return self >>= other.
439+
"""
440+
return _process_c_function(self, other, backend.get().af_bitshiftr)
441+
243442
def __getitem__(self, key: int | slice | tuple[int | slice] | Array, /) -> Array:
244443
# TODO: API Specification - key: int | slice | ellipsis | tuple[int | slice] | Array
245444
# TODO: refactor
@@ -405,6 +604,7 @@ def _constant_array(value: int | float | bool | complex, shape: CShape, dtype: D
405604
ctypes.pointer(out.arr), ctypes.c_double(value.real), ctypes.c_double(value.imag), 4,
406605
ctypes.pointer(shape.c_array), dtype.c_api_value))
407606
elif dtype == af_int64:
607+
# TODO discuss workaround for passing float to ctypes
408608
safe_call(backend.get().af_constant_long(
409609
ctypes.pointer(out.arr), ctypes.c_longlong(value.real), 4, ctypes.pointer(shape.c_array)))
410610
elif dtype == af_uint64:

0 commit comments

Comments
 (0)