Skip to content

Commit

Permalink
simplify hybrid mean(), var(), and sum()
Browse files Browse the repository at this point in the history
  • Loading branch information
krlmlr committed Feb 2, 2018
1 parent 88a886b commit 54b8191
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 138 deletions.
76 changes: 22 additions & 54 deletions inst/include/dplyr/Result/Mean.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,47 @@
#include <dplyr/Result/Processor.h>

namespace dplyr {

namespace internal {

// version for NA_RM == true
template <int RTYPE, bool NA_RM, typename Index>
struct Mean_internal {
static double process(typename Rcpp::traits::storage_type<RTYPE>::type* ptr, const Index& indices) {
typedef typename Rcpp::traits::storage_type<RTYPE>::type STORAGE;
long double res = 0.0;
int n = indices.size();
int m = 0;
int m = n;
for (int i = 0; i < n; i++) {
STORAGE value = ptr[ indices[i] ];
if (! Rcpp::traits::is_na<RTYPE>(value)) {
res += value;
m++;

// REALSXP and !NA_RM: we don't test for NA here because += NA will give NA
// this is faster in the most common case where there are no NA
// if there are NA, we could return quicker as in the version for
// INTSXP, but we would penalize the most common case
//
// INTSXP: no shortcut, need to test
if (NA_RM || RTYPE == INTSXP) {
if (Rcpp::traits::is_na<RTYPE>(value)) {
if (!NA_RM) {
return NA_REAL;
}

--m;
continue;
}
}

res += value;
}
if (m == 0) return R_NaN;
res /= m;

// Correcting accuracy of result, see base R implementation
if (R_FINITE(res)) {
long double t = 0.0;
for (int i = 0; i < n; i++) {
STORAGE value = ptr[indices[i]];
if (! Rcpp::traits::is_na<RTYPE>(value)) {
if (!NA_RM || ! Rcpp::traits::is_na<RTYPE>(value)) {
t += value - res;
}
}
Expand All @@ -39,54 +55,6 @@ struct Mean_internal {
}
};

// special cases for NA_RM == false
template <typename Index>
struct Mean_internal<INTSXP, false, Index> {
static double process(int* ptr, const Index& indices) {
long double res = 0.0;
int n = indices.size();
for (int i = 0; i < n; i++) {
int value = ptr[ indices[i] ];
// need to handle missing value specifically
if (value == NA_INTEGER) {
return NA_REAL;
}
res += value;
}
res /= n;

if (R_FINITE((double)res)) {
long double t = 0.0;
for (int i = 0; i < n; i++) {
t += ptr[indices[i]] - res;
}
res += t / n;
}
return (double)res;
}
};

template <typename Index>
struct Mean_internal<REALSXP, false, Index> {
static double process(double* ptr, const Index& indices) {
long double res = 0.0;
int n = indices.size();
for (int i = 0; i < n; i++) {
res += ptr[ indices[i] ];
}
res /= n;

if (R_FINITE((double)res)) {
long double t = 0.0;
for (int i = 0; i < n; i++) {
t += ptr[indices[i]] - res;
}
res += t / n;
}
return (double)res;
}
};

} // namespace internal

template <int RTYPE, bool NA_RM>
Expand Down
72 changes: 29 additions & 43 deletions inst/include/dplyr/Result/Sum.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ namespace dplyr {

namespace internal {

// this one is actually only used for RTYPE = REALSXP and NA_RM = true
template <int RTYPE, bool NA_RM, typename Index>
struct Sum {
typedef typename Rcpp::traits::storage_type<RTYPE>::type STORAGE;
Expand All @@ -16,66 +15,53 @@ struct Sum {
int n = indices.size();
for (int i = 0; i < n; i++) {
double value = ptr[indices[i]];
if (! Rcpp::traits::is_na<RTYPE>(value)) res += value;
}
return (double)res;
}
};

template <typename Index>
struct Sum<INTSXP, true, Index> {
static int process(int* ptr, const Index& indices) {
long double res = 0;
int n = indices.size();
for (int i = 0; i < n; i++) {
int value = ptr[indices[i]];
if (! Rcpp::traits::is_na<INTSXP>(value)) res += value;
}
if (res > INT_MAX || res <= INT_MIN) {
warning("integer overflow - use sum(as.numeric(.))");
return IntegerVector::get_na();
// !NA_RM: we don't test for NA here because += NA will give NA
// this is faster in the most common case where there are no NA
// if there are NA, we could return quicker as in the version for
// INTSXP, but we would penalize the most common case
if (NA_RM && Rcpp::traits::is_na<RTYPE>(value)) {
continue;
}

res += value;
}
return (int)res;

return (STORAGE)res;
}
};

template <typename Index>
struct Sum<INTSXP, false, Index> {
static int process(int* ptr, const Index& indices) {
// Special case for INTSXP:
template <bool NA_RM, typename Index>
struct Sum<INTSXP, NA_RM, Index> {
enum { RTYPE = INTSXP };
typedef typename Rcpp::traits::storage_type<RTYPE>::type STORAGE;
static STORAGE process(typename Rcpp::traits::storage_type<RTYPE>::type* ptr, const Index& indices) {
long double res = 0;
int n = indices.size();
for (int i = 0; i < n; i++) {
int value = ptr[indices[i]];
if (Rcpp::traits::is_na<INTSXP>(value)) {
return NA_INTEGER;
double value = ptr[indices[i]];

if (Rcpp::traits::is_na<RTYPE>(value)) {
if (NA_RM) {
continue;
}

return Rcpp::traits::get_na<RTYPE>();
}

res += value;
}

if (res > INT_MAX || res <= INT_MIN) {
warning("integer overflow - use sum(as.numeric(.))");
return IntegerVector::get_na();
return Rcpp::traits::get_na<RTYPE>();
}
return (int)res;
}
};

template <typename Index>
struct Sum<REALSXP, false, Index> {
static double process(double* ptr, const Index& indices) {
long double res = 0.0;
int n = indices.size();
for (int i = 0; i < n; i++) {
// we don't test for NA here because += NA will give NA
// this is faster in the most common case where there are no NA
// if there are NA, we could return quicker as in the version for
// INTSXP above, but we would penalize the most common case
res += ptr[ indices[i] ];
}
return (double)res;
return (STORAGE)res;
}
};


} // namespace internal

template <int RTYPE, bool NA_RM>
Expand Down
45 changes: 6 additions & 39 deletions inst/include/dplyr/Result/Var.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
#include <dplyr/Result/Processor.h>

namespace dplyr {

namespace internal {

inline double square(double x) {
return x * x;
}

}

// version for NA_RM = false
template <int RTYPE, bool NA_RM>
class Var : public Processor<REALSXP, Var<RTYPE, NA_RM> > {
public:
Expand All @@ -25,62 +27,27 @@ class Var : public Processor<REALSXP, Var<RTYPE, NA_RM> > {

inline double process_chunk(const SlicingIndex& indices) {
int n = indices.size();
if (n == 1) return NA_REAL;
if (n <= 1) return NA_REAL;
double m = internal::Mean_internal<RTYPE, NA_RM, SlicingIndex>::process(data_ptr, indices);

if (!R_FINITE(m)) return m;

double sum = 0.0;
for (int i = 0; i < n; i++) {
sum += internal::square(data_ptr[indices[i]] - m);
}
return sum / (n - 1);
}

private:
STORAGE* data_ptr;
};


// version for NA_RM = true
template <int RTYPE>
class Var<RTYPE, true> : public Processor<REALSXP, Var<RTYPE, true> > {
public:
typedef Processor<REALSXP, Var<RTYPE, true> > Base;
typedef typename Rcpp::traits::storage_type<RTYPE>::type STORAGE;

explicit Var(SEXP x) :
Base(x),
data_ptr(Rcpp::internal::r_vector_start<RTYPE>(x))
{}
~Var() {}

inline double process_chunk(const SlicingIndex& indices) {
int n = indices.size();
if (n == 1) return NA_REAL;
double m = internal::Mean_internal<RTYPE, true, SlicingIndex>::process(data_ptr, indices);

if (!R_FINITE(m)) return m;

double sum = 0.0;
int count = 0;
for (int i = 0; i < n; i++) {
STORAGE current = data_ptr[indices[i]];
if (Rcpp::Vector<RTYPE>::is_na(current)) continue;
if (NA_RM && Rcpp::Vector<RTYPE>::is_na(current)) continue;
sum += internal::square(current - m);
count++;
}
if (count == 1) return NA_REAL;
if (count <= 1) return NA_REAL;
return sum / (count - 1);
}

private:
STORAGE* data_ptr;
};




}

#endif
4 changes: 2 additions & 2 deletions tests/testthat/test-summarise.r
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,8 @@ test_that("summarise creates an empty data frame when no parameters are used", {
})

test_that("summarise works with zero-row data frames", {
res <- summarise(mtcars[0, ], n = n())
expect_equal(res, data.frame(n = 0L))
res <- summarise(mtcars[0, ], n = n(), sum = sum(cyl), mean = mean(mpg), var = var(drat))
expect_equal(res, data.frame(n = 0L, sum = 0, mean = NaN, var = NA_real_))
})

test_that("summarise works with zero-column data frames (#3071)", {
Expand Down

0 comments on commit 54b8191

Please sign in to comment.