Skip to content

Commit

Permalink
Merge pull request rust-ndarray#814 from rust-ndarray/zip-collect-drop
Browse files Browse the repository at this point in the history
Implement Zip::apply_collect for non-Copy elements too
  • Loading branch information
bluss authored Apr 29, 2020
2 parents 3ea6861 + 624fd75 commit adef586
Show file tree
Hide file tree
Showing 4 changed files with 267 additions and 7 deletions.
32 changes: 32 additions & 0 deletions benches/bench1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,38 @@ fn add_2d_alloc_zip_collect(bench: &mut test::Bencher) {
});
}

#[bench]
fn vec_string_collect(bench: &mut test::Bencher) {
let v = vec![""; 10240];
bench.iter(|| {
v.iter().map(|s| s.to_owned()).collect::<Vec<_>>()
});
}

#[bench]
fn array_string_collect(bench: &mut test::Bencher) {
let v = Array::from(vec![""; 10240]);
bench.iter(|| {
Zip::from(&v).apply_collect(|s| s.to_owned())
});
}

#[bench]
fn vec_f64_collect(bench: &mut test::Bencher) {
let v = vec![1.; 10240];
bench.iter(|| {
v.iter().map(|s| s + 1.).collect::<Vec<_>>()
});
}

#[bench]
fn array_f64_collect(bench: &mut test::Bencher) {
let v = Array::from(vec![1.; 10240]);
bench.iter(|| {
Zip::from(&v).apply_collect(|s| s + 1.)
});
}


