Skip to content

Commit aed6e1c

Browse files
committed
Change lhs/inout parameter of assign functions to mutable ref
This is required since by default a varialble reference immutable in rust. However, arrayfire C API for assignment operations, af_assign_seq and af_assign_gen, does use lhs as output. Thus, an Array passed as lhs might get modifed some times, thus breaking the immutable ref rule. This change fixes that by converting the lhs/inout paramters of assignment functions to mutable references and removes return values completely.
1 parent 2c62150 commit aed6e1c

File tree

4 files changed

+42
-53
lines changed

4 files changed

+42
-53
lines changed

examples/acoustic_wave.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ fn acoustic_wave_simulation() {
7171
// Location of the source.
7272
let seqs = &[Seq::new(700.0, 800.0, 1.0), Seq::new(800.0, 800.0, 1.0)];
7373
// Set the pressure there.
74-
p = assign_seq(
75-
&p,
74+
assign_seq(
75+
&mut p,
7676
seqs,
7777
&index(&pulse, &[Seq::new(it as f64, it as f64, 1.0)]),
7878
);

examples/helloworld.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ fn main() {
2222

2323
let dims = Dim4::new(&[num_rows, num_cols, 1, 1]);
2424

25-
let a = randu::<f32>(dims);
25+
let mut a = randu::<f32>(dims);
2626
af_print!("Create a 5-by-3 float matrix on the GPU", a);
2727

2828
println!("Element-wise arithmetic");
@@ -67,8 +67,8 @@ fn main() {
6767
let r_dims = Dim4::new(&[3, 1, 1, 1]);
6868
let r_input: [f32; 3] = [1.0, 1.0, 1.0];
6969
let r = Array::new(&r_input, r_dims);
70-
let ur = set_row(&a, &r, num_rows - 1);
71-
af_print!("Set last row to 1's", ur);
70+
set_row(&mut a, &r, num_rows - 1);
71+
af_print!("Set last row to 1's", a);
7272

7373
let d_dims = Dim4::new(&[2, 3, 1, 1]);
7474
let d_input: [i32; 6] = [1, 2, 3, 4, 5, 6];

src/arith/mod.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -831,7 +831,6 @@ mod op_assign {
831831
use crate::array::Array;
832832
use crate::index::{assign_gen, Indexer};
833833
use crate::seq::Seq;
834-
use std::mem;
835834
use std::ops::{AddAssign, DivAssign, MulAssign, RemAssign, SubAssign};
836835
use std::ops::{BitAndAssign, BitOrAssign, BitXorAssign, ShlAssign, ShrAssign};
837836

@@ -852,8 +851,7 @@ mod op_assign {
852851
idxrs.set_index(&tmp_seq, n, Some(false));
853852
}
854853
let opres = $func(self as &Array<A>, &rhs, false).cast::<A>();
855-
let tmp = assign_gen(self as &Array<A>, &idxrs, &opres);
856-
let old = mem::replace(self, tmp);
854+
assign_gen(self, &idxrs, &opres);
857855
}
858856
}
859857
};
@@ -884,8 +882,7 @@ mod op_assign {
884882
idxrs.set_index(&tmp_seq, n, Some(false));
885883
}
886884
let opres = $func(self as &Array<A>, &rhs, false).cast::<A>();
887-
let tmp = assign_gen(self as &Array<A>, &idxrs, &opres);
888-
let old = mem::replace(self, tmp);
885+
assign_gen(self, &idxrs, &opres);
889886
}
890887
}
891888
};

src/index.rs

Lines changed: 35 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use crate::util::{AfArray, AfIndex, DimT, HasAfEnum, MutAfArray, MutAfIndex};
99

1010
use std::default::Default;
1111
use std::marker::PhantomData;
12+
use std::mem;
1213

1314
#[allow(dead_code)]
1415
extern "C" {
@@ -276,7 +277,6 @@ where
276277
/// print(&a);
277278
/// print(&row(&a, 4));
278279
/// ```
279-
#[allow(dead_code)]
280280
pub fn row<T>(input: &Array<T>, row_num: u64) -> Array<T>
281281
where
282282
T: HasAfEnum,
@@ -290,20 +290,18 @@ where
290290
)
291291
}
292292

293-
#[allow(dead_code)]
294-
/// Set `row_num`^th row in `input` Array to a new Array `new_row`
295-
pub fn set_row<T>(input: &Array<T>, new_row: &Array<T>, row_num: u64) -> Array<T>
293+
/// Set `row_num`^th row in `inout` Array to a new Array `new_row`
294+
pub fn set_row<T>(inout: &mut Array<T>, new_row: &Array<T>, row_num: u64)
296295
where
297296
T: HasAfEnum,
298297
{
299298
let seqs = [
300299
Seq::new(row_num as f64, row_num as f64, 1.0),
301300
Seq::default(),
302301
];
303-
assign_seq(input, &seqs, new_row)
302+
assign_seq(inout, &seqs, new_row)
304303
}
305304

306-
#[allow(dead_code)]
307305
/// Get an Array with all rows from `first` to `last` in the `input` Array
308306
pub fn rows<T>(input: &Array<T>, first: u64, last: u64) -> Array<T>
309307
where
@@ -315,14 +313,13 @@ where
315313
)
316314
}
317315

