Skip to content

Commit

Permalink
simplify across() selection (tidyverse#5991)
Browse files Browse the repository at this point in the history
* rm handling of across cache, only remaining use in c_across()

* simplify <DataMask>$across_cols() to avoid the used/unused handling

* simplify across_setup() by removing the .top_level argument

* <DataMask>$add_one() contributes to <DataMask>$all_types

* test for tidyverse#5951

closes tidyverse#5951

* DataMask only keeps `current_data` instead of both `all_vars` and `all_data`

* <DataMask>$current_data uses full data and not just ptypes

* internal dplyr_summarise_recycle_chunks() signals that chunks have been recycled and therefore that the associated result needs to be regenerated.
  • Loading branch information
romainfrancois committed Sep 7, 2021
1 parent e644961 commit ed2d8f5
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 148 deletions.
39 changes: 3 additions & 36 deletions R/across.R
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,7 @@ if_across <- function(op, df) {
#' )
c_across <- function(cols = everything()) {
cols <- enquo(cols)
key <- key_deparse(cols)
vars <- c_across_setup(!!cols, key = key)
vars <- c_across_setup(!!cols)

mask <- peek_mask("c_across()")

Expand All @@ -250,30 +249,14 @@ across_glue_mask <- function(.col, .fn, .caller_env) {
glue_mask
}

# TODO: The usage of a cache in `c_across_setup()` is a stopgap solution, and
# this idea should not be used anywhere else. This should be replaced by either
# expansions of expressions (as we now use for `across()`) or the
# next version of hybrid evaluation, which should offer a way for any function
# to do any required "set up" work (like the `eval_select()` call) a single
# time per top-level call, rather than once per group.
across_setup <- function(cols,
fns,
names,
.caller_env,
mask = peek_mask("across()"),
.top_level = FALSE,
inline = FALSE) {
cols <- enquo(cols)

if (.top_level) {
# FIXME: this is a little bit hacky to make top_across()
# work, otherwise mask$across_cols() fails when calling
# self$current_cols(across_vars_used)
# it should not affect anything because it is expected that
# across_setup() is only ever called on the first group anyway
# but perhaps it is time to review how across_cols() work
mask$set_current_group(1L)
}
# `across()` is evaluated in a data mask so we need to remove the
# mask layer from the quosure environment (#5460)
cols <- quo_set_env(cols, data_mask_top(quo_get_env(cols), recursive = FALSE, inherit = FALSE))
Expand All @@ -290,6 +273,7 @@ across_setup <- function(cols,
))
}
across_cols <- mask$across_cols()

vars <- tidyselect::eval_select(cols, data = across_cols)
names_vars <- names(vars)
vars <- names(across_cols)[vars]
Expand Down Expand Up @@ -356,34 +340,18 @@ data_mask_top <- function(env, recursive = FALSE, inherit = FALSE) {
env
}

c_across_setup <- function(cols, key) {
c_across_setup <- function(cols) {
mask <- peek_mask("c_across()")

value <- mask$across_cache_get(key)
if (!is.null(value)) {
return(value)
}

cols <- enquo(cols)
across_cols <- mask$across_cols()

vars <- tidyselect::eval_select(expr(!!cols), across_cols)
value <- names(vars)

mask$across_cache_add(key, value)

value
}

# FIXME: Should not cache `cols` when it includes env-expressions
# https://github.com/r-lib/tidyselect/issues/235
key_deparse <- function(cols) {
paste(
paste0(deparse(quo_get_expr(cols)), collapse = "\n"),
format(quo_get_env(cols))
)
}

new_dplyr_quosure <- function(quo, ...) {
attr(quo, "dplyr:::data") <- list2(...)
quo
Expand Down Expand Up @@ -511,7 +479,6 @@ expand_across <- function(quo) {
fns = eval_tidy(expr$.fns, mask, env = env),
names = eval_tidy(expr$.names, mask, env = env),
.caller_env = dplyr_mask$get_caller_env(),
.top_level = TRUE,
inline = TRUE
)

Expand Down
68 changes: 9 additions & 59 deletions R/data-mask.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ DataMask <- R6Class("DataMask",
abort("Can't transform a data frame with duplicate names.")
}
names(data) <- names_bindings
private$all_vars <- names_bindings
private$data <- data
private$caller <- caller
private$current_data <- unclass(data)

private$chops <- .Call(dplyr_lazy_vec_chop_impl, data, rows)
private$mask <- .Call(dplyr_data_masks_setup, private$chops, data, rows)
Expand All @@ -28,7 +28,7 @@ DataMask <- R6Class("DataMask",

},

add_one = function(name, chunks) {
add_one = function(name, chunks, result) {
if (inherits(private$data, "rowwise_df")){
is_scalar_list <- function(.x) {
vec_is_list(.x) && length(.x) == 1L
Expand All @@ -38,14 +38,7 @@ DataMask <- R6Class("DataMask",
}
}

.Call(`dplyr_mask_add`, private, name, chunks)
},

add_many = function(ptype, chunks) {
chunks_extracted <- .Call(dplyr_extract_chunks, chunks, ptype)
map2(seq_along(ptype), names(ptype), function(j, nm) {
self$add_one(nm, chunks_extracted[[j]])
})
.Call(`dplyr_mask_add`, private, name, result, chunks)
},

remove = function(name) {
Expand Down Expand Up @@ -91,7 +84,7 @@ DataMask <- R6Class("DataMask",
},

current_vars = function() {
private$all_vars
names(private$current_data)
},

current_non_group_vars = function() {
Expand All @@ -111,7 +104,7 @@ DataMask <- R6Class("DataMask",
},

get_used = function() {
.Call(env_resolved, private$chops, private$all_vars)
.Call(env_resolved, private$chops, names(private$current_data))
},

unused_vars = function() {
Expand All @@ -125,46 +118,7 @@ DataMask <- R6Class("DataMask",
},

across_cols = function() {
original_data <- self$full_data()
original_data <- unclass(original_data)

across_vars <- self$current_non_group_vars()
unused_vars <- self$unused_vars()

across_vars_unused <- intersect(across_vars, unused_vars)
across_vars_used <- setdiff(across_vars, across_vars_unused)

# Pull unused vars from original data to keep from marking them as used.
# Column lengths will not match if `original_data` is grouped, but for
# the usage of tidyselect in `across()` we only need the column names
# and types to be correct.
cols_unused <- original_data[across_vars_unused]
cols_used <- self$current_cols(across_vars_used)

cols <- vec_c(cols_unused, cols_used)

# workaround until vctrs 0.3.5 is on CRAN
# (https://github.com/r-lib/vctrs/issues/1263)
if (length(cols) == 0) {
names(cols) <- character()
}

# Match original ordering
cols <- cols[across_vars]

cols
},

across_cache_get = function(key) {
private$across_cache[[key]]
},

across_cache_add = function(key, value) {
private$across_cache[[key]] <- value
},

across_cache_reset = function() {
private$across_cache <- list()
private$current_data[self$current_non_group_vars()]
},

forget = function(fn) {
Expand Down Expand Up @@ -217,9 +171,8 @@ DataMask <- R6Class("DataMask",
# in the parent environment of `mask`
mask = NULL,

# names of all the variables, this initially is names(data)
# grows (and sometimes shrinks) as new columns are added/removed
all_vars = character(),
# ptypes of all the variables
current_data = list(),

# names of the grouping variables
group_vars = character(),
Expand All @@ -231,9 +184,6 @@ DataMask <- R6Class("DataMask",
keys = NULL,

# caller environment of the verb (summarise(), ...)
caller = NULL,

# cache for across
across_cache = list()
caller = NULL
)
)
15 changes: 11 additions & 4 deletions R/mutate.R
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ mutate_cols <- function(.data, ..., caller_env) {

withCallingHandlers({
for (i in seq_along(dots)) {
mask$across_cache_reset()
context_poke("column", old_current_column)

# get results from all the quosures that are expanded from ..i
Expand Down Expand Up @@ -344,13 +343,21 @@ mutate_cols <- function(.data, ..., caller_env) {
chunks <- quo_result$chunks

if (!quo_data$is_named && is.data.frame(result)) {
new_columns[names(result)] <- result
mask$add_many(result, chunks)
types <- vec_ptype(result)
types_names <- names(types)
chunks_extracted <- .Call(dplyr_extract_chunks, chunks, types)

for (j in seq_along(types)) {
mask$add_one(types_names[j], chunks_extracted[[j]], result = result[[j]])
}

new_columns[types_names] <- result
} else {
# treat as a single output otherwise
name <- quo_data$name_auto
mask$add_one(name = name, chunks = chunks, result = result)

new_columns[[name]] <- result
mask$add_one(name, chunks)
}

}
Expand Down
30 changes: 19 additions & 11 deletions R/summarise.R
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,12 @@ summarise_cols <- function(.data, ..., caller_env) {
types <- vector("list", length(dots))

chunks <- list()
results <- list()
types <- list()
out_names <- character()

withCallingHandlers({
for (i in seq_along(dots)) {
mask$across_cache_reset()
context_poke("column", old_current_column)

quosures <- expand_across(dots[[i]])
Expand All @@ -248,8 +248,8 @@ summarise_cols <- function(.data, ..., caller_env) {
}
)
chunks_k <- vec_cast_common(!!!chunks_k, .to = types_k)

quosures_results[[k]] <- list(chunks = chunks_k, types = types_k)
result_k <- vec_c(!!!chunks_k, .ptype = types_k)
quosures_results[[k]] <- list(chunks = chunks_k, types = types_k, results = result_k)
}

for (k in seq_along(quosures)) {
Expand All @@ -262,35 +262,43 @@ summarise_cols <- function(.data, ..., caller_env) {
}
types_k <- quo_result$types
chunks_k <- quo_result$chunks
results_k <- quo_result$results

if (!quo_data$is_named && is.data.frame(types_k)) {
chunks_extracted <- .Call(dplyr_extract_chunks, chunks_k, types_k)

walk2(chunks_extracted, names(types_k), function(chunks_k_j, nm) {
mask$add_one(nm, chunks_k_j)
})
types_k_names <- names(types_k)
for (j in seq_along(chunks_extracted)) {
mask$add_one(
name = types_k_names[j],
chunks = chunks_extracted[[j]],
result = results_k[[j]]
)
}

chunks <- append(chunks, chunks_extracted)
types <- append(types, as.list(types_k))
out_names <- c(out_names, names(types_k))
results <- append(results, results_k)
out_names <- c(out_names, types_k_names)
} else {
name <- quo_data$name_auto
mask$add_one(name, chunks_k)
mask$add_one(name = name, chunks = chunks_k, result = results_k)
chunks <- append(chunks, list(chunks_k))
types <- append(types, list(types_k))
results <- append(results, list(results_k))
out_names <- c(out_names, name)
}

}
}

recycle_info <- .Call(`dplyr_summarise_recycle_chunks`, chunks, mask$get_rows(), types)
recycle_info <- .Call(`dplyr_summarise_recycle_chunks`, chunks, mask$get_rows(), types, results)
chunks <- recycle_info$chunks
sizes <- recycle_info$sizes
results <- recycle_info$results

# materialize columns
for (i in seq_along(chunks)) {
result <- vec_c(!!!chunks[[i]], .ptype = types[[i]])
result <- results[[i]] %||% vec_c(!!!chunks[[i]], .ptype = types[[i]])
cols[[ out_names[i] ]] <- result
}

Expand Down
6 changes: 3 additions & 3 deletions src/dplyr.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct symbols {
static SEXP current_expression;
static SEXP rows;
static SEXP caller;
static SEXP all_vars;
static SEXP current_data;
static SEXP dot_drop;
static SEXP abort_glue;
static SEXP dot_indices;
Expand Down Expand Up @@ -106,14 +106,14 @@ SEXP dplyr_mask_eval_all(SEXP quo, SEXP env_private);
SEXP dplyr_mask_eval_all_summarise(SEXP quo, SEXP env_private);
SEXP dplyr_mask_eval_all_mutate(SEXP quo, SEXP env_private);
SEXP dplyr_mask_eval_all_filter(SEXP quos, SEXP env_private, SEXP s_n, SEXP env_filter);
SEXP dplyr_summarise_recycle_chunks(SEXP chunks, SEXP rows, SEXP ptypes);
SEXP dplyr_summarise_recycle_chunks(SEXP chunks, SEXP rows, SEXP ptypes, SEXP results);
SEXP dplyr_group_indices(SEXP data, SEXP rows);
SEXP dplyr_group_keys(SEXP group_data);
SEXP dplyr_reduce_lgl_or(SEXP, SEXP);
SEXP dplyr_reduce_lgl_and(SEXP, SEXP);

SEXP dplyr_mask_remove(SEXP env_private, SEXP s_name);
SEXP dplyr_mask_add(SEXP env_private, SEXP s_name, SEXP chunks);
SEXP dplyr_mask_add(SEXP env_private, SEXP s_name, SEXP ptype, SEXP chunks);

SEXP dplyr_lazy_vec_chop(SEXP data, SEXP rows);
SEXP dplyr_data_masks_setup(SEXP chops, SEXP data, SEXP rows);
Expand Down
10 changes: 6 additions & 4 deletions src/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@ SEXP get_names_expanded() {
}

SEXP get_names_summarise_recycle_chunks(){
SEXP names = Rf_allocVector(STRSXP, 2);
SEXP names = Rf_allocVector(STRSXP, 3);
R_PreserveObject(names);

SET_STRING_ELT(names, 0, Rf_mkChar("chunks"));
SET_STRING_ELT(names, 1, Rf_mkChar("sizes"));
SET_STRING_ELT(names, 2, Rf_mkChar("results"));
return names;
}

Expand All @@ -44,7 +46,7 @@ SEXP symbols::dot_current_group = Rf_install(".current_group");
SEXP symbols::current_expression = Rf_install("current_expression");
SEXP symbols::rows = Rf_install("rows");
SEXP symbols::caller = Rf_install("caller");
SEXP symbols::all_vars = Rf_install("all_vars");
SEXP symbols::current_data = Rf_install("current_data");
SEXP symbols::dot_drop = Rf_install(".drop");
SEXP symbols::abort_glue = Rf_install("abort_glue");
SEXP symbols::dot_indices = Rf_install(".indices");
Expand Down Expand Up @@ -105,13 +107,13 @@ static const R_CallMethodDef CallEntries[] = {
{"dplyr_mask_eval_all_mutate", (DL_FUNC)& dplyr_mask_eval_all_mutate, 2},
{"dplyr_mask_eval_all_filter", (DL_FUNC)& dplyr_mask_eval_all_filter, 4},

{"dplyr_summarise_recycle_chunks", (DL_FUNC)& dplyr_summarise_recycle_chunks, 3},
{"dplyr_summarise_recycle_chunks", (DL_FUNC)& dplyr_summarise_recycle_chunks, 4},

{"dplyr_group_indices", (DL_FUNC)& dplyr_group_indices, 2},
{"dplyr_group_keys", (DL_FUNC)& dplyr_group_keys, 1},

{"dplyr_mask_remove", (DL_FUNC)& dplyr_mask_remove, 2},
{"dplyr_mask_add", (DL_FUNC)& dplyr_mask_add, 3},
{"dplyr_mask_add", (DL_FUNC)& dplyr_mask_add, 4},

{"dplyr_lazy_vec_chop_impl", (DL_FUNC)& dplyr_lazy_vec_chop, 2},
{"dplyr_data_masks_setup", (DL_FUNC)& dplyr_data_masks_setup, 3},
Expand Down
Loading

0 comments on commit ed2d8f5

Please sign in to comment.