Skip to content

Commit

Permalink
benchmark naive bayes and SVMs (closes #4)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch committed Aug 2, 2024
1 parent 625b9b7 commit 3a56e35
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 9 deletions.
157 changes: 148 additions & 9 deletions vignettes/articles/benchmarks.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ rinfa supports a number of model types:
* Multinomial regression
* Decision trees
* Naive bayes
* Support vector machines

For each of those model types, we'll benchmark the time to fit the model to a dataset of varying size, and compare the time-to-fit across each modeling engine.

Expand Down Expand Up @@ -140,7 +141,7 @@ for (engine in engines) {
```{r save-linear_reg_timings, include = FALSE, eval = eval_fits}
save(
linear_reg_timings,
file = "vignettes/articles/timings/linear_reg_timings.rda"
file = "timings/linear_reg_timings.rda"
)
```

Expand Down Expand Up @@ -213,14 +214,14 @@ for (engine in engines) {
save(
logistic_reg_timings,
file = "vignettes/articles/timings/logistic_reg_timings.rda"
file = "timings/logistic_reg_timings.rda"
)
```

```{r save-logistic_reg_timings, include = FALSE, eval = eval_fits}
save(
logistic_reg_timings,
file = "vignettes/articles/timings/logistic_reg_timings"
file = "timings/logistic_reg_timings"
)
```

Expand Down Expand Up @@ -294,7 +295,7 @@ for (engine in engines) {
```{r save-multinom_reg_timings, include = FALSE, eval = eval_fits}
save(
multinom_reg_timings,
file = "vignettes/articles/timings/multinom_reg_timings.rda"
file = "timings/multinom_reg_timings.rda"
)
```

Expand Down Expand Up @@ -367,19 +368,86 @@ for (engine in engines) {
save(
decision_tree_timings,
file = "vignettes/articles/timings/decision_tree_timings.rda"
file = "timings/decision_tree_timings.rda"
)
```

```{r save-decision_tree_timings, include = FALSE, eval = eval_fits}
```{r}
decision_tree_timings %>%
ggplot() +
aes(x = n_row, y = timing, colour = engine, group = engine) +
geom_line() +
scale_y_log10() +
scale_x_log10() +
labs(x = "# rows", y = "Log(Fit time, seconds)")
```

## Naive bayes

```{r naive_Bayes_timings, include = FALSE, message = FALSE, warning = FALSE, eval = eval_fits}
library(discrim)
engines <- unique(get_model_env()[["naive_Bayes"]]$engine)
engines
naive_Bayes_timings <-
data.frame(
engine = character(),
n_row = integer(),
timing = numeric()
)
x <- lapply(engines, function(engine) {
pkgs <- required_pkgs(naive_Bayes(engine = engine))
lapply(pkgs, require, character.only = TRUE)
})
for (engine in engines) {
spec <- naive_Bayes(engine = engine, mode = "classification")
if (engine == "glmnet") {
spec <- spec %>% set_args(penalty = 0)
}
if (engine %in% c("spark", "keras")) {
next
}
for (n_row in n_rows) {
print(paste0("Engine: ", engine, " # Rows: ", n_row))
longest_fit <- naive_Bayes_timings[naive_Bayes_timings$engine == engine,]
longest_fit <- max(longest_fit$timing)
if (longest_fit > 600) next
set.seed(1)
d <- sim_classification(n_row)
fit_encoding <- get_fit("naive_Bayes")
fit_encoding <- fit_encoding[fit_encoding$engine == engine, "value"]
fit_encoding <- fit_encoding$value[[1]]$interface
if (!identical(fit_encoding, "matrix")) {
timing <- system.time(fit(spec, class ~ ., d))
} else {
x <- as.matrix(d[colnames(d) != "class"])
timing <- system.time(fit_xy(spec, x = x, y = d$class))
}
naive_Bayes_timings <-
bind_rows(
naive_Bayes_timings,
data.frame(engine = engine, n_row = n_row, timing = timing[["elapsed"]])
)
}
}
save(
decision_tree_timings,
file = "vignettes/articles/timings/multinom_reg_timings.rda"
naive_Bayes_timings,
file = "timings/naive_Bayes_timings.rda"
)
```

```{r}
decision_tree_timings %>%
naive_Bayes_timings %>%
ggplot() +
aes(x = n_row, y = timing, colour = engine, group = engine) +
geom_line() +
Expand All @@ -388,3 +456,74 @@ decision_tree_timings %>%
labs(x = "# rows", y = "Log(Fit time, seconds)")
```

## Support vector machines

```{r svm_linear_timings, include = FALSE, message = FALSE, warning = FALSE, eval = eval_fits}
engines <- unique(get_model_env()[["svm_linear"]]$engine)
engines
svm_linear_timings <-
data.frame(
engine = character(),
n_row = integer(),
timing = numeric()
)
x <- lapply(engines, function(engine) {
pkgs <- required_pkgs(svm_linear(engine = engine))
lapply(pkgs, require, character.only = TRUE)
})
for (engine in engines) {
spec <- svm_linear(engine = engine, mode = "classification")
if (engine == "glmnet") {
spec <- spec %>% set_args(penalty = 0)
}
if (engine %in% c("spark", "keras")) {
next
}
for (n_row in n_rows) {
print(paste0("Engine: ", engine, " # Rows: ", n_row))
longest_fit <- svm_linear_timings[svm_linear_timings$engine == engine,]
longest_fit <- max(longest_fit$timing)
if (longest_fit > 600) next
set.seed(1)
d <- sim_classification(n_row)
fit_encoding <- get_fit("svm_linear")
fit_encoding <- fit_encoding[fit_encoding$engine == engine, "value"]
fit_encoding <- fit_encoding$value[[1]]$interface
if (!identical(fit_encoding, "matrix")) {
timing <- system.time(fit(spec, class ~ ., d))
} else {
x <- as.matrix(d[colnames(d) != "class"])
timing <- system.time(fit_xy(spec, x = x, y = d$class))
}
svm_linear_timings <-
bind_rows(
svm_linear_timings,
data.frame(engine = engine, n_row = n_row, timing = timing[["elapsed"]])
)
}
}
save(
svm_linear_timings,
file = "timings/svm_linear_timings.rda"
)
```

```{r}
svm_linear_timings %>%
ggplot() +
aes(x = n_row, y = timing, colour = engine, group = engine) +
geom_line() +
scale_y_log10() +
scale_x_log10() +
labs(x = "# rows", y = "Log(Fit time, seconds)")
```
Binary file added vignettes/articles/timings/naive_Bayes_timings.rda
Binary file not shown.
Binary file added vignettes/articles/timings/svm_linear_timings.rda
Binary file not shown.

0 comments on commit 3a56e35

Please sign in to comment.