Skip to content

Commit

Permalink
Change syntax of azip! to be similar to for loops
Browse files Browse the repository at this point in the history
  • Loading branch information
jturner314 committed Sep 9, 2019
1 parent b98c844 commit 4f34f3d
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 102 deletions.
2 changes: 1 addition & 1 deletion benches/bench1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ fn add_2d_zip_alloc(bench: &mut test::Bencher) {
let b = Array::<i32, _>::zeros((ADD2DSZ, ADD2DSZ));
bench.iter(|| unsafe {
let mut c = Array::uninitialized(a.dim());
azip!(a, b, mut c in { *c = a + b });
azip!((&a in &a, &b in &b, c in &mut c) *c = a + b);
c
});
}
Expand Down
4 changes: 2 additions & 2 deletions benches/chunks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ fn chunk2x2_iter_sum(bench: &mut Bencher) {
let chunksz = (2, 2);
let mut sum = Array::zeros(a.exact_chunks(chunksz).raw_dim());
bench.iter(|| {
azip!(ref a (a.exact_chunks(chunksz)), mut sum in {
azip!((a in a.exact_chunks(chunksz), sum in &mut sum) {
*sum = a.iter().sum::<f32>();
});
});
Expand All @@ -24,7 +24,7 @@ fn chunk2x2_sum(bench: &mut Bencher) {
let chunksz = (2, 2);
let mut sum = Array::zeros(a.exact_chunks(chunksz).raw_dim());
bench.iter(|| {
azip!(ref a (a.exact_chunks(chunksz)), mut sum in {
azip!((a in a.exact_chunks(chunksz), sum in &mut sum) {
*sum = a.sum();
});
});
Expand Down
4 changes: 2 additions & 2 deletions benches/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ fn sum_3_azip(bench: &mut Bencher) {
let c = vec![1; ZIPSZ];
bench.iter(|| {
let mut s = 0;
azip!(a, b, c in {
azip!((&a in &a, &b in &b, &c in &c) {
s += a + b + c;
});
s
Expand All @@ -182,7 +182,7 @@ fn vector_sum_3_azip(bench: &mut Bencher) {
let b = vec![1.; ZIPSZ];
let mut c = vec![1.; ZIPSZ];
bench.iter(|| {
azip!(a, b, mut c in {
azip!((&a in &a, &b in &b, c in &mut c) {
*c += a + b;
});
});
Expand Down
2 changes: 1 addition & 1 deletion benches/par_rayon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ fn add(bench: &mut Bencher) {
let c = Array2::<f64>::zeros((ADDN, ADDN));
let d = Array2::<f64>::zeros((ADDN, ADDN));
bench.iter(|| {
azip!(mut a, b, c, d in {
azip!((a in &mut a, &b in &b, &c in &c, &d in &d) {
*a += b.exp() + c.exp() + d.exp();
});
});
Expand Down
8 changes: 4 additions & 4 deletions examples/zip_many.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,23 @@ fn main() {

{
let a = a.view_mut().reversed_axes();
azip!(mut a (a), b (b.t()) in { *a = b });
azip!((a in a, &b in b.t()) *a = b);
}
assert_eq!(a, b);

azip!(mut a, b, c in { *a = b + c; });
azip!((a in &mut a, &b in &b, &c in c) *a = b + c);
assert_eq!(a, &b + &c);

// sum of each row
let ax = Axis(0);
let mut sums = Array::zeros(a.len_of(ax));
azip!(mut sums, ref a (a.axis_iter(ax)) in { *sums = a.sum() });
azip!((s in &mut sums, a in a.axis_iter(ax)) *s = a.sum());

// sum of each chunk
let chunk_sz = (2, 2);
let nchunks = (n / chunk_sz.0, n / chunk_sz.1);
let mut sums = Array::zeros(nchunks);
azip!(mut sums, ref a (a.exact_chunks(chunk_sz)) in { *sums = a.sum() });
azip!((s in &mut sums, a in a.exact_chunks(chunk_sz)) *s = a.sum());

// Let's imagine we split to parallelize
{
Expand Down
6 changes: 3 additions & 3 deletions src/doc/ndarray_for_numpy_users/coord_transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
//! let bunge = Array2::<f64>::ones((3, nelems));
//!
//! let mut rmat = Array::zeros((3, 3, nelems).f());
//! azip!(mut rmat (rmat.axis_iter_mut(Axis(2))), ref bunge (bunge.axis_iter(Axis(1))) in {
//! azip!((mut rmat in rmat.axis_iter_mut(Axis(2)), bunge in bunge.axis_iter(Axis(1))) {
//! let s1 = bunge[0].sin();
//! let c1 = bunge[0].cos();
//! let s2 = bunge[1].sin();
Expand All @@ -129,8 +129,8 @@
//! let eye2d = Array2::<f64>::eye(3);
//!
//! let mut rotated = Array3::<f64>::zeros((3, 3, nelems).f());
//! azip!(mut rotated (rotated.axis_iter_mut(Axis(2)), rmat (rmat.axis_iter(Axis(2)))) in {
//! rotated.assign({ &rmat.dot(&eye2d) });
//! azip!((mut rotated in rotated.axis_iter_mut(Axis(2)), rmat in rmat.axis_iter(Axis(2))) {
//! rotated.assign(&rmat.dot(&eye2d));
//! });
//! }
//! ```
2 changes: 1 addition & 1 deletion src/numeric/impl_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ where
let mut sum_sq = Array::<A, _>::zeros(self.dim.remove_axis(axis));
for (i, subview) in self.axis_iter(axis).enumerate() {
let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail.");
azip!(mut mean, mut sum_sq, x (subview) in {
azip!((mean in &mut mean, sum_sq in &mut sum_sq, &x in &subview) {
let delta = x - *mean;
*mean = *mean + delta / count;
*sum_sq = (x - *mean).mul_add(delta, *sum_sq);
Expand Down
124 changes: 45 additions & 79 deletions src/zip/zipmacro.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#[macro_export]
/// Array zip macro: lock step function application across several arrays and
/// producers.
///
Expand All @@ -7,41 +6,32 @@
/// This example:
///
/// ```rust,ignore
/// azip!(mut a, b, c in { *a = b + c })
/// azip!((a in &mut a, &b in &b, &c in &c) *a = b + c);
/// ```
///
/// Is equivalent to:
///
/// ```rust,ignore
/// Zip::from(&mut a).and(&b).and(&c).apply(|a, &b, &c| {
/// *a = b + c;
/// *a = b + c
/// });
///
/// ```
///
/// Explanation of the shorthand for captures:
///
/// + `mut a`: the producer is `&mut a` and the variable pattern is `mut a`.
/// + `b`: the producer is `&b` and the variable pattern is `&b` (same for `c`).
/// The syntax is either
///
/// The syntax is `azip!(` *[* `index` *pattern* `,`*] capture [*`,` *capture [*`,` *...] ]* `in {` *expression* `})`
/// where the captures are a sequence of pattern-like items that indicate which
/// arrays are used for the zip. The *expression* is evaluated elementwise,
/// with the value of an element from each producer in their respective variable.
/// `azip!((` *pat* `in` *expr* `,` *[* *pat* `in` *expr* `,` ... *]* `)` *body_expr* `)`
///
/// More capture rules:
/// or, to use `Zip::indexed` instead of `Zip::from`,
///
/// + `ref c`: the producer is `&c` and the variable pattern is `c`.
/// + `mut a (expr)`: the producer is `expr` and the variable pattern is `mut a`.
/// + `b (expr)`: the producer is `expr` and the variable pattern is `&b`.
/// + `ref c (expr)`: the producer is `expr` and the variable pattern is `c`.
/// `azip!((index` *pat* `,` *pat* `in` *expr* `,` *[* *pat* `in` *expr* `,` ... *]* `)` *body_expr* `)`
///
/// Special rule:
///
/// + `index i`: Use `Zip::indexed` instead. `i` is a pattern -- it can be
/// a single variable name or something else that pattern matches the index.
/// This rule must be the first if it is used, and it must be followed by
/// at least one other rule.
/// The *expr* are expressions whose types must implement `IntoNdProducer`, the
/// *pat* are the patterns of the parameters to the closure called by
/// `Zip::apply`, and *body_expr* is the body of the closure called by
/// `Zip::apply`. You can think of each *pat* `in` *expr* as being analogous to
/// the `pat in expr` of a normal loop `for pat in expr { statements }`: a
/// pattern, followed by `in`, followed by an expression that implements
/// `IntoNdProducer` (analogous to `IntoIterator` for a `for` loop).
///
/// **Panics** if any of the arrays are not of the same shape.
///
Expand All @@ -68,12 +58,12 @@
///
/// // Example 1: Compute a simple ternary operation:
/// // elementwise addition of b and c, stored in a
/// azip!(mut a, b, c in { *a = b + c });
/// azip!((a in &mut a, &b in &b, &c in &c) *a = b + c);
///
/// assert_eq!(a, &b + &c);
///
/// // Example 2: azip!() with index
/// azip!(index (i, j), b, c in {
/// azip!((index (i, j), &b in &b, &c in &c) {
/// a[[i, j]] = b - c;
/// });
///
Expand All @@ -87,80 +77,56 @@
/// assert_eq!(a, &b * &c);
///
///
/// // Since this function borrows its inputs, captures must use the x (x) pattern
/// // to avoid the macro's default rule that autorefs the producer.
/// // Since this function borrows its inputs, the `IntoNdProducer`
/// // expressions don't need to explicitly include `&mut` or `&`.
/// fn borrow_multiply(a: &mut M, b: &M, c: &M) {
/// azip!(mut a (a), b (b), c (c) in { *a = b * c });
/// azip!((a in a, &b in b, &c in c) *a = b * c);
/// }
///
///
/// // Example 4: using azip!() with a `ref` rule
/// // Example 4: using azip!() without dereference in pattern.
/// //
/// // Create a new array `totals` with one entry per row of `a`.
/// // Use azip to traverse the rows of `a` and assign to the corresponding
/// // entry in `totals` with the sum across each row.
/// //
/// // The row is an array view; use the 'ref' rule on the row, to avoid the
/// // default which is to dereference the produced item.
/// let mut totals = Array1::zeros(a.nrows());
///
/// azip!(mut totals, ref row (a.genrows()) in {
/// *totals = row.sum();
/// });
/// // The row is an array view; it doesn't need to be dereferenced.
/// let mut totals = Array1::zeros(a.rows());
/// azip!((totals in &mut totals, row in a.genrows()) *totals = row.sum());
///
/// // Check the result against the built in `.sum_axis()` along axis 1.
/// assert_eq!(totals, a.sum_axis(Axis(1)));
/// }
///
/// ```
#[macro_export]
macro_rules! azip {
// Build Zip Rule (index)
(@parse [index => $a:expr, $($aa:expr,)*] $t1:tt in $t2:tt) => {
$crate::azip!(@finish ($crate::Zip::indexed($a)) [$($aa,)*] $t1 in $t2)
};
// Build Zip Rule (no index)
(@parse [$a:expr, $($aa:expr,)*] $t1:tt in $t2:tt) => {
$crate::azip!(@finish ($crate::Zip::from($a)) [$($aa,)*] $t1 in $t2)
};
// Build Finish Rule (both)
(@finish ($z:expr) [$($aa:expr,)*] [$($p:pat,)+] in { $($t:tt)*}) => {
#[allow(unused_mut)]
($z)
$(
.and($aa)
)*
.apply(|$($p),+| {
$($t)*
})
};
// parsing stack: [expressions] [patterns] (one per operand)
// index uses empty [] -- must be first
(@parse [] [] index $i:pat, $($t:tt)*) => {
$crate::azip!(@parse [index =>] [$i,] $($t)*);
};
(@parse [$($exprs:tt)*] [$($pats:tt)*] mut $x:ident ($e:expr) $($t:tt)*) => {
$crate::azip!(@parse [$($exprs)* $e,] [$($pats)* mut $x,] $($t)*);
};
(@parse [$($exprs:tt)*] [$($pats:tt)*] mut $x:ident $($t:tt)*) => {
$crate::azip!(@parse [$($exprs)* &mut $x,] [$($pats)* mut $x,] $($t)*);
// Indexed with a single producer and no trailing comma.
((index $index:pat, $first_pat:pat in $first_prod:expr) $body:expr) => {
$crate::Zip::indexed($first_prod).apply(|$index, $first_pat| $body)
};
(@parse [$($exprs:tt)*] [$($pats:tt)*] , $($t:tt)*) => {
$crate::azip!(@parse [$($exprs)*] [$($pats)*] $($t)*);
// Indexed with more than one producer and no trailing comma.
((index $index:pat, $first_pat:pat in $first_prod:expr, $($pat:pat in $prod:expr),*) $body:expr) => {
$crate::Zip::indexed($first_prod)
$(.and($prod))*
.apply(|$index, $first_pat, $($pat),*| $body)
};
(@parse [$($exprs:tt)*] [$($pats:tt)*] ref $x:ident ($e:expr) $($t:tt)*) => {
$crate::azip!(@parse [$($exprs)* $e,] [$($pats)* $x,] $($t)*);
// Indexed with trailing comma.
((index $index:pat, $($pat:pat in $prod:expr),+,) $body:expr) => {
azip!((index $index, $($pat in $prod),+) $body)
};
(@parse [$($exprs:tt)*] [$($pats:tt)*] ref $x:ident $($t:tt)*) => {
$crate::azip!(@parse [$($exprs)* &$x,] [$($pats)* $x,] $($t)*);
// Unindexed with a single producer and no trailing comma.
(($first_pat:pat in $first_prod:expr) $body:expr) => {
$crate::Zip::from($first_prod).apply(|$first_pat| $body)
};
(@parse [$($exprs:tt)*] [$($pats:tt)*] $x:ident ($e:expr) $($t:tt)*) => {
$crate::azip!(@parse [$($exprs)* $e,] [$($pats)* &$x,] $($t)*);
// Unindexed with more than one producer and no trailing comma.
(($first_pat:pat in $first_prod:expr, $($pat:pat in $prod:expr),*) $body:expr) => {
$crate::Zip::from($first_prod)
$(.and($prod))*
.apply(|$first_pat, $($pat),*| $body)
};
(@parse [$($exprs:tt)*] [$($pats:tt)*] $x:ident $($t:tt)*) => {
$crate::azip!(@parse [$($exprs)* &$x,] [$($pats)* &$x,] $($t)*);
// Unindexed with trailing comma.
(($($pat:pat in $prod:expr),+,) $body:expr) => {
azip!(($($pat in $prod),+) $body)
};
(@parse [$($exprs:tt)*] [$($pats:tt)*] $($t:tt)*) => { };
($($t:tt)*) => {
$crate::azip!(@parse [] [] $($t)*);
}
}
18 changes: 9 additions & 9 deletions tests/azip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ use std::mem::swap;
fn test_azip1() {
let mut a = Array::zeros(62);
let mut x = 0;
azip!(mut a in { *a = x; x += 1; });
azip!((a in &mut a) { *a = x; x += 1; });
assert_equal(cloned(&a), 0..a.len());
}

#[test]
fn test_azip2() {
let mut a = Array::zeros((5, 7));
let b = Array::from_shape_fn(a.dim(), |(i, j)| 1. / (i + 2 * j) as f32);
azip!(mut a, b in { *a = b; });
azip!((a in &mut a, &b in &b) *a = b);
assert_eq!(a, b);
}

Expand All @@ -35,7 +35,7 @@ fn test_azip2_1() {
let mut a = Array::zeros((5, 7));
let b = Array::from_shape_fn((5, 10), |(i, j)| 1. / (i + 2 * j) as f32);
let b = b.slice(s![..;-1, 3..]);
azip!(mut a, b in { *a = b; });
azip!((a in &mut a, &b in &b) *a = b);
assert_eq!(a, b);
}

Expand All @@ -44,7 +44,7 @@ fn test_azip2_3() {
let mut b = Array::from_shape_fn((5, 10), |(i, j)| 1. / (i + 2 * j) as f32);
let mut c = Array::from_shape_fn((5, 10), |(i, j)| f32::exp((i + j) as f32));
let a = b.clone();
azip!(mut b, mut c in { swap(b, c) });
azip!((b in &mut b, c in &mut c) swap(b, c));
assert_eq!(a, c);
assert!(a != b);
}
Expand All @@ -58,7 +58,7 @@ fn test_azip2_sum() {
for i in 0..2 {
let ax = Axis(i);
let mut b = Array::zeros(c.len_of(ax));
azip!(mut b, ref c (c.axis_iter(ax)) in { *b = c.sum() });
azip!((b in &mut b, c in c.axis_iter(ax)) *b = c.sum());
assert_abs_diff_eq!(b, c.sum_axis(Axis(1 - i)), epsilon = 1e-6);
}
}
Expand All @@ -75,7 +75,7 @@ fn test_azip3_slices() {
*elt = i as f32;
}

azip!(mut a (&mut a[..]), b (&b[..]), mut c (&mut c[..]) in {
azip!((a in &mut a[..], b in &b[..], c in &mut c[..]) {
*a += b / 10.;
*c = a.sin();
});
Expand Down Expand Up @@ -115,7 +115,7 @@ fn test_zip_dim_mismatch_1() {
let mut d = a.raw_dim();
d[0] += 1;
let b = Array::from_shape_fn(d, |(i, j)| 1. / (i + 2 * j) as f32);
azip!(mut a, b in { *a = b; });
azip!((a in &mut a, &b in &b) *a = b);
}

// Test that Zip handles memory layout correctly for
Expand All @@ -136,7 +136,7 @@ fn test_contiguous_but_not_c_or_f() {
let correct_012 = a[[0, 1, 2]] + b[[0, 1, 2]];

let mut ans = Array::zeros(a.dim().f());
azip!(mut ans, a, b in { *ans = a + b });
azip!((ans in &mut ans, &a in &a, &b in &b) *ans = a + b);
println!("{:?}", a);
println!("{:?}", b);
println!("{:?}", ans);
Expand Down Expand Up @@ -200,7 +200,7 @@ fn test_indices_2() {
}

let mut count = 0;
azip!(index i, a1 in {
azip!((index i, &a1 in &a1) {
count += 1;
assert_eq!(a1, i);
});
Expand Down

0 comments on commit 4f34f3d

Please sign in to comment.