Skip to content

Commit

Permalink
Add function for checking column-major format.
Browse files Browse the repository at this point in the history
  • Loading branch information
maciejkula committed Jan 14, 2018
1 parent 8126097 commit 9362430
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions src/linalg/impl_linalg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,37 @@ fn blas_row_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
true
}

#[cfg(feature="blas")]
fn blas_column_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
where S: Data,
A: 'static,
S::Elem: 'static,
{
if !same_type::<A, S::Elem>() {
return false;
}
let (m, n) = a.dim();
let s0 = a.strides()[0];
let s1 = a.strides()[1];
if !(s0 == 1 || m == 1) {
return false;
}
if s0 < 1 || s1 < 1 {
return false;
}
if (s0 > blas_index::max_value() as isize || s0 < blas_index::min_value() as isize) ||
(s1 > blas_index::max_value() as isize || s1 < blas_index::min_value() as isize)
{
return false;
}
if m > blas_index::max_value() as usize ||
n > blas_index::max_value() as usize
{
return false;
}
true
}

#[cfg(test)]
#[cfg(feature="blas")]
mod blas_tests {
Expand All @@ -713,31 +744,43 @@ mod blas_tests {
fn blas_row_major_2d_normal_matrix() {
let m: Array2<f32> = Array2::zeros((3, 5));
assert!(blas_row_major_2d::<f32, _>(&m));
assert!(!blas_column_major_2d::<f32, _>(&m));
}

#[test]
fn blas_row_major_2d_row_matrix() {
let m: Array2<f32> = Array2::zeros((1, 5));
assert!(blas_row_major_2d::<f32, _>(&m));
assert!(blas_column_major_2d::<f32, _>(&m));
}

#[test]
fn blas_row_major_2d_column_matrix() {
let m: Array2<f32> = Array2::zeros((5, 1));
assert!(blas_row_major_2d::<f32, _>(&m));
assert!(blas_column_major_2d::<f32, _>(&m));
}

#[test]
fn blas_row_major_2d_transposed_row_matrix() {
let m: Array2<f32> = Array2::zeros((1, 5));
let m_t = m.t();
assert!(blas_row_major_2d::<f32, _>(&m_t));
assert!(blas_column_major_2d::<f32, _>(&m));
}

#[test]
fn blas_row_major_2d_transposed_column_matrix() {
let m: Array2<f32> = Array2::zeros((5, 1));
let m_t = m.t();
assert!(blas_row_major_2d::<f32, _>(&m_t));
assert!(blas_column_major_2d::<f32, _>(&m));
}

#[test]
fn blas_column_major_2d_normal_matrix() {
let m: Array2<f32> = Array2::zeros((3, 5).f());
assert!(!blas_row_major_2d::<f32, _>(&m));
assert!(blas_column_major_2d::<f32, _>(&m));
}
}

0 comments on commit 9362430

Please sign in to comment.