Skip to content

Commit

Permalink
pass matrix directly to prediction function
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch committed Aug 26, 2024
1 parent d48414a commit d059d4b
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 13 deletions.
2 changes: 1 addition & 1 deletion R/extendr-wrappers.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ predict_naive_Bayes <- function(model, x, n_features) .Call(wrap__predict_naive_

fit_linear_reg <- function(x, y) .Call(wrap__fit_linear_reg, x, y)

predict_linear_reg <- function(model, x, n_features) .Call(wrap__predict_linear_reg, model, x, n_features)
predict_linear_reg <- function(model, x) .Call(wrap__predict_linear_reg, model, x)

fit_logistic_reg <- function(x, y, n_features) .Call(wrap__fit_logistic_reg, x, y, n_features)

Expand Down
2 changes: 1 addition & 1 deletion R/mod-linear_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

#' @export
predict.linfa_linear_reg <- function(object, newdata, ...) {
predict_linear_reg(object$fit, c(newdata), n_features = ncol(object$ptype))
predict_linear_reg(object$fit, newdata)
}

# nocov start
Expand Down
16 changes: 5 additions & 11 deletions src/rust/src/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,20 @@ impl From<FittedLinearRegression<f64>> for linfa_linear_reg {
#[extendr]
fn fit_linear_reg(x: ArrayView2<f64>, y: ArrayView1<f64>) -> linfa_linear_reg {
// Convert inputs to linfa-happy formats
let x_owned: Array2<f64> = x.to_owned();
let y_owned: Array1<f64> = y.to_owned();
let x: Array2<f64> = x.to_owned();
let y: Array1<f64> = y.to_owned();

// Create a Dataset
let dataset = Dataset::new(x_owned, y_owned);
let dataset = Dataset::new(x, y);

let model = LinearRegression::default().fit(&dataset).unwrap();

linfa_linear_reg::from(model)
}

#[extendr]
fn predict_linear_reg(model: &linfa_linear_reg, x: Vec<f64>, n_features: i32) -> Doubles {
let n_features = n_features as usize;

// Convert Vec<f64> to Array2 for x
let x = Array2::from_shape_vec((n_features, x.len() / n_features), x)
.expect("Failed to reshape x")
.t()
.to_owned();
fn predict_linear_reg(model: &linfa_linear_reg, x: ArrayView2<f64>) -> Doubles {
let x: Array2<f64> = x.to_owned();

let preds = model.model.predict(&x);
let preds = preds.into_raw_vec();
Expand Down

0 comments on commit d059d4b

Please sign in to comment.