Skip to content

Commit 5f8dce0

Browse files
committed
Alias ctypes data types to custom data types
1 parent a79ccc5 commit 5f8dce0

22 files changed

+506
-490
lines changed

arrayfire/algorithm.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,29 @@
1616

1717
def _parallel_dim(a, dim, c_func):
1818
out = Array()
19-
safe_call(c_func(ct.pointer(out.arr), a.arr, ct.c_int(dim)))
19+
safe_call(c_func(c_pointer(out.arr), a.arr, c_int_t(dim)))
2020
return out
2121

2222
def _reduce_all(a, c_func):
23-
real = ct.c_double(0)
24-
imag = ct.c_double(0)
23+
real = c_double_t(0)
24+
imag = c_double_t(0)
2525

26-
safe_call(c_func(ct.pointer(real), ct.pointer(imag), a.arr))
26+
safe_call(c_func(c_pointer(real), c_pointer(imag), a.arr))
2727

2828
real = real.value
2929
imag = imag.value
3030
return real if imag == 0 else real + imag * 1j
3131

3232
def _nan_parallel_dim(a, dim, c_func, nan_val):
3333
out = Array()
34-
safe_call(c_func(ct.pointer(out.arr), a.arr, ct.c_int(dim), ct.c_double(nan_val)))
34+
safe_call(c_func(c_pointer(out.arr), a.arr, c_int_t(dim), c_double_t(nan_val)))
3535
return out
3636

3737
def _nan_reduce_all(a, c_func, nan_val):
38-
real = ct.c_double(0)
39-
imag = ct.c_double(0)
38+
real = c_double_t(0)
39+
imag = c_double_t(0)
4040

41-
safe_call(c_func(ct.pointer(real), ct.pointer(imag), a.arr, ct.c_double(nan_val)))
41+
safe_call(c_func(c_pointer(real), c_pointer(imag), a.arr, c_double_t(nan_val)))
4242

4343
real = real.value
4444
imag = imag.value
@@ -235,13 +235,13 @@ def imin(a, dim=None):
235235
if dim is not None:
236236
out = Array()
237237
idx = Array()
238-
safe_call(backend.get().af_imin(ct.pointer(out.arr), ct.pointer(idx.arr), a.arr, ct.c_int(dim)))
238+
safe_call(backend.get().af_imin(c_pointer(out.arr), c_pointer(idx.arr), a.arr, c_int_t(dim)))
239239
return out,idx
240240
else:
241-
real = ct.c_double(0)
242-
imag = ct.c_double(0)
243-
idx = ct.c_uint(0)
244-
safe_call(backend.get().af_imin_all(ct.pointer(real), ct.pointer(imag), ct.pointer(idx), a.arr))
241+
real = c_double_t(0)
242+
imag = c_double_t(0)
243+
idx = c_uint_t(0)
244+
safe_call(backend.get().af_imin_all(c_pointer(real), c_pointer(imag), c_pointer(idx), a.arr))
245245
real = real.value
246246
imag = imag.value
247247
val = real if imag == 0 else real + imag * 1j
@@ -268,13 +268,13 @@ def imax(a, dim=None):
268268
if dim is not None:
269269
out = Array()
270270
idx = Array()
271-
safe_call(backend.get().af_imax(ct.pointer(out.arr), ct.pointer(idx.arr), a.arr, ct.c_int(dim)))
271+
safe_call(backend.get().af_imax(c_pointer(out.arr), c_pointer(idx.arr), a.arr, c_int_t(dim)))
272272
return out,idx
273273
else:
274-
real = ct.c_double(0)
275-
imag = ct.c_double(0)
276-
idx = ct.c_uint(0)
277-
safe_call(backend.get().af_imax_all(ct.pointer(real), ct.pointer(imag), ct.pointer(idx), a.arr))
274+
real = c_double_t(0)
275+
imag = c_double_t(0)
276+
idx = c_uint_t(0)
277+
safe_call(backend.get().af_imax_all(c_pointer(real), c_pointer(imag), c_pointer(idx), a.arr))
278278
real = real.value
279279
imag = imag.value
280280
val = real if imag == 0 else real + imag * 1j
@@ -327,7 +327,7 @@ def scan(a, dim=0, op=BINARYOP.ADD, inclusive_scan=True):
327327
- will contain scan of input.
328328
"""
329329
out = Array()
330-
safe_call(backend.get().af_scan(ct.pointer(out.arr), a.arr, dim, op.value, inclusive_scan))
330+
safe_call(backend.get().af_scan(c_pointer(out.arr), a.arr, dim, op.value, inclusive_scan))
331331
return out
332332

333333
def scan_by_key(key, a, dim=0, op=BINARYOP.ADD, inclusive_scan=True):
@@ -361,7 +361,7 @@ def scan_by_key(key, a, dim=0, op=BINARYOP.ADD, inclusive_scan=True):
361361
- will contain scan of input.
362362
"""
363363
out = Array()
364-
safe_call(backend.get().af_scan_by_key(ct.pointer(out.arr), key.arr, a.arr, dim, op.value, inclusive_scan))
364+
safe_call(backend.get().af_scan_by_key(c_pointer(out.arr), key.arr, a.arr, dim, op.value, inclusive_scan))
365365
return out
366366