#[bench]
fn add_2d_assign_ops(bench: &mut test::Bencher) {
Expand Down
26 changes: 19 additions & 7 deletions src/zip/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#[macro_use]
mod zipmacro;
mod partial_array;

use std::mem::MaybeUninit;

Expand All @@ -20,6 +21,8 @@ use crate::NdIndex;
use crate::indexes::{indices, Indices};
use crate::layout::{CORDER, FORDER};

use partial_array::PartialArray;

/// Return if the expression is a break value.
macro_rules! fold_while {
($e:expr) => {
Expand Down Expand Up @@ -195,6 +198,7 @@ pub trait NdProducer {
fn split_at(self, axis: Axis, index: usize) -> (Self, Self)
where
Self: Sized;

private_decl! {}
}

Expand Down Expand Up @@ -1070,16 +1074,24 @@ macro_rules! map_impl {
/// inputs.
///
/// If all inputs are c- or f-order respectively, that is preserved in the output.
///
/// Restricted to functions that produce copyable results for technical reasons; other
/// cases are not yet implemented.
pub fn apply_collect<R>(self, f: impl FnMut($($p::Item,)* ) -> R) -> Array<R, D>
where R: Copy,
{
// To support non-Copy elements, implementation of dropping partial array (on
// panic) is needed
// Make uninit result
let mut output = self.uninitalized_for_current_layout::<R>();
self.apply_assign_into(&mut output, f);
if !std::mem::needs_drop::<R>() {
// For elements with no drop glue, just overwrite into the array
self.apply_assign_into(&mut output, f);
} else {
// For generic elements, use a proxy that counts the number of filled elements,
// and can drop the right number of elements on unwinding
unsafe {
PartialArray::scope(output.view_mut(), move |partial| {
debug_assert_eq!(partial.layout().tendency() >= 0, self.layout_tendency >= 0);
self.apply_assign_into(partial, f);
});
}
}

unsafe {
output.assume_init()
}
Expand Down
144 changes: 144 additions & 0 deletions src/zip/partial_array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// Copyright 2020 bluss and ndarray developers.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use crate::imp_prelude::*;
use crate::{
AssignElem,
Layout,
NdProducer,
Zip,
FoldWhile,
};

use std::cell::Cell;
use std::mem;
use std::mem::MaybeUninit;
use std::ptr;

/// An assignable element reference that increments a counter when assigned
pub(crate) struct ProxyElem<'a, 'b, A> {
item: &'a mut MaybeUninit<A>,
filled: &'b Cell<usize>
}

impl<'a, 'b, A> AssignElem<A> for ProxyElem<'a, 'b, A> {
fn assign_elem(self, item: A) {
self.filled.set(self.filled.get() + 1);
*self.item = MaybeUninit::new(item);
}
}

/// Handles progress of assigning to a part of an array, for elements that need
/// to be dropped on unwinding. See Self::scope.
pub(crate) struct PartialArray<'a, 'b, A, D>
where D: Dimension
{
data: ArrayViewMut<'a, MaybeUninit<A>, D>,
filled: &'b Cell<usize>,
}

impl<'a, 'b, A, D> PartialArray<'a, 'b, A, D>
where D: Dimension
{
/// Create a temporary PartialArray that wraps the array view `data`;
/// if the end of the scope is reached, the partial array is marked complete;
/// if execution unwinds at any time before them, the elements written until then
/// are dropped.
///
/// Safety: the caller *must* ensure that elements will be written in `data`'s preferred order.
/// PartialArray can not handle arbitrary writes, only in the memory order.
pub(crate) unsafe fn scope(data: ArrayViewMut<'a, MaybeUninit<A>, D>,
scope_fn: impl FnOnce(&mut PartialArray<A, D>))
{
let filled = Cell::new(0);
let mut partial = PartialArray::new(data, &filled);
scope_fn(&mut partial);
filled.set(0); // mark complete
}

unsafe fn new(data: ArrayViewMut<'a, MaybeUninit<A>, D>,
filled: &'b Cell<usize>) -> Self
{
debug_assert_eq!(filled.get(), 0);
Self { data, filled }
}
}

impl<'a, 'b, A, D> Drop for PartialArray<'a, 'b, A, D>
where D: Dimension
{
fn drop(&mut self) {
if !mem::needs_drop::<A>() {
return;
}

let mut count = self.filled.get();
if count == 0 {
return;
}

Zip::from(self).fold_while((), move |(), elt| {
if count > 0 {
count -= 1;
unsafe {
ptr::drop_in_place::<A>(elt.item.as_mut_ptr());
}
FoldWhile::Continue(())
} else {
FoldWhile::Done(())
}
});
}
}

impl<'a: 'c, 'b: 'c, 'c, A, D: Dimension> NdProducer for &'c mut PartialArray<'a, 'b, A, D> {
// This just wraps ArrayViewMut as NdProducer and maps the item
type Item = ProxyElem<'a, 'b, A>;
type Dim = D;
type Ptr = *mut MaybeUninit<A>;
type Stride = isize;

private_impl! {}
fn raw_dim(&self) -> Self::Dim {
self.data.raw_dim()
}

fn equal_dim(&self, dim: &Self::Dim) -> bool {
self.data.equal_dim(dim)
}

fn as_ptr(&self) -> Self::Ptr {
NdProducer::as_ptr(&self.data)
}

fn layout(&self) -> Layout {
self.data.layout()
}

unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item {
ProxyElem { filled: self.filled, item: &mut *ptr }
}

unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
self.data.uget_ptr(i)
}

fn stride_of(&self, axis: Axis) -> Self::Stride {
self.data.stride_of(axis)
}

#[inline(always)]
fn contiguous_stride(&self) -> Self::Stride {
self.data.contiguous_stride()
}

fn split_at(self, _axis: Axis, _index: usize) -> (Self, Self) {
unimplemented!();
}
}

72 changes: 72 additions & 0 deletions tests/azip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,78 @@ fn test_zip_assign_into_cell() {
assert_abs_diff_eq!(a2, &b + &c, epsilon = 1e-6);
}

#[test]
fn test_zip_collect_drop() {
use std::cell::RefCell;
use std::panic;

struct Recorddrop<'a>((usize, usize), &'a RefCell<Vec<(usize, usize)>>);

impl<'a> Drop for Recorddrop<'a> {
fn drop(&mut self) {
self.1.borrow_mut().push(self.0);
}
}

#[derive(Copy, Clone)]
enum Config {
CC,
CF,
FF,
}

impl Config {
fn a_is_f(self) -> bool {
match self {
Config::CC | Config::CF => false,
_ => true,
}
}
fn b_is_f(self) -> bool {
match self {
Config::CC => false,
_ => true,
}
}
}

let test_collect_panic = |config: Config, will_panic: bool, slice: bool| {
let mut inserts = RefCell::new(Vec::new());
let mut drops = RefCell::new(Vec::new());

let mut a = Array::from_shape_fn((5, 10).set_f(config.a_is_f()), |idx| idx);
let mut b = Array::from_shape_fn((5, 10).set_f(config.b_is_f()), |_| 0);
if slice {
a = a.slice_move(s![.., ..-1]);
b = b.slice_move(s![.., ..-1]);
}

let _result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
Zip::from(&a).and(&b).apply_collect(|&elt, _| {
if elt.0 > 3 && will_panic {
panic!();
}
inserts.borrow_mut().push(elt);
Recorddrop(elt, &drops)
});
}));

println!("{:?}", inserts.get_mut());
println!("{:?}", drops.get_mut());

assert_eq!(inserts.get_mut().len(), drops.get_mut().len(), "Incorrect number of drops");
assert_eq!(inserts.get_mut(), drops.get_mut(), "Incorrect order of drops");
};

for &should_panic in &[true, false] {
for &should_slice in &[false, true] {
test_collect_panic(Config::CC, should_panic, should_slice);
test_collect_panic(Config::CF, should_panic, should_slice);
test_collect_panic(Config::FF, should_panic, should_slice);
}
}
}


#[test]
fn test_azip_syntax_trailing_comma() {
Expand Down

0 comments on commit adef586

Please sign in to comment.