Skip to content

Commit

Permalink
run: Set all lower/upper bounds
Browse files Browse the repository at this point in the history
Only setting finite values means you can't clear a bound by setting to
NA.
  • Loading branch information
lentinj committed Jul 15, 2024
1 parent 217484e commit ff00f55
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 4 deletions.
4 changes: 2 additions & 2 deletions R/init_val.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ g3_init_val <- function (
stopifnot(is.character(name_spec) && length(name_spec) == 1)
stopifnot(is.numeric(value) || is.null(value))
stopifnot(is.numeric(spread) || is.null(spread))
stopifnot(is.numeric(lower) || is.null(lower))
stopifnot(is.numeric(upper) || is.null(upper))
stopifnot(is.numeric(lower) || is.na(lower) || is.null(lower))
stopifnot(is.numeric(upper) || is.na(upper) || is.null(upper))
stopifnot(is.logical(optimise) || is.null(optimise))
stopifnot(identical(parscale, 'auto') || is.numeric(parscale) || is.null(parscale))
stopifnot(is.logical(random) || is.null(random))
Expand Down
4 changes: 2 additions & 2 deletions R/run.R
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ update_data_bounds <- function (model_data, param_tmpl) {
# User didn't supply extra parameters, nothing to do
} else if (is.data.frame(param_tmpl)) {
for (param_type in c('lower', 'upper')) {
for (i in which(is.finite(param_tmpl[[param_type]])) ) {
for (i in seq_len(nrow(param_tmpl))) {
data_var <- cpp_escape_varname(paste0(param_tmpl[i, 'switch'], '__', param_type))
if (!exists(data_var, envir = model_data)) next

data_val <- param_tmpl[i, param_type]
model_data[[data_var]] <- if (is.finite(data_val)) data_val else NaN
model_data[[data_var]] <- if (is.na(data_val)) NaN else data_val
}
}
} else {
Expand Down
19 changes: 19 additions & 0 deletions tests/test-likelihood_bounds.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ actions <- list(g3a_time(2000, 2000), list("555" = g3_formula({
nll <- nll + g3_param('pa')
nll <- nll + g3_param('pb')
nll <- nll + g3_param('pc')
nll <- nll + g3_param('pd', value = 0, lower = -10, upper = 10)
})))
model_code <- g3_to_tmb(c(actions, list( g3l_bounds_penalty(actions) )))
fn <- g3_to_r(c(actions, list( g3l_bounds_penalty(actions) )))
Expand Down Expand Up @@ -140,6 +141,24 @@ if (nzchar(Sys.getenv('G3_TEST_TMB'))) {
obj.fn$fn(),
2e12,
tolerance=1e1), "nll: TMB version, 2 parameters outside bounds")

suppressWarnings(attr(model_code, 'parameter_template') |>
g3_init_val('pd', 100) |>
identity() -> params.in)
obj.fn <- g3_tmb_adfun(model_code, params.in)
ok(ut_cmp_equal(
obj.fn$fn(),
2e12,
tolerance = 1e1 ), "nll: Outside initial bounds")

suppressWarnings(attr(model_code, 'parameter_template') |>
g3_init_val('pd', 100, lower = NA, upper = NA) |>
identity() -> params.in)
obj.fn <- g3_tmb_adfun(model_code, params.in)
ok(ut_cmp_equal(
obj.fn$fn(),
sum(unlist(params.in$value)) ), "nll: Outside initial bounds, but we cleared them")

} else {
writeLines("# skip: not compiling TMB model")
}

0 comments on commit ff00f55

Please sign in to comment.