Skip to content

Commit

Permalink
Merge pull request rust-ndarray#980 from YuhanLiin/const-generics
Browse files Browse the repository at this point in the history
Const generics Improvements
  • Loading branch information
bluss authored Dec 6, 2021
2 parents d1bb045 + 75a27e5 commit 6c8b821
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 169 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- stable
- beta
- nightly
- 1.49.0 # MSRV
- 1.51.0 # MSRV

steps:
- uses: actions/checkout@v2
Expand Down
66 changes: 57 additions & 9 deletions src/arraytraits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,20 @@
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use std::hash;
use std::iter::FromIterator;
use alloc::boxed::Box;
use alloc::vec::Vec;
use std::iter::IntoIterator;
use std::mem;
use std::ops::{Index, IndexMut};
use alloc::boxed::Box;
use alloc::vec::Vec;
use std::{hash, mem::size_of};
use std::{iter::FromIterator, slice};

use crate::imp_prelude::*;
use crate::iter::{Iter, IterMut};
use crate::NdIndex;

use crate::numeric_util;
use crate::{FoldWhile, Zip};
use crate::{
dimension,
iter::{Iter, IterMut},
numeric_util, FoldWhile, NdIndex, Zip,
};

#[cold]
#[inline(never)]
Expand Down Expand Up @@ -323,6 +323,30 @@ where
}
}

/// Implementation of ArrayView2::from(&S) where S is a slice to a 2D array
///
/// **Panics** if the product of non-zero axis lengths overflows `isize` (This can only occur if A
/// is zero-sized because slices cannot contain more than `isize::MAX` number of bytes).
impl<'a, A, const N: usize> From<&'a [[A; N]]> for ArrayView<'a, A, Ix2> {
/// Create a two-dimensional read-only array view of the data in `slice`
fn from(xs: &'a [[A; N]]) -> Self {
let cols = N;
let rows = xs.len();
let dim = Ix2(rows, cols);
if size_of::<A>() == 0 {
dimension::size_of_shape_checked(&dim)
.expect("Product of non-zero axis lengths must not overflow isize.");
}

// `cols * rows` is guaranteed to fit in `isize` because we checked that it fits in
// `isize::MAX`
unsafe {
let data = slice::from_raw_parts(xs.as_ptr() as *const A, cols * rows);
ArrayView::from_shape_ptr(dim, data.as_ptr())
}
}
}

/// Implementation of `ArrayView::from(&A)` where `A` is an array.
impl<'a, A, S, D> From<&'a ArrayBase<S, D>> for ArrayView<'a, A, D>
where
Expand Down Expand Up @@ -355,6 +379,30 @@ where
}
}

/// Implementation of ArrayViewMut2::from(&S) where S is a slice to a 2D array
///
/// **Panics** if the product of non-zero axis lengths overflows `isize` (This can only occur if A
/// is zero-sized because slices cannot contain more than `isize::MAX` number of bytes).
impl<'a, A, const N: usize> From<&'a mut [[A; N]]> for ArrayViewMut<'a, A, Ix2> {
/// Create a two-dimensional read-write array view of the data in `slice`
fn from(xs: &'a mut [[A; N]]) -> Self {
let cols = N;
let rows = xs.len();
let dim = Ix2(rows, cols);
if size_of::<A>() == 0 {
dimension::size_of_shape_checked(&dim)
.expect("Product of non-zero axis lengths must not overflow isize.");
}

// `cols * rows` is guaranteed to fit in `isize` because we checked that it fits in
// `isize::MAX`
unsafe {
let data = slice::from_raw_parts_mut(xs.as_mut_ptr() as *mut A, cols * rows);
ArrayViewMut::from_shape_ptr(dim, data.as_mut_ptr())
}
}
}

/// Implementation of `ArrayViewMut::from(&mut A)` where `A` is an array.
impl<'a, A, S, D> From<&'a mut ArrayBase<S, D>> for ArrayViewMut<'a, A, D>
where
Expand Down
102 changes: 58 additions & 44 deletions src/dimension/ndindex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,50 +140,6 @@ macro_rules! ndindex_with_array {
0
}
}

// implement NdIndex<IxDyn> for Dim<[Ix; 2]> and so on
unsafe impl NdIndex<IxDyn> for Dim<[Ix; $n]> {
#[inline]
fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option<isize> {
debug_assert_eq!(strides.ndim(), $n,
"Attempted to index with {:?} in array with {} axes",
self, strides.ndim());
stride_offset_checked(dim.ix(), strides.ix(), self.ix())
}

#[inline]
fn index_unchecked(&self, strides: &IxDyn) -> isize {
debug_assert_eq!(strides.ndim(), $n,
"Attempted to index with {:?} in array with {} axes",
self, strides.ndim());
$(
stride_offset(get!(self, $index), get!(strides, $index)) +
)*
0
}
}

// implement NdIndex<IxDyn> for [Ix; 2] and so on
unsafe impl NdIndex<IxDyn> for [Ix; $n] {
#[inline]
fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option<isize> {
debug_assert_eq!(strides.ndim(), $n,
"Attempted to index with {:?} in array with {} axes",
self, strides.ndim());
stride_offset_checked(dim.ix(), strides.ix(), self)
}

#[inline]
fn index_unchecked(&self, strides: &IxDyn) -> isize {
debug_assert_eq!(strides.ndim(), $n,
"Attempted to index with {:?} in array with {} axes",
self, strides.ndim());
$(
stride_offset(self[$index], get!(strides, $index)) +
)*
0
}
}
)+
};
}
Expand All @@ -198,6 +154,64 @@ ndindex_with_array! {
[6, Ix6 0 1 2 3 4 5]
}

// implement NdIndex<IxDyn> for Dim<[Ix; 2]> and so on
unsafe impl<const N: usize> NdIndex<IxDyn> for Dim<[Ix; N]> {
#[inline]
fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option<isize> {
debug_assert_eq!(
strides.ndim(),
N,
"Attempted to index with {:?} in array with {} axes",
self,
strides.ndim()
);
stride_offset_checked(dim.ix(), strides.ix(), self.ix())
}

#[inline]
fn index_unchecked(&self, strides: &IxDyn) -> isize {
debug_assert_eq!(
strides.ndim(),
N,
"Attempted to index with {:?} in array with {} axes",
self,
strides.ndim()
);
(0..N)
.map(|i| stride_offset(get!(self, i), get!(strides, i)))
.sum()
}
}

// implement NdIndex<IxDyn> for [Ix; 2] and so on
unsafe impl<const N: usize> NdIndex<IxDyn> for [Ix; N] {
#[inline]
fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option<isize> {
debug_assert_eq!(
strides.ndim(),
N,
"Attempted to index with {:?} in array with {} axes",
self,
strides.ndim()
);
stride_offset_checked(dim.ix(), strides.ix(), self)
}

#[inline]
fn index_unchecked(&self, strides: &IxDyn) -> isize {
debug_assert_eq!(
strides.ndim(),
N,
"Attempted to index with {:?} in array with {} axes",
self,
strides.ndim()
);
(0..N)
.map(|i| stride_offset(self[i], get!(strides, i)))
.sum()
}
}

impl<'a> IntoDimension for &'a [Ix] {
type Dim = IxDyn;
fn into_dimension(self) -> Self::Dim {
Expand Down
Loading

0 comments on commit 6c8b821

Please sign in to comment.