Skip to content

Commit

Permalink
Merge branch 'main' of github.com:mlr-org/bbotk
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Sep 9, 2024
2 parents 05b8c3b + 3c53790 commit 178e721
Show file tree
Hide file tree
Showing 19 changed files with 390 additions and 75 deletions.
1 change: 0 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ export(callback_async)
export(callback_batch)
export(clbk)
export(clbks)
export(evaluate_queue_default)
export(is_dominated)
export(mlr_callbacks)
export(mlr_optimizers)
Expand Down
40 changes: 34 additions & 6 deletions R/CallbackAsync.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,28 @@ CallbackAsync = R6Class("CallbackAsync",
public = list(

#' @field on_optimization_begin (`function()`)\cr
#' Stage called at the beginning of the optimization in the main process.
#' Called in `Optimizer$optimize()`.
#' Stage called at the beginning of the optimization in the main process.
#' Called in `Optimizer$optimize()`.
on_optimization_begin = NULL,

#' @field on_worker_begin (`function()`)\cr
#' Stage called at the beginning of the optimization on the worker.
#' Called in the worker loop.
#' Stage called at the beginning of the optimization on the worker.
#' Called in the worker loop.
on_worker_begin = NULL,

#' @field on_optimizer_before_eval (`function()`)\cr
#' Stage called after the optimizer proposes points.
#' Called in `OptimInstance$.eval_point()`.
on_optimizer_before_eval = NULL,

#' @field on_optimizer_after_eval (`function()`)\cr
#' Stage called after points are evaluated.
#' Called in `OptimInstance$.eval_point()`.
on_optimizer_after_eval = NULL,

#' @field on_worker_end (`function()`)\cr
#' Stage called at the end of the optimization on the worker.
#' Called in the worker loop.
#' Stage called at the end of the optimization on the worker.
#' Called in the worker loop.
on_worker_end = NULL,

#' @field on_result (`function()`)\cr
Expand Down Expand Up @@ -52,6 +62,10 @@ CallbackAsync = R6Class("CallbackAsync",
#' - on_optimization_begin
#' Start Worker
#' - on_worker_begin
#' Start Optimization on Worker
#' - on_optimizer_before_eval
#' - on_optimizer_after_eval
#' End Optimization on Worker
#' - on_worker_end
#' End Worker
#' - on_result
Expand Down Expand Up @@ -81,6 +95,14 @@ CallbackAsync = R6Class("CallbackAsync",
#' Stage called at the beginning of the optimization on the worker.
#' Called in the worker loop.
#' The functions must have two arguments named `callback` and `context`.
#' @param on_optimizer_before_eval (`function()`)\cr
#' Stage called after the optimizer proposes points.
#' Called in `OptimInstance$eval_point()`.
#' The functions must have two arguments named `callback` and `context`.
#' @param on_optimizer_after_eval (`function()`)\cr
#' Stage called after points are evaluated.
#' Called in `OptimInstance$eval_point()`.
#' The functions must have two arguments named `callback` and `context`.
#' @param on_worker_end (`function()`)\cr
#' Stage called at the end of the optimization on the worker.
#' Called in the worker loop.
Expand All @@ -101,18 +123,24 @@ callback_async = function(
man = NA_character_,
on_optimization_begin = NULL,
on_worker_begin = NULL,
on_optimizer_before_eval = NULL,
on_optimizer_after_eval = NULL,
on_worker_end = NULL,
on_result = NULL,
on_optimization_end = NULL
) {
stages = discard(set_names(list(
on_optimization_begin,
on_worker_begin,
on_optimizer_before_eval,
on_optimizer_after_eval,
on_worker_end,
on_result,
on_optimization_end),
c("on_optimization_begin",
"on_worker_begin",
"on_optimizer_before_eval",
"on_optimizer_after_eval",
"on_worker_end",
"on_result",
"on_optimization_end")), is.null)
Expand Down
42 changes: 41 additions & 1 deletion R/ContextAsync.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,53 @@ ContextAsync = R6Class("ContextAsync",
active = list(

#' @field result ([data.table::data.table])\cr
#' The result of the optimization.
#' The result of the optimization.
result = function(rhs) {
if (missing(rhs)) {
get_private(self$instance)$.result
} else {
get_private(self$instance, ".result") = rhs
}
},

#' @field xs (list())\cr
#' The point to be evaluated.
xs = function(rhs) {
if (missing(rhs)) {
get_private(self$instance)$.xs
} else {
get_private(self$instance, ".xs") = rhs
}
},

#' @field xs_trafoed (list())\cr
#' The transformed point to be evaluated.
xs_trafoed = function(rhs) {
if (missing(rhs)) {
get_private(self$instance)$.xs_trafoed
} else {
get_private(self$instance, ".xs_trafoed") = rhs
}
},

#' @field extra (list())\cr
#' Additional information.
extra = function(rhs) {
if (missing(rhs)) {
get_private(self$instance)$.extra
} else {
get_private(self$instance, ".extra") = rhs
}
},

#' @field ys (list())\cr
#' The result of the evaluation.
ys = function(rhs) {
if (missing(rhs)) {
get_private(self$instance)$.ys
} else {
get_private(self$instance, ".ys") = rhs
}
}
)
)
45 changes: 44 additions & 1 deletion R/OptimInstanceAsync.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,50 @@ OptimInstanceAsync = R6Class("OptimInstanceAsync",
),

private = list(
.xs = NULL,
.xs_trafoed = NULL,
.extra = NULL,
.ys = NULL,

.eval_point = function(xs) {
# transpose point
private$.xs = xs[self$archive$cols_x]
private$.xs_trafoed = trafo_xs(private$.xs, self$search_space)
private$.extra = xs[names(xs) %nin% c(self$archive$cols_x, "x_domain")]

call_back("on_optimizer_before_eval", self$objective$callbacks, self$objective$context)

# eval
key = self$archive$push_running_point(private$.xs)
private$.ys = self$objective$eval(private$.xs_trafoed)

call_back("on_optimizer_after_eval", self$objective$callbacks, self$objective$context)

# push result
self$archive$push_result(key, private$.ys, x_domain = private$.xs_trafoed, extra = private$.extra)

return(invisible(private$.ys))
},

.eval_queue = function() {
while (!self$is_terminated && self$archive$n_queued) {
task = self$archive$pop_point()
if (!is.null(task)) {
private$.xs = task$xs

# transpose point
private$.xs_trafoed = trafo_xs(private$.xs, self$search_space)

# eval
call_back("on_optimizer_before_eval", self$objective$callbacks, self$objective$context)
private$.ys = self$objective$eval(private$.xs_trafoed)

# push reuslt
call_back("on_optimizer_after_eval", self$objective$callbacks, self$objective$context)
self$archive$push_result(task$key, private$.ys, x_domain = private$.xs_trafoed)
}
}
},

# initialize context for optimization
.initialize_context = function(optimizer) {
Expand All @@ -94,4 +138,3 @@ OptimInstanceAsync = R6Class("OptimInstanceAsync",
}
)
)

31 changes: 11 additions & 20 deletions R/OptimizerAsync.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ OptimizerAsync = R6Class("OptimizerAsync",
optimize = function(inst) {
optimize_async_default(inst, self)
}
),

private = list(
.xdt = NULL,
.ys = NULL
)
)

Expand Down Expand Up @@ -65,7 +70,13 @@ optimize_async_default = function(instance, optimizer, design = NULL, n_workers
rush = rush::RushWorker$new(instance$rush$network_id, remote = FALSE)
instance$rush = rush
instance$archive$rush = rush

call_back("on_worker_begin", instance$objective$callbacks, instance$objective$context)

# run optimizer loop
get_private(optimizer)$.optimize(instance)

call_back("on_worker_end", instance$objective$callbacks, instance$objective$context)
} else {
# run .optimize() on workers
rush = instance$rush
Expand Down Expand Up @@ -135,23 +146,3 @@ optimize_async_default = function(instance, optimizer, design = NULL, n_workers
call_back("on_optimization_end", instance$objective$callbacks, instance$objective$context)
return(instance$result)
}

#' @title Default Evaluation of the Queue
#'
#' @description
#' Used internally in `$.optimize()` of [OptimizerAsync] classes to evaluate a queue of points e.g. in [OptimizerAsyncGridSearch].
#'
#' @param instance [OptimInstanceAsync].
#'
#' @keywords internal
#' @export
evaluate_queue_default = function(instance) {
while (!instance$is_terminated && instance$archive$n_queued) {
task = instance$archive$pop_point() # FIXME: Add fields argument?
if (!is.null(task)) {
xs_trafoed = trafo_xs(task$xs, instance$search_space)
ys = instance$objective$eval(xs_trafoed)
instance$archive$push_result(task$key, ys, x_domain = xs_trafoed)
}
}
}
5 changes: 1 addition & 4 deletions R/OptimizerAsyncDesignPoints.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ OptimizerAsyncDesignPoints = R6Class("OptimizerAsyncDesignPoints",
#' @param inst ([OptimInstance]).
#' @return [data.table::data.table].
optimize = function(inst) {

# generate grid and send to workers
design = inst$search_space$assert_dt(self$param_set$values$design)

Expand All @@ -55,10 +54,8 @@ OptimizerAsyncDesignPoints = R6Class("OptimizerAsyncDesignPoints",

private = list(
.optimize = function(inst) {
archive = inst$archive

# evaluate design of points
evaluate_queue_default(inst)
get_private(inst)$.eval_queue()
}
)
)
Expand Down
5 changes: 1 addition & 4 deletions R/OptimizerAsyncGridSearch.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ OptimizerAsyncGridSearch = R6Class("OptimizerAsyncGridSearch",
#' @param inst ([OptimInstance]).
#' @return [data.table::data.table].
optimize = function(inst) {

# generate grid
pv = self$param_set$values
design = generate_design_grid(inst$search_space, resolution = pv$resolution, param_resolutions = pv$param_resolutions)$data
Expand All @@ -64,10 +63,8 @@ OptimizerAsyncGridSearch = R6Class("OptimizerAsyncGridSearch",

private = list(
.optimize = function(inst) {
archive = inst$archive

# evaluate grid points
evaluate_queue_default(inst)
get_private(inst)$.eval_queue()
}
)
)
Expand Down
14 changes: 4 additions & 10 deletions R/OptimizerAsyncRandomSearch.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,16 @@ OptimizerAsyncRandomSearch = R6Class("OptimizerAsyncRandomSearch",
search_space = inst$search_space

# usually the queue is empty but callbacks might have added points
evaluate_queue_default(inst)
get_private(inst)$.eval_queue()

while(!inst$is_terminated) {
# sample new points
sampler = SamplerUnif$new(search_space)
xdt = sampler$sample(1)$data
xss = transpose_list(xdt)
xs = xss[[1]][inst$archive$cols_x]
xs_trafoed = trafo_xs(xs, search_space)
key = inst$archive$push_running_point(xs)
xs = transpose_list(xdt)[[1]]

# eval
ys = inst$objective$eval(xs_trafoed)

# push result
inst$archive$push_result(key, ys = ys, x_domain = xs_trafoed)
# evaluate
get_private(inst)$.eval_point(xs)
}
}
)
Expand Down
8 changes: 8 additions & 0 deletions man/CallbackAsync.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions man/ContextAsync.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions man/callback_async.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 178e721

Please sign in to comment.