367367
def where(a):
@@ -379,7 +379,7 @@ def where(a):
379379
Linear indices for non zero elements.
380380
"""
381381
out = Array()
382-
safe_call(backend.get().af_where(ct.pointer(out.arr), a.arr))
382+
safe_call(backend.get().af_where(c_pointer(out.arr), a.arr))
383383
return out
384384

385385
def diff1(a, dim=0):
@@ -441,7 +441,7 @@ def sort(a, dim=0, is_ascending=True):
441441
Currently `dim` is only supported for 0.
442442
"""
443443
out = Array()
444-
safe_call(backend.get().af_sort(ct.pointer(out.arr), a.arr, ct.c_uint(dim), ct.c_bool(is_ascending)))
444+
safe_call(backend.get().af_sort(c_pointer(out.arr), a.arr, c_uint_t(dim), c_bool_t(is_ascending)))
445445
return out
446446

447447
def sort_index(a, dim=0, is_ascending=True):
@@ -469,8 +469,8 @@ def sort_index(a, dim=0, is_ascending=True):
469469
"""
470470
out = Array()
471471
idx = Array()
472-
safe_call(backend.get().af_sort_index(ct.pointer(out.arr), ct.pointer(idx.arr), a.arr,
473-
ct.c_uint(dim), ct.c_bool(is_ascending)))
472+
safe_call(backend.get().af_sort_index(c_pointer(out.arr), c_pointer(idx.arr), a.arr,
473+
c_uint_t(dim), c_bool_t(is_ascending)))
474474
return out,idx
475475

476476
def sort_by_key(iv, ik, dim=0, is_ascending=True):
@@ -500,8 +500,8 @@ def sort_by_key(iv, ik, dim=0, is_ascending=True):
500500
"""
501501
ov = Array()
502502
ok = Array()
503-
safe_call(backend.get().af_sort_by_key(ct.pointer(ov.arr), ct.pointer(ok.arr),
504-
iv.arr, ik.arr, ct.c_uint(dim), ct.c_bool(is_ascending)))
503+
safe_call(backend.get().af_sort_by_key(c_pointer(ov.arr), c_pointer(ok.arr),
504+
iv.arr, ik.arr, c_uint_t(dim), c_bool_t(is_ascending)))
505505
return ov,ok
506506

