Skip to content

Commit 3dee59e

Browse files
committed
Fix comparison functions output type
1 parent b084d21 commit 3dee59e

File tree

2 files changed

+93
-36
lines changed

2 files changed

+93
-36
lines changed

examples/conway.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,6 @@ fn conways_game_of_life() {
2424
let c0 = &eq(&n_hood, &c0, false);
2525
let c1 = &eq(&n_hood, &c1, false);
2626
state = state * c0 + c1;
27-
win.draw_image(&normalise(&state), None);
27+
win.draw_image(&normalise(&state.cast::<f32>()), None);
2828
}
2929
}

src/arith/mod.rs

Lines changed: 92 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -433,17 +433,6 @@ macro_rules! overloaded_binary_func {
433433
///
434434
/// An Array with results of the binary operation.
435435
///
436-
/// In the case of comparison operations such as the following, the type of output
437-
/// Array is [DType::B8](./enum.DType.html). To retrieve the results of such boolean output
438-
/// to host, an array of 8-bit wide types(eg. u8, i8) should be used since ArrayFire's internal
439-
/// implementation uses char for boolean.
440-
///
441-
/// * [gt](./fn.gt.html)
442-
/// * [lt](./fn.lt.html)
443-
/// * [ge](./fn.ge.html)
444-
/// * [le](./fn.le.html)
445-
/// * [eq](./fn.eq.html)
446-
///
447436
///# Note
448437
///
449438
/// The trait `Convertable` essentially translates to a scalar native type on rust or Array.
@@ -487,55 +476,123 @@ overloaded_binary_func!("Compute remainder from two Arrays", rem, rem_helper, af
487476
overloaded_binary_func!("Compute left shift", shiftl, shiftl_helper, af_bitshiftl);
488477
overloaded_binary_func!("Compute right shift", shiftr, shiftr_helper, af_bitshiftr);
489478
overloaded_binary_func!(
479+
"Compute modulo of two Arrays",
480+
modulo,
481+
modulo_helper,
482+
af_mod
483+
);
484+
overloaded_binary_func!(
485+
"Calculate atan2 of two Arrays",
486+
atan2,
487+
atan2_helper,
488+
af_atan2
489+
);
490+
overloaded_binary_func!(
491+
"Create complex array from two Arrays",
492+
cplx2,
493+
cplx2_helper,
494+
af_cplx2
495+
);
496+
overloaded_binary_func!("Compute root", root, root_helper, af_root);
497+
overloaded_binary_func!("Computer power", pow, pow_helper, af_pow);
498+
499+
macro_rules! overloaded_compare_func {
500+
($doc_str: expr, $fn_name: ident, $help_name: ident, $ffi_name: ident) => {
501+
fn $help_name<A, B>(lhs: &Array<A>, rhs: &Array<B>, batch: bool) -> Array<bool>
502+
where
503+
A: HasAfEnum + ImplicitPromote<B>,
504+
B: HasAfEnum + ImplicitPromote<A>,
505+
{
506+
let mut temp: i64 = 0;
507+
unsafe {
508+
let err_val = $ffi_name(
509+
&mut temp as MutAfArray,
510+
lhs.get() as AfArray,
511+
rhs.get() as AfArray,
512+
batch as c_int,
513+
);
514+
HANDLE_ERROR(AfError::from(err_val));
515+
}
516+
temp.into()
517+
}
518+
519+
#[doc=$doc_str]
520+
///
521+
/// This is a comparison operation.
522+
///
523+
///# Parameters
524+
///
525+
/// - `arg1`is an argument that implements an internal trait `Convertable`.
526+
/// - `arg2`is an argument that implements an internal trait `Convertable`.
527+
/// - `batch` is an boolean that indicates if the current operation is an batch operation.
528+
///
529+
/// Both parameters `arg1` and `arg2` can be either an Array or a value of rust integral
530+
/// type.
531+
///
532+
///# Return Values
533+
///
534+
/// An Array with results of the comparison operation a.k.a an Array of boolean values.
535+
///# Note
536+
///
537+
/// The trait `Convertable` essentially translates to a scalar native type on rust or Array.
538+
pub fn $fn_name<T, U>(
539+
arg1: &T,
540+
arg2: &U,
541+
batch: bool,
542+
) -> Array<bool>
543+
where
544+
T: Convertable,
545+
U: Convertable,
546+
<T as Convertable>::OutType: HasAfEnum + ImplicitPromote<<U as Convertable>::OutType>,
547+
<U as Convertable>::OutType: HasAfEnum + ImplicitPromote<<T as Convertable>::OutType>,
548+
{
549+
let lhs = arg1.convert(); // Convert to Array<T>
550+
let rhs = arg2.convert(); // Convert to Array<T>
551+
match (lhs.is_scalar(), rhs.is_scalar()) {
552+
(true, false) => {
553+
let l = tile(&lhs, rhs.dims());
554+
$help_name(&l, &rhs, batch)
555+
}
556+
(false, true) => {
557+
let r = tile(&rhs, lhs.dims());
558+
$help_name(&lhs, &r, batch)
559+
}
560+
_ => $help_name(&lhs, &rhs, batch),
561+
}
562+
}
563+
};
564+
}
565+
566+
overloaded_compare_func!(
490567
"Perform `less than` comparison operation",
491568
lt,
492569
lt_helper,
493570
af_lt
494571
);
495-
overloaded_binary_func!(
572+
overloaded_compare_func!(
496573
"Perform `greater than` comparison operation",
497574
gt,
498575
gt_helper,
499576
af_gt
500577
);
501-
overloaded_binary_func!(
578+
overloaded_compare_func!(
502579
"Perform `less than equals` comparison operation",
503580
le,
504581
le_helper,
505582
af_le
506583
);
507-
overloaded_binary_func!(
584+
overloaded_compare_func!(
508585
"Perform `greater than equals` comparison operation",
509586
ge,
510587
ge_helper,
511588
af_ge
512589
);
513-
overloaded_binary_func!(
590+
overloaded_compare_func!(
514591
"Perform `equals` comparison operation",
515592
eq,
516593
eq_helper,
517594
af_eq
518595
);
519-
overloaded_binary_func!(
520-
"Compute modulo of two Arrays",
521-
modulo,
522-
modulo_helper,
523-
af_mod
524-
);
525-
overloaded_binary_func!(
526-
"Calculate atan2 of two Arrays",
527-
atan2,
528-
atan2_helper,
529-
af_atan2
530-
);
531-
overloaded_binary_func!(
532-
"Create complex array from two Arrays",
533-
cplx2,
534-
cplx2_helper,
535-
af_cplx2
536-
);
537-
overloaded_binary_func!("Compute root", root, root_helper, af_root);
538-
overloaded_binary_func!("Computer power", pow, pow_helper, af_pow);
539596

540597
fn clamp_helper<X, Y>(
541598
inp: &Array<X>,

0 commit comments

Comments
 (0)