Skip to content

Commit

Permalink
pass x matrix directly for logistic_reg()
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch committed Aug 26, 2024
1 parent 51b435c commit 121408e
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 20 deletions.
4 changes: 2 additions & 2 deletions R/extendr-wrappers.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ fit_linear_reg <- function(x, y) .Call(wrap__fit_linear_reg, x, y)

predict_linear_reg <- function(model, x) .Call(wrap__predict_linear_reg, model, x)

fit_logistic_reg <- function(x, y, n_features) .Call(wrap__fit_logistic_reg, x, y, n_features)
fit_logistic_reg <- function(x, y) .Call(wrap__fit_logistic_reg, x, y)

predict_logistic_reg <- function(model, x, n_features) .Call(wrap__predict_logistic_reg, model, x, n_features)
predict_logistic_reg <- function(model, x) .Call(wrap__predict_logistic_reg, model, x)

fit_multinom_reg <- function(x, y, n_features) .Call(wrap__fit_multinom_reg, x, y, n_features)

Expand Down
4 changes: 2 additions & 2 deletions R/mod-logistic_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
}

# TODO: check that there are not more than two classes
fit <- fit_logistic_reg(c(x), y, ncol(x))
fit <- fit_logistic_reg(x, y)

structure(
list(fit = fit, ptype = vctrs::vec_slice(x, 0)),
Expand All @@ -47,7 +47,7 @@

#' @export
predict.linfa_logistic_reg <- function(object, newdata, ...) {
predict_logistic_reg(object$fit, c(newdata), n_features = ncol(object$ptype))
predict_logistic_reg(object$fit, newdata)
}

# nocov start
Expand Down
21 changes: 5 additions & 16 deletions src/rust/src/logistic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,32 +24,21 @@ impl From<FittedLogisticRegression<f64, usize>> for linfa_logistic_reg {
}

#[extendr]
fn fit_logistic_reg(x: Vec<f64>, y: Vec<i32>, n_features: i32) -> linfa_logistic_reg {
let n_features = n_features as usize;

let x = Array2::from_shape_vec((n_features, x.len() / n_features), x)
.expect("Failed to reshape x")
.t()
.to_owned();
fn fit_logistic_reg(x: ArrayView2<f64>, y: Vec<i32>) -> linfa_logistic_reg {
let x: Array2<f64> = x.to_owned();

let y = Array1::from_vec(y.into_iter().map(|v| v as usize).collect());

let dataset = Dataset::new(x, y)
.with_feature_names((0..n_features).map(|i| format!("feature_{}", i)).collect());
let dataset = Dataset::new(x, y);

let model = LogisticRegression::default().fit(&dataset).unwrap();

linfa_logistic_reg::from(model)
}

#[extendr]
fn predict_logistic_reg(model: &linfa_logistic_reg, x: Vec<f64>, n_features: i32) -> Integers {
let n_features = n_features as usize;

let x = Array2::from_shape_vec((n_features, x.len() / n_features), x)
.expect("Failed to reshape x")
.t()
.to_owned();
fn predict_logistic_reg(model: &linfa_logistic_reg, x: ArrayView2<f64>) -> Integers {
let x: Array2<f64> = x.to_owned();

let preds = model.model.predict(&x);

Expand Down

0 comments on commit 121408e

Please sign in to comment.