507507
def set_unique(a, is_sorted=False):
@@ -521,7 +521,7 @@ def set_unique(a, is_sorted=False):
521521
an array containing the unique values from `a`
522522
"""
523523
out = Array()
524-
safe_call(backend.get().af_set_unique(ct.pointer(out.arr), a.arr, ct.c_bool(is_sorted)))
524+
safe_call(backend.get().af_set_unique(c_pointer(out.arr), a.arr, c_bool_t(is_sorted)))
525525
return out
526526

527527
def set_union(a, b, is_unique=False):
@@ -543,7 +543,7 @@ def set_union(a, b, is_unique=False):
543543
an array values after performing the union of `a` and `b`.
544544
"""
545545
out = Array()
546-
safe_call(backend.get().af_set_union(ct.pointer(out.arr), a.arr, b.arr, ct.c_bool(is_unique)))
546+
safe_call(backend.get().af_set_union(c_pointer(out.arr), a.arr, b.arr, c_bool_t(is_unique)))
547547
return out
548548

549549
def set_intersect(a, b, is_unique=False):
@@ -565,5 +565,5 @@ def set_intersect(a, b, is_unique=False):
565565
an array values after performing the intersect of `a` and `b`.
566566
"""
567567
out = Array()
568-
safe_call(backend.get().af_set_intersect(ct.pointer(out.arr), a.arr, b.arr, ct.c_bool(is_unique)))
568+
safe_call(backend.get().af_set_intersect(c_pointer(out.arr), a.arr, b.arr, c_bool_t(is_unique)))
569569
return out

arrayfire/arith.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,27 +26,27 @@ def _arith_binary_func(lhs, rhs, c_func):
2626
raise TypeError("Atleast one input needs to be of type arrayfire.array")
2727

2828
elif (is_left_array and is_right_array):
29-
safe_call(c_func(ct.pointer(out.arr), lhs.arr, rhs.arr, _bcast_var.get()))
29+
safe_call(c_func(c_pointer(out.arr), lhs.arr, rhs.arr, _bcast_var.get()))
3030

3131
elif (_is_number(rhs)):
3232
ldims = dim4_to_tuple(lhs.dims())
3333
rty = implicit_dtype(rhs, lhs.type())
3434
other = Array()
3535
other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], rty)
36-
safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, _bcast_var.get()))
36+
safe_call(c_func(c_pointer(out.arr), lhs.arr, other.arr, _bcast_var.get()))
3737

3838
else:
3939
rdims = dim4_to_tuple(rhs.dims())
4040
lty = implicit_dtype(lhs, rhs.type())
4141
other = Array()
4242
other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], lty)
43-
safe_call(c_func(ct.pointer(out.arr), other.arr, rhs.arr, _bcast_var.get()))
43+
safe_call(c_func(c_pointer(out.arr), other.arr, rhs.arr, _bcast_var.get()))
4444

4545
return out
4646

4747
def _arith_unary_func(a, c_func):
4848
out = Array()
49-
safe_call(c_func(ct.pointer(out.arr), a.arr))
49+
safe_call(c_func(c_pointer(out.arr), a.arr))
5050
return out
5151

5252
def cast(a, dtype):
@@ -75,7 +75,7 @@ def cast(a, dtype):
7575
array containing the values from `a` after converting to `dtype`.
7676
"""
7777
out=Array()
78-
safe_call(backend.get().af_cast(ct.pointer(out.arr), a.arr, dtype.value))
78+
safe_call(backend.get().af_cast(c_pointer(out.arr), a.arr, dtype.value))
7979
return out
8080

8181
def minof(lhs, rhs):
@@ -160,7 +160,7 @@ def clamp(val, low, high):
160160
else:
161161
high_arr = high.arr
162162

163-
safe_call(backend.get().af_clamp(ct.pointer(out.arr), val.arr, low_arr, high_arr, _bcast_var.get()))
163+
safe_call(backend.get().af_clamp(c_pointer(out.arr), val.arr, low_arr, high_arr, _bcast_var.get()))
164164

165165
return out
166166

0 commit comments

Comments
 (0)