Skip to content

Commit

Permalink
fill missing columns in xss with NAs in xdt (mlr-org#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc authored May 30, 2021
1 parent 438843e commit 32e5371
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 3 deletions.
4 changes: 2 additions & 2 deletions R/ObjectiveRFunDt.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ ObjectiveRFunDt = R6Class("ObjectiveRFunDt",

#' @description
#' Evaluates multiple input values received as a list, converted to a `data.table()` on the
#' objective function.
#' objective function. Missing columns in xss are filled with `NA`s in `xdt`.
#'
#' @param xss (`list()`)\cr
#' A list of lists that contains multiple x values, e.g.
Expand All @@ -45,7 +45,7 @@ ObjectiveRFunDt = R6Class("ObjectiveRFunDt",
#' `data.table(y = 1:2)` or `data.table(y1 = 1:2, y2 = 3:4)`.
eval_many = function(xss) {
if (self$check_values) lapply(xss, self$domain$assert)
res = private$.fun(rbindlist(xss, use.names = TRUE))
res = private$.fun(rbindlist(xss, use.names = TRUE, fill = TRUE))
if (self$check_values) self$codomain$assert_dt(res[, self$codomain$ids(), with = FALSE])
return(res)
},
Expand Down
2 changes: 1 addition & 1 deletion man/ObjectiveRFunDt.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 21 additions & 0 deletions tests/testthat/test_Objective.R
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,24 @@ test_that("ObjectiveRFunDt works with a list containing elements with different
res = rfun_dt$eval_many(list(list(x = 1, z = 2), list(x = 1, z = 2)))
expect_equal(res, data.table(y = c(1, 1)))
})

test_that("ObjectiveRFunDt works with deps #141", {
FUN = function(xdt) {
pmap_dtr(xdt, function(x1, x2) {
data.table(y = if(is.na(x2)) x1 else x2)
})
}
domain = ps(x1 = p_int(), x2 = p_int())
domain$add_dep("x2", "x1", CondEqual$new(-1))
codomain = ps(y = p_dbl(tags = "minimize"))
rfun_dt = ObjectiveRFunDt$new(fun = FUN, domain = domain, codomain = codomain)

design = Design$new(
domain,
data.table(x1 = c(-1, 1), x2 = c(2, 2)),
remove_dupl = FALSE
)
xss = design$transpose(trafo = TRUE, filter_na = TRUE)
res = rfun_dt$eval_many(xss)
expect_equal(res, data.table(y = c(2, 1)))
})

0 comments on commit 32e5371

Please sign in to comment.