318-
#[allow(dead_code)]
319-
/// Set rows from `first` to `last` in `input` Array with rows from Array `new_rows`
320-
pub fn set_rows<T>(input: &Array<T>, new_rows: &Array<T>, first: u64, last: u64) -> Array<T>
316+
/// Set rows from `first` to `last` in `inout` Array with rows from Array `new_rows`
317+
pub fn set_rows<T>(inout: &mut Array<T>, new_rows: &Array<T>, first: u64, last: u64)
321318
where
322319
T: HasAfEnum,
323320
{
324321
let seqs = [Seq::new(first as f64, last as f64, 1.0), Seq::default()];
325-
assign_seq(input, &seqs, new_rows)
322+
assign_seq(inout, &seqs, new_rows)
326323
}
327324

328325
/// Extract `col_num` col from `input` Array
@@ -337,7 +334,6 @@ where
337334
/// println!("Grab last col of the random matrix");
338335
/// print(&col(&a, 4));
339336
/// ```
340-
#[allow(dead_code)]
341337
pub fn col<T>(input: &Array<T>, col_num: u64) -> Array<T>
342338
where
343339
T: HasAfEnum,
@@ -351,20 +347,18 @@ where
351347
)
352348
}
353349

354-
#[allow(dead_code)]
355-
/// Set `col_num`^th col in `input` Array to a new Array `new_col`
356-
pub fn set_col<T>(input: &Array<T>, new_col: &Array<T>, col_num: u64) -> Array<T>
350+
/// Set `col_num`^th col in `inout` Array to a new Array `new_col`
351+
pub fn set_col<T>(inout: &mut Array<T>, new_col: &Array<T>, col_num: u64)
357352
where
358353
T: HasAfEnum,
359354
{
360355
let seqs = [
361356
Seq::default(),
362357
Seq::new(col_num as f64, col_num as f64, 1.0),
363358
];
364-
assign_seq(input, &seqs, new_col)
359+
assign_seq(inout, &seqs, new_col)
365360
}
366361

367-
#[allow(dead_code)]
368362
/// Get all cols from `first` to `last` in the `input` Array
369363
pub fn cols<T>(input: &Array<T>, first: u64, last: u64) -> Array<T>
370364
where
@@ -376,20 +370,18 @@ where
376370
)
377371
}
378372

379-
#[allow(dead_code)]
380-
/// Set cols from `first` to `last` in `input` Array with cols from Array `new_cols`
381-
pub fn set_cols<T>(input: &Array<T>, new_cols: &Array<T>, first: u64, last: u64) -> Array<T>
373+
/// Set cols from `first` to `last` in `inout` Array with cols from Array `new_cols`
374+
pub fn set_cols<T>(inout: &mut Array<T>, new_cols: &Array<T>, first: u64, last: u64)
382375
where
383376
T: HasAfEnum,
384377
{
385378
let seqs = [Seq::default(), Seq::new(first as f64, last as f64, 1.0)];
386-
assign_seq(input, &seqs, new_cols)
379+
assign_seq(inout, &seqs, new_cols)
387380
}
388381

389-
#[allow(dead_code)]
390382
/// Get `slice_num`^th slice from `input` Array
391383
///
392-
/// Note. Slices indicate that the indexing is along 3rd dimension
384+
/// Slices indicate that the indexing is along 3rd dimension
393385
pub fn slice<T>(input: &Array<T>, slice_num: u64) -> Array<T>
394386
where
395387
T: HasAfEnum,
@@ -402,11 +394,10 @@ where
402394
index(input, &seqs)
403395
}
404396

405-
#[allow(dead_code)]
406-
/// Set slice `slice_num` in `input` Array to a new Array `new_slice`
397+
/// Set slice `slice_num` in `inout` Array to a new Array `new_slice`
407398
///
408399
/// Slices indicate that the indexing is along 3rd dimension
409-
pub fn set_slice<T>(input: &Array<T>, new_slice: &Array<T>, slice_num: u64) -> Array<T>
400+
pub fn set_slice<T>(inout: &mut Array<T>, new_slice: &Array<T>, slice_num: u64)
410401
where
411402
T: HasAfEnum,
412403
{
@@ -415,10 +406,9 @@ where
415406
Seq::default(),
416407
Seq::new(slice_num as f64, slice_num as f64, 1.0),
417408
];
418-
assign_seq(input, &seqs, new_slice)
409+
assign_seq(inout, &seqs, new_slice)
419410
}
420411

421-
#[allow(dead_code)]
422412
/// Get slices from `first` to `last` in `input` Array
423413
///
424414
/// Slices indicate that the indexing is along 3rd dimension
@@ -434,11 +424,10 @@ where
434424
index(input, &seqs)
435425
}
436426

