Skip to content

Commit

Permalink
Handle degenerate cases. Now passing unit tests. Fixes #1.
Browse files Browse the repository at this point in the history
  • Loading branch information
AEBilgrau committed May 16, 2019
1 parent 423e8f8 commit e5a709f
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 10 deletions.
66 changes: 58 additions & 8 deletions src/corFamily.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,26 @@
Rcpp::NumericMatrix corRcpp(Rcpp::NumericMatrix & X) {

const int m = X.ncol();
const int n = X.nrow();

// Centering the matrix
X = centerNumericMatrix(X);

Rcpp::NumericMatrix cor(m, m);

// Degenerate case
if (n == 0) {
std::fill(cor.begin(), cor.end(), Rcpp::NumericVector::get_na());
return cor;
}

// Compute 1 over the sample standard deviation
Rcpp::NumericVector inv_sqrt_ss(m);
for (int i = 0; i < m; ++i) {
inv_sqrt_ss(i) = 1/sqrt(Rcpp::sum(X(Rcpp::_, i)*X(Rcpp::_, i)));
}

// Computing the correlation matrix
Rcpp::NumericMatrix cor(m, m);
for (int i = 0; i < m; ++i) {
for (int j = 0; j <= i; ++j) {
cor(i, j) = Rcpp::sum(X(Rcpp::_,i)*X(Rcpp::_,j)) *
Expand All @@ -77,11 +85,19 @@ Rcpp::NumericMatrix xcorRcpp(Rcpp::NumericMatrix & X,

const int m_X = X.ncol();
const int m_Y = Y.ncol();
const int n = X.nrow();

// Centering the matrices
X = centerNumericMatrix(X);
Y = centerNumericMatrix(Y);

Rcpp::NumericMatrix cor(m_X, m_Y);

// Degenerate case
if (n == 0) {
std::fill(cor.begin(), cor.end(), Rcpp::NumericVector::get_na());
return cor;
}

// Compute 1 over square root the sum of squares
Rcpp::NumericVector inv_sqrt_ss_X(m_X);
Expand All @@ -94,7 +110,6 @@ Rcpp::NumericMatrix xcorRcpp(Rcpp::NumericMatrix & X,
}

// Computing the cross-correlation matrix
Rcpp::NumericMatrix cor(m_X, m_Y);
for (int i = 0; i < m_X; ++i) {
for (int j = 0; j < m_Y; ++j) {
cor(i, j) = Rcpp::sum(X(Rcpp::_, i)*Y(Rcpp::_, j)) *
Expand All @@ -112,18 +127,39 @@ Rcpp::NumericMatrix xcorRcpp(Rcpp::NumericMatrix & X,
//' @export
// [[Rcpp::export]]
arma::mat corArma(const arma::mat & X) {
arma::mat cor = arma::cor(X, 0);
cor.diag() /= cor.diag(); // Ensure 1 in the diagonal (if not NaN)
return cor;
arma::mat out(X.n_cols, X.n_cols);

// Degenerate cases
if (X.n_cols == 0) {
return out;
} else if (X.n_rows == 0 || X.n_rows == 1) {
out.fill(Rcpp::NumericVector::get_na());
} else {
out = arma::cor(X, 0);
}

out.diag() /= out.diag(); // Ensure 1 in the diagonal (if not NaN)
return out;
}

// Cross-correlation implementation in armadillo
//' @rdname corFamily
//' @export
// [[Rcpp::export]]
arma::mat xcorArma(const arma::mat & X,
const arma::mat & Y) {
return arma::cor(X, Y, 0);
arma::mat xcorArma(const arma::mat& X,
const arma::mat& Y) {
arma::mat out(X.n_cols, Y.n_cols);

// Degenerate case first
if (X.n_cols == 0 || Y.n_cols == 0) {
return out;
} else if (X.n_rows == 0 || X.n_rows == 1 || Y.n_rows == 0 || Y.n_rows == 1) {
out.fill(Rcpp::NumericVector::get_na());
} else {
out = arma::cor(X, Y, 0);
}

return out;
}


Expand All @@ -135,6 +171,12 @@ arma::mat xcorArma(const arma::mat & X,
// [[Rcpp::export]]
Eigen::MatrixXd corEigen(Eigen::Map<Eigen::MatrixXd> & X) {

// Handle degenerate cases
if (X.rows() == 0 && X.cols() > 0) {
return Eigen::MatrixXd::Constant(X.cols(), X.cols(),
Rcpp::NumericVector::get_na());
}

// Computing degrees of freedom
// n - 1 is the unbiased estimate whereas n is the MLE
const int df = X.rows() - 1; // Subtract 1 by default
Expand All @@ -159,6 +201,14 @@ Eigen::MatrixXd corEigen(Eigen::Map<Eigen::MatrixXd> & X) {
Eigen::MatrixXd xcorEigen(Eigen::Map<Eigen::MatrixXd> & X,
Eigen::Map<Eigen::MatrixXd> & Y) {

// Handle degenerate cases
if (X.cols() == 0 || Y.cols() == 0) {
return Eigen::MatrixXd::Constant(0, 0, 0);
} else if (X.rows() == 0) { // && X.cols() > 0 && Y.cols() > 0 implicit
return Eigen::MatrixXd::Constant(X.cols(), Y.cols(),
Rcpp::NumericVector::get_na());
}

// Computing degrees of freedom
// n - 1 is the unbiased estimate whereas n is the MLE
const int df = X.rows() - 1; // Subtract 1 by default
Expand Down
4 changes: 2 additions & 2 deletions src/covFamily.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ Rcpp::NumericMatrix xcovRcpp(Rcpp::NumericMatrix & X,
}


// Covariance "implementation"" in Armadillio
// Covariance implementation in Armadillio
//' @rdname covFamily
//' @export
// [[Rcpp::export]]
Expand All @@ -130,7 +130,7 @@ arma::mat covArma(const arma::mat& X,
}


// Cross-covariance "implementation"" in Armadillio
// Cross-covariance implementation in Armadillio
//' @rdname covFamily
//' @export
// [[Rcpp::export]]
Expand Down

0 comments on commit e5a709f

Please sign in to comment.