@@ -433,17 +433,6 @@ macro_rules! overloaded_binary_func {
433
433
///
434
434
/// An Array with results of the binary operation.
435
435
///
436
- /// In the case of comparison operations such as the following, the type of output
437
- /// Array is [DType::B8](./enum.DType.html). To retrieve the results of such boolean output
438
- /// to host, an array of 8-bit wide types(eg. u8, i8) should be used since ArrayFire's internal
439
- /// implementation uses char for boolean.
440
- ///
441
- /// * [gt](./fn.gt.html)
442
- /// * [lt](./fn.lt.html)
443
- /// * [ge](./fn.ge.html)
444
- /// * [le](./fn.le.html)
445
- /// * [eq](./fn.eq.html)
446
- ///
447
436
///# Note
448
437
///
449
438
/// The trait `Convertable` essentially translates to a scalar native type on rust or Array.
@@ -487,55 +476,123 @@ overloaded_binary_func!("Compute remainder from two Arrays", rem, rem_helper, af
487
476
overloaded_binary_func ! ( "Compute left shift" , shiftl, shiftl_helper, af_bitshiftl) ;
488
477
overloaded_binary_func ! ( "Compute right shift" , shiftr, shiftr_helper, af_bitshiftr) ;
489
478
overloaded_binary_func ! (
479
+ "Compute modulo of two Arrays" ,
480
+ modulo,
481
+ modulo_helper,
482
+ af_mod
483
+ ) ;
484
+ overloaded_binary_func ! (
485
+ "Calculate atan2 of two Arrays" ,
486
+ atan2,
487
+ atan2_helper,
488
+ af_atan2
489
+ ) ;
490
+ overloaded_binary_func ! (
491
+ "Create complex array from two Arrays" ,
492
+ cplx2,
493
+ cplx2_helper,
494
+ af_cplx2
495
+ ) ;
496
+ overloaded_binary_func ! ( "Compute root" , root, root_helper, af_root) ;
497
+ overloaded_binary_func ! ( "Computer power" , pow, pow_helper, af_pow) ;
498
+
499
+ macro_rules! overloaded_compare_func {
500
+ ( $doc_str: expr, $fn_name: ident, $help_name: ident, $ffi_name: ident) => {
501
+ fn $help_name<A , B >( lhs: & Array <A >, rhs: & Array <B >, batch: bool ) -> Array <bool >
502
+ where
503
+ A : HasAfEnum + ImplicitPromote <B >,
504
+ B : HasAfEnum + ImplicitPromote <A >,
505
+ {
506
+ let mut temp: i64 = 0 ;
507
+ unsafe {
508
+ let err_val = $ffi_name(
509
+ & mut temp as MutAfArray ,
510
+ lhs. get( ) as AfArray ,
511
+ rhs. get( ) as AfArray ,
512
+ batch as c_int,
513
+ ) ;
514
+ HANDLE_ERROR ( AfError :: from( err_val) ) ;
515
+ }
516
+ temp. into( )
517
+ }
518
+
519
+ #[ doc=$doc_str]
520
+ ///
521
+ /// This is a comparison operation.
522
+ ///
523
+ ///# Parameters
524
+ ///
525
+ /// - `arg1`is an argument that implements an internal trait `Convertable`.
526
+ /// - `arg2`is an argument that implements an internal trait `Convertable`.
527
+ /// - `batch` is an boolean that indicates if the current operation is an batch operation.
528
+ ///
529
+ /// Both parameters `arg1` and `arg2` can be either an Array or a value of rust integral
530
+ /// type.
531
+ ///
532
+ ///# Return Values
533
+ ///
534
+ /// An Array with results of the comparison operation a.k.a an Array of boolean values.
535
+ ///# Note
536
+ ///
537
+ /// The trait `Convertable` essentially translates to a scalar native type on rust or Array.
538
+ pub fn $fn_name<T , U >(
539
+ arg1: & T ,
540
+ arg2: & U ,
541
+ batch: bool ,
542
+ ) -> Array <bool >
543
+ where
544
+ T : Convertable ,
545
+ U : Convertable ,
546
+ <T as Convertable >:: OutType : HasAfEnum + ImplicitPromote <<U as Convertable >:: OutType >,
547
+ <U as Convertable >:: OutType : HasAfEnum + ImplicitPromote <<T as Convertable >:: OutType >,
548
+ {
549
+ let lhs = arg1. convert( ) ; // Convert to Array<T>
550
+ let rhs = arg2. convert( ) ; // Convert to Array<T>
551
+ match ( lhs. is_scalar( ) , rhs. is_scalar( ) ) {
552
+ ( true , false ) => {
553
+ let l = tile( & lhs, rhs. dims( ) ) ;
554
+ $help_name( & l, & rhs, batch)
555
+ }
556
+ ( false , true ) => {
557
+ let r = tile( & rhs, lhs. dims( ) ) ;
558
+ $help_name( & lhs, & r, batch)
559
+ }
560
+ _ => $help_name( & lhs, & rhs, batch) ,
561
+ }
562
+ }
563
+ } ;
564
+ }
565
+
566
+ overloaded_compare_func ! (
490
567
"Perform `less than` comparison operation" ,
491
568
lt,
492
569
lt_helper,
493
570
af_lt
494
571
) ;
495
- overloaded_binary_func ! (
572
+ overloaded_compare_func ! (
496
573
"Perform `greater than` comparison operation" ,
497
574
gt,
498
575
gt_helper,
499
576
af_gt
500
577
) ;
501
- overloaded_binary_func ! (
578
+ overloaded_compare_func ! (
502
579
"Perform `less than equals` comparison operation" ,
503
580
le,
504
581
le_helper,
505
582
af_le
506
583
) ;
507
- overloaded_binary_func ! (
584
+ overloaded_compare_func ! (
508
585
"Perform `greater than equals` comparison operation" ,
509
586
ge,
510
587
ge_helper,
511
588
af_ge
512
589
) ;
513
- overloaded_binary_func ! (
590
+ overloaded_compare_func ! (
514
591
"Perform `equals` comparison operation" ,
515
592
eq,
516
593
eq_helper,
517
594
af_eq
518
595
) ;
519
- overloaded_binary_func ! (
520
- "Compute modulo of two Arrays" ,
521
- modulo,
522
- modulo_helper,
523
- af_mod
524
- ) ;
525
- overloaded_binary_func ! (
526
- "Calculate atan2 of two Arrays" ,
527
- atan2,
528
- atan2_helper,
529
- af_atan2
530
- ) ;
531
- overloaded_binary_func ! (
532
- "Create complex array from two Arrays" ,
533
- cplx2,
534
- cplx2_helper,
535
- af_cplx2
536
- ) ;
537
- overloaded_binary_func ! ( "Compute root" , root, root_helper, af_root) ;
538
- overloaded_binary_func ! ( "Computer power" , pow, pow_helper, af_pow) ;
539
596
540
597
fn clamp_helper < X , Y > (
541
598
inp : & Array < X > ,
0 commit comments