Skip to content

Commit

Permalink
pass x matrix directly for naive_Bayes()
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch committed Aug 26, 2024
1 parent d059d4b commit 51b435c
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 24 deletions.
4 changes: 2 additions & 2 deletions R/extendr-wrappers.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
#' @useDynLib rinfa, .registration = TRUE
NULL

fit_naive_Bayes <- function(x, y, n_features, var_smoothing) .Call(wrap__fit_naive_Bayes, x, y, n_features, var_smoothing)
fit_naive_Bayes <- function(x, y, var_smoothing) .Call(wrap__fit_naive_Bayes, x, y, var_smoothing)

predict_naive_Bayes <- function(model, x, n_features) .Call(wrap__predict_naive_Bayes, model, x, n_features)
predict_naive_Bayes <- function(model, x) .Call(wrap__predict_naive_Bayes, model, x)

fit_linear_reg <- function(x, y) .Call(wrap__fit_linear_reg, x, y)

Expand Down
5 changes: 2 additions & 3 deletions R/mod-naive_Bayes.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@

fit <-
fit_naive_Bayes(
c(x),
x,
y,
ncol(x),
var_smoothing = smoothness
)

Expand All @@ -49,7 +48,7 @@

#' @export
predict.linfa_naive_Bayes <- function(object, newdata, ...) {
predict_naive_Bayes(object$fit, c(newdata), n_features = ncol(object$ptype))
predict_naive_Bayes(object$fit, newdata)
}


Expand Down
24 changes: 5 additions & 19 deletions src/rust/src/bayes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,14 @@ impl From<GaussianNb<f64, usize>> for linfa_naive_Bayes {
}

#[extendr]
pub fn fit_naive_Bayes(x: Vec<f64>, y: Vec<i32>, n_features: i32, var_smoothing: f64) -> linfa_naive_Bayes {
let n_features = n_features as usize;

// Convert Vec<f64> to Array2 for x
let x = Array2::from_shape_vec((n_features, x.len() / n_features), x)
.expect("Failed to reshape x")
.t()
.to_owned();
pub fn fit_naive_Bayes(x: ArrayView2<f64>, y: Vec<i32>, var_smoothing: f64) -> linfa_naive_Bayes {
let x: Array2<f64> = x.to_owned();

// Convert Vec<i32> to Array1<usize> for y
let y = Array1::from_vec(y.into_iter().map(|v| v as usize).collect());

// Create a Dataset
let dataset = Dataset::new(x, y)
.with_feature_names((0..n_features).map(|i| format!("feature_{}", i)).collect());

let dataset = Dataset::new(x, y);
let model = GaussianNb::params()
.var_smoothing(var_smoothing)
.fit(&dataset)
Expand All @@ -44,14 +36,8 @@ pub fn fit_naive_Bayes(x: Vec<f64>, y: Vec<i32>, n_features: i32, var_smoothing:
}

#[extendr]
pub fn predict_naive_Bayes(model: &linfa_naive_Bayes, x: Vec<f64>, n_features: i32) -> Integers {
let n_features = n_features as usize;

// Convert Vec<f64> to Array2 for x
let x = Array2::from_shape_vec((n_features, x.len() / n_features), x)
.expect("Failed to reshape x")
.t()
.to_owned();
pub fn predict_naive_Bayes(model: &linfa_naive_Bayes, x: ArrayView2<f64>) -> Integers {
let x: Array2<f64> = x.to_owned();

let preds = model.model.predict(&x);

Expand Down

0 comments on commit 51b435c

Please sign in to comment.