Skip to content

Commit

Permalink
missing probability module
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Oct 26, 2018
1 parent fe24a99 commit 90e1514
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions R/mlp_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,17 @@ mlp_keras_data <-
)
)


nnet_softmax <- function(results, object) {
if (ncol(results) == 1)
results <- cbind(1 - results, results)

results <- apply(results, 1, function(x) exp(x)/sum(exp(x)))
results <- as_tibble(t(results))
names(results) <- paste0(".pred_", object$lvl)
results
}

mlp_nnet_data <-
list(
libs = "nnet",
Expand Down Expand Up @@ -103,6 +114,17 @@ mlp_nnet_data <-
type = "class"
)
),
classprob = list(
pre = NULL,
post = nnet_softmax,
func = c(fun = "predict"),
args =
list(
object = quote(object$fit),
newdata = quote(new_data),
type = "raw"
)
),
raw = list(
pre = NULL,
func = c(fun = "predict"),
Expand All @@ -114,6 +136,7 @@ mlp_nnet_data <-
)
)


# ------------------------------------------------------------------------------

# keras wrapper for feed-forward nnet
Expand Down

0 comments on commit 90e1514

Please sign in to comment.