437-
#[allow(dead_code)]
438-
/// Set `first` to `last` slices of `input` Array to a new Array `new_slices`
427+
/// Set `first` to `last` slices of `inout` Array to a new Array `new_slices`
439428
///
440429
/// Slices indicate that the indexing is along 3rd dimension
441-
pub fn set_slices<T>(input: &Array<T>, new_slices: &Array<T>, first: u64, last: u64) -> Array<T>
430+
pub fn set_slices<T>(inout: &mut Array<T>, new_slices: &Array<T>, first: u64, last: u64)
442431
where
443432
T: HasAfEnum,
444433
{
@@ -447,7 +436,7 @@ where
447436
Seq::default(),
448437
Seq::new(first as f64, last as f64, 1.0),
449438
];
450-
assign_seq(input, &seqs, new_slices)
439+
assign_seq(inout, &seqs, new_slices)
451440
}
452441

453442
/// Lookup(hash) an Array using another Array
@@ -480,25 +469,26 @@ where
480469
///
481470
/// ```rust
482471
/// use arrayfire::{constant, Dim4, Seq, assign_seq, print};
483-
/// let a = constant(2.0 as f32, Dim4::new(&[5, 3, 1, 1]));
484-
/// let b = constant(1.0 as f32, Dim4::new(&[3, 3, 1, 1]));
485-
/// let seqs = &[Seq::new(1.0, 3.0, 1.0), Seq::default()];
486-
/// let sub = assign_seq(&a, seqs, &b);
472+
/// let mut a = constant(2.0 as f32, Dim4::new(&[5, 3, 1, 1]));
487473
/// print(&a);
488474
/// // 2.0 2.0 2.0
489475
/// // 2.0 2.0 2.0
490476
/// // 2.0 2.0 2.0
491477
/// // 2.0 2.0 2.0
492478
/// // 2.0 2.0 2.0
493479
///
494-
/// print(&sub);
480+
/// let b = constant(1.0 as f32, Dim4::new(&[3, 3, 1, 1]));
481+
/// let seqs = &[Seq::new(1.0, 3.0, 1.0), Seq::default()];
482+
/// assign_seq(&mut a, seqs, &b);
483+
///
484+
/// print(&a);
495485
/// // 2.0 2.0 2.0
496486
/// // 1.0 1.0 1.0
497487
/// // 1.0 1.0 1.0
498488
/// // 1.0 1.0 1.0
499489
/// // 2.0 2.0 2.0
500490
/// ```
501-
pub fn assign_seq<T: Copy, I>(lhs: &Array<I>, seqs: &[Seq<T>], rhs: &Array<I>) -> Array<I>
491+
pub fn assign_seq<T: Copy, I>(lhs: &mut Array<I>, seqs: &[Seq<T>], rhs: &Array<I>)
502492
where
503493
c_double: From<T>,
504494
I: HasAfEnum,
@@ -516,7 +506,8 @@ where
516506
);
517507
HANDLE_ERROR(AfError::from(err_val));
518508
}
519-
temp.into()
509+
let modified = temp.into();
510+
let _old_arr = mem::replace(lhs, modified);
520511
}
521512

522513
/// Index an Array using any combination of Array's and Sequence's
@@ -574,7 +565,7 @@ where
574565
/// let values: [f32; 3] = [1.0, 2.0, 3.0];
575566
/// let indices = Array::new(&values, Dim4::new(&[3, 1, 1, 1]));
576567
/// let seq4gen = Seq::new(0.0, 2.0, 1.0);
577-
/// let a = randu::<f32>(Dim4::new(&[5, 3, 1, 1]));
568+
/// let mut a = randu::<f32>(Dim4::new(&[5, 3, 1, 1]));
578569
/// // [5 3 1 1]
579570
/// // 0.0000 0.2190 0.3835
580571
/// // 0.1315 0.0470 0.5194
@@ -588,16 +579,16 @@ where
588579
/// idxrs.set_index(&indices, 0, None); // 2nd parameter is indexing dimension
589580
/// idxrs.set_index(&seq4gen, 1, Some(false)); // 3rd parameter indicates batch operation
590581
///
591-
/// let sub2 = assign_gen(&a, &idxrs, &b);
592-
/// println!("a(indices, seq(0, 2, 1))"); print(&sub2);
582+
/// assign_gen(&mut a, &idxrs, &b);
583+
/// println!("a(indices, seq(0, 2, 1))"); print(&a);
593584
/// // [5 3 1 1]
594585
/// // 0.0000 0.2190 0.3835
595586
/// // 2.0000 2.0000 2.0000
596587
/// // 2.0000 2.0000 2.0000
597588
/// // 2.0000 2.0000 2.0000
598589
/// // 0.5328 0.9347 0.0535
599590
/// ```
600-
pub fn assign_gen<T>(lhs: &Array<T>, indices: &Indexer, rhs: &Array<T>) -> Array<T>
591+
pub fn assign_gen<T>(lhs: &mut Array<T>, indices: &Indexer, rhs: &Array<T>)
601592
where
602593
T: HasAfEnum,
603594
{
@@ -612,7 +603,8 @@ where
612603
);
613604
HANDLE_ERROR(AfError::from(err_val));
614605
}
615-
temp.into()
606+
let modified = temp.into();
607+
let _old_arr = mem::replace(lhs, modified);
616608
}
617609

618610
#[repr(C)]

0 commit comments

Comments
 (0)