Skip to content

Commit

Permalink
cache more in across_setup() (tidyverse#4982)
Browse files Browse the repository at this point in the history
* cache more in across_setup()

* use match.call() in across_setup() and c_across_setup()

* enquo() later

* inline local_column() in across_apply()

* Standardize on non-dotted arguments

* Extract out `key` into its own variable

* Use simpler and faster `sys.call()` rather than `match.call()`

* Set `cur_column()` info once per column

* Formatting tweaks

* `enquo()` later in `c_across_setup()`

* Test that cache key depends on all inputs, not just the `.cols`

* use new_tibble() instead of as_tibble() (tidyverse#4997)

* use new_tibble() instead of as_tibble()

* perform name repair in across_setup()

* internally use vctrs::vec_recycle() on across() results

* use vctrs::vec_recycle_common()

* + tests for across() tidy recycling behaviour

* from @DavisVaughan review

Co-authored-by: DavisVaughan <[email protected]>
  • Loading branch information
romainfrancois and DavisVaughan committed Mar 18, 2020
1 parent a0390fd commit 3e3850b
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 54 deletions.
137 changes: 86 additions & 51 deletions R/across.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,87 +76,66 @@
#' )
#' @export
across <- function(.cols = everything(), .fns = NULL, ..., .names = NULL) {
vars <- across_select({{ .cols }})
key <- key_deparse(sys.call())
setup <- across_setup({{ .cols }}, fns = .fns, names = .names, key = key)

vars <- setup$vars
fns <- setup$fns
names <- setup$names

mask <- peek_mask()
data <- mask$current_cols(vars)

if (is.null(.fns)) {
if (is.null(fns)) {
nrow <- length(mask$current_rows())
data <- new_tibble(data, nrow = nrow)

if (is.null(.names)) {
if (is.null(names)) {
return(data)
} else {
return(set_names(data, glue(.names, col = names(data), fn = "1")))
}
}

# apply `.names` smart default
if (is.function(.fns) || is_formula(.fns)) {
.names <- .names %||% "{col}"
.fns <- list("1" = .fns)
} else {
.names <- .names %||% "{col}_{fn}"
}

if (!is.list(.fns)) {
abort("`.fns` must be NULL, a function, a formula, or a list of functions/formulas", class = "dplyr_error_across")
}

# make sure fns has names, use number to replace unnamed
if (is.null(names(.fns))) {
names_fns <- seq_along(.fns)
} else {
names_fns <- names(.fns)
empties <- which(names_fns == "")
if (length(empties)) {
names_fns[empties] <- empties
return(set_names(data, names))
}
}

# handle formulas
.fns <- map(.fns, as_function)

n_cols <- length(data)
n_fns <- length(.fns)
n_fns <- length(fns)

seq_n_cols <- seq_len(n_cols)
seq_fns <- seq_len(n_fns)

k <- 1L
out <- vector("list", n_cols * n_fns)

# Reset `cur_column()` info on exit
old_var <- context_peek_bare("column")
on.exit(context_poke("column", old_var), add = TRUE)

# Loop in such an order that all functions are applied
# to a single column before moving on to the next column
for (i in seq_n_cols) {
var <- vars[[i]]
col <- data[[i]]

context_poke("column", var)

for (j in seq_fns) {
fn <- .fns[[j]]
out[[k]] <- across_apply(var, col, fn, ...)
fn <- fns[[j]]
out[[k]] <- fn(col, ...)
k <- k + 1L
}
}

names(out) <- glue(.names,
col = rep(vars, each = length(.fns)),
fn = rep(names_fns, length(data))
)

as_tibble(out)
}

across_apply <- function(var, col, fn, ...) {
local_column(var)
fn(col, ...)
size <- vec_size_common(!!!out)
out <- vec_recycle_common(!!!out, .size = size)
names(out) <- names
new_tibble(out, nrow = size)
}

#' @export
#' @rdname across
c_across <- function(.cols = everything()) {
vars <- across_select({{ .cols }})
key <- key_deparse(sys.call())
vars <- c_across_setup({{ .cols }}, key = key)

mask <- peek_mask()

Expand All @@ -166,26 +145,82 @@ c_across <- function(.cols = everything()) {
vec_c(!!!.cols)
}

# TODO: The usage of a cache in `across_select()` is a stopgap solution, and
# TODO: The usage of a cache in `across_setup()` and `c_across_setup()` is a stopgap solution, and
# this idea should not be used anywhere else. This should be replaced by the
# next version of hybrid evaluation, which should offer a way for any function
# to do any required "set up" work (like the `eval_select()` call) a single
# time per top-level call, rather than once per group.
across_select <- function(cols) {
across_setup <- function(cols, fns, names, key) {
mask <- peek_mask()

value <- mask$across_cache_get(key)
if (!is.null(value)) {
return(value)
}

cols <- enquo(cols)
across_cols <- mask$across_cols()

vars <- tidyselect::eval_select(expr(!!cols), across_cols)
vars <- names(vars)

key <- quo_get_expr(cols)
key <- key_deparse(key)
if (is.null(fns)) {
if (!is.null(names)) {
names <- vec_as_names(glue(names, col = vars, fn = "1"), repair = "check_unique")
}

value <- list(vars = vars, fns = fns, names = names)
mask$across_cache_add(key, value)

return(value)
}

# apply `.names` smart default
if (is.function(fns) || is_formula(fns)) {
names <- names %||% "{col}"
fns <- list("1" = fns)
} else {
names <- names %||% "{col}_{fn}"
}

if (!is.list(fns)) {
abort("`.fns` must be NULL, a function, a formula, or a list of functions/formulas", class = "dplyr_error_across")
}

# handle formulas
fns <- map(fns, as_function)

# make sure fns has names, use number to replace unnamed
if (is.null(names(fns))) {
names_fns <- seq_along(fns)
} else {
names_fns <- names(fns)
empties <- which(names_fns == "")
if (length(empties)) {
names_fns[empties] <- empties
}
}

cache <- mask$across_cache_get()
value <- cache[[key]]
names <- vec_as_names(glue(names,
col = rep(vars, each = length(fns)),
fn = rep(names_fns, length(vars))
), repair = "check_unique")

value <- list(vars = vars, fns = fns, names = names)
mask$across_cache_add(key, value)

value
}

c_across_setup <- function(cols, key) {
mask <- peek_mask()

value <- mask$across_cache_get(key)
if (!is.null(value)) {
return(value)
}

cols <- enquo(cols)
across_cols <- mask$across_cols()

vars <- tidyselect::eval_select(expr(!!cols), across_cols)
Expand Down
5 changes: 4 additions & 1 deletion R/context.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,11 @@ context_poke <- function(name, value) {
context_env[[name]] <- value
old
}
context_peek_bare <- function(name) {
context_env[[name]]
}
context_peek <- function(name, fun, location = "dplyr verbs") {
context_env[[name]] %||%
context_peek_bare(name) %||%
abort(glue("{fun} must only be used inside {location}"))
}
context_local <- function(name, value, frame = caller_env()) {
Expand Down
4 changes: 2 additions & 2 deletions R/data-mask.R
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ DataMask <- R6Class("DataMask",
cols
},

across_cache_get = function() {
private$across_cache
across_cache_get = function(key) {
private$across_cache[[key]]
},

across_cache_add = function(key, value) {
Expand Down
28 changes: 28 additions & 0 deletions tests/testthat/test-across.R
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,34 @@ test_that("monitoring cache - across() usage can depend on the group id", {
)
})

test_that("monitoring cache - across() internal cache key depends on all inputs", {
df <- tibble(g = rep(1:2, each = 2), a = 1:4)
df <- group_by(df, g)

expect_identical(
mutate(df, tibble(x = across(is.numeric, mean)$a, y = across(is.numeric, max)$a)),
mutate(df, x = mean(a), y = max(a))
)
})

test_that("across() rejects non vectors", {
expect_error(
data.frame(x = 1) %>% summarise(across(everything(), ~sym("foo")))
)
})

test_that("across() uses tidy recycling rules", {
expect_equal(
data.frame(x = 1, y = 2) %>% summarise(across(everything(), ~rep(42, .))),
data.frame(x = rep(42, 2), y = rep(42, 2))
)

expect_error(
data.frame(x = 2, y = 3) %>% summarise(across(everything(), ~rep(42, .)))
)
})


# c_across ----------------------------------------------------------------

test_that("selects and combines columns", {
Expand Down

0 comments on commit 3e3850b

Please sign in to comment.