Skip to content

Commit

Permalink
<algorithm>: find/count vectorize moar (microsoft#3267)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicole Mazzuca <[email protected]>
Co-authored-by: Stephan T. Lavavej <[email protected]>
  • Loading branch information
3 people committed Dec 15, 2022
1 parent 032120d commit c3217d2
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 11 deletions.
43 changes: 34 additions & 9 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -5563,17 +5563,42 @@ _NODISCARD constexpr auto lexicographical_compare_three_way(
}
#endif // __cpp_lib_concepts

template <class _Ty, class _Elem>
struct _Vector_alg_in_find_is_safe_object_pointers : false_type {};
template <class _Ty1, class _Ty2>
struct _Vector_alg_in_find_is_safe_object_pointers<_Ty1*, _Ty2*>
: conjunction<
// _Ty1* is an object pointer type
disjunction<is_object<_Ty1>, is_void<_Ty1>>,
// _Ty2* is an object pointer type
disjunction<is_object<_Ty2>, is_void<_Ty2>>,
// either _Ty1 is the same as _Ty2 (ignoring cv-qualifiers), or one of the two is void
disjunction<is_same<remove_cv_t<_Ty1>, remove_cv_t<_Ty2>>, is_void<_Ty1>, is_void<_Ty2>>> {};

// Can we activate the vector algorithms for find/count?
template <class _Iter, class _Ty, class _Elem = _Iter_value_t<_Iter>>
_INLINE_VAR constexpr bool _Vector_alg_in_find_is_safe = // Can we activate the vector algorithms for find/count?
_Iterator_is_contiguous<_Iter> // The iterator must be contiguous so we can get raw pointers.
&& !_Iterator_is_volatile<_Iter> // The iterator must not be volatile.
&& disjunction_v< // And one of the following conditions must be met:
_INLINE_VAR constexpr bool _Vector_alg_in_find_is_safe =
// The iterator must be contiguous so we can get raw pointers.
_Iterator_is_contiguous<_Iter>
// The iterator must not be volatile.
&& !_Iterator_is_volatile<_Iter>
// And one of the following conditions must be met:
&& disjunction_v<
#ifdef __cpp_lib_byte
conjunction<is_same<_Ty, byte>, is_same<_Elem, byte>>, // We're finding a std::byte in a range of std::byte.
// We're finding a std::byte in a range of std::byte.
conjunction<is_same<_Ty, byte>, is_same<_Elem, byte>>,
#endif // __cpp_lib_byte
conjunction<is_integral<_Ty>, is_integral<_Elem>>, // We're finding an integer in a range of integers.
// The integer types can be different, which requires careful handling.
conjunction<is_pointer<_Ty>, is_same<_Ty, _Elem>>>; // We're finding a U* in a range of U* (identical types).
// We're finding an integer in a range of integers.
// This case is the one that requires careful runtime handling in _Could_compare_equal_to_value_type.
conjunction<is_integral<_Ty>, is_integral<_Elem>>,
// We're finding an (object or function) pointer in a range of pointers of the same type.
conjunction<is_pointer<_Ty>, is_same<_Ty, _Elem>>,
// We're finding a nullptr in a range of (object or function) pointers.
conjunction<is_same<_Ty, nullptr_t>, is_pointer<_Elem>>,
// We're finding an object pointer in a range of object pointers, and:
// - One of the pointer types is a cv void*.
// - One of the pointer types is a cv1 U* and the other is a cv2 U*.
_Vector_alg_in_find_is_safe_object_pointers<_Ty, _Elem>>;

template <class _InIt, class _Ty>
_NODISCARD constexpr bool _Could_compare_equal_to_value_type(const _Ty& _Val) {
Expand All @@ -5584,7 +5609,7 @@ _NODISCARD constexpr bool _Could_compare_equal_to_value_type(const _Ty& _Val) {
#ifdef __cpp_lib_byte
is_same<_Ty, byte>,
#endif // __cpp_lib_byte
is_same<_Ty, bool>, is_pointer<_Ty>>) {
is_same<_Ty, bool>, is_pointer<_Ty>, is_same<_Ty, nullptr_t>>) {
return true;
} else {
using _Elem = _Iter_value_t<_InIt>;
Expand Down
51 changes: 49 additions & 2 deletions tests/std/tests/Dev11_0316853_find_memchr_optimization/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,20 +475,67 @@ int main() {
{ // Test pointers
const char* s = "xxxyyy";
const char* arr[]{s, s + 1, s + 1, s + 5, s, s + 4, nullptr};
const void* arr_void[]{s, s + 1, s + 1, s + 5, s, s + 4, nullptr};

static_assert(_Vector_alg_in_find_is_safe<decltype(begin(arr)), decltype(s + 1)>, "should optimize");
static_assert(!_Vector_alg_in_find_is_safe<decltype(begin(arr)), nullptr_t>, "should not optimize");

static_assert(_Vector_alg_in_find_is_safe<decltype(begin(arr)), char*>, "should optimize");
static_assert(_Vector_alg_in_find_is_safe<decltype(begin(arr)), void*>, "should optimize");
static_assert(_Vector_alg_in_find_is_safe<decltype(begin(arr)), const volatile void*>, "should optimize");
static_assert(_Vector_alg_in_find_is_safe<decltype(begin(arr)), nullptr_t>, "should optimize");
static_assert(_Vector_alg_in_find_is_safe<decltype(begin(arr_void)), char*>, "should optimize");
static_assert(_Vector_alg_in_find_is_safe<decltype(begin(arr_void)), const char*>, "should optimize");
static_assert(_Vector_alg_in_find_is_safe<decltype(begin(arr_void)), void*>, "should optimize");
static_assert(_Vector_alg_in_find_is_safe<decltype(begin(arr_void)), const void*>, "should optimize");

// const char pointer range
assert(find(begin(arr), end(arr), s) == begin(arr));
assert(find(begin(arr), end(arr), const_cast<char*>(s)) == begin(arr));
assert(find(begin(arr), end(arr), const_cast<volatile char*>(s)) == begin(arr));

assert(find(begin(arr), end(arr), s + 1) == begin(arr) + 1);
assert(find(begin(arr), end(arr), static_cast<const void*>(s + 1)) == begin(arr) + 1);

assert(find(begin(arr), end(arr), s + 3) == end(arr));
assert(find(begin(arr), end(arr), static_cast<const void*>(s + 3)) == end(arr));

assert(find(begin(arr), end(arr), static_cast<const char*>(nullptr)) == begin(arr) + 6);
assert(find(begin(arr), end(arr), static_cast<const void*>(nullptr)) == begin(arr) + 6);
assert(find(begin(arr), end(arr), nullptr) == begin(arr) + 6);

assert(count(begin(arr), end(arr), s + 1) == 2);
assert(count(begin(arr), end(arr), s + 5) == 1);
assert(count(begin(arr), end(arr), s + 3) == 0);
assert(count(begin(arr), end(arr), static_cast<const char*>(nullptr)) == 1);
assert(count(begin(arr), end(arr), nullptr) == 1);

// const void pointer range
assert(find(begin(arr_void), end(arr_void), s) == begin(arr_void));
assert(find(begin(arr_void), end(arr_void), const_cast<char*>(s)) == begin(arr_void));
assert(find(begin(arr_void), end(arr_void), const_cast<volatile char*>(s)) == begin(arr_void));

assert(find(begin(arr_void), end(arr_void), s + 1) == begin(arr_void) + 1);
assert(find(begin(arr_void), end(arr_void), static_cast<const void*>(s + 1)) == begin(arr_void) + 1);

assert(find(begin(arr_void), end(arr_void), s + 3) == end(arr_void));
assert(find(begin(arr_void), end(arr_void), static_cast<const void*>(s + 3)) == end(arr_void));

assert(find(begin(arr_void), end(arr_void), static_cast<const char*>(nullptr)) == begin(arr_void) + 6);
assert(find(begin(arr_void), end(arr_void), static_cast<const void*>(nullptr)) == begin(arr_void) + 6);
assert(find(begin(arr_void), end(arr_void), nullptr) == begin(arr_void) + 6);

assert(count(begin(arr_void), end(arr_void), s + 1) == 2);
assert(count(begin(arr_void), end(arr_void), s + 5) == 1);
assert(count(begin(arr_void), end(arr_void), s + 3) == 0);
assert(count(begin(arr_void), end(arr_void), static_cast<const char*>(nullptr)) == 1);
assert(count(begin(arr_void), end(arr_void), nullptr) == 1);
}

{ // random other checks for _Vector_alg_in_find_is_safe
static_assert(!_Vector_alg_in_find_is_safe<void (**)(), void*>, "should not optimize");
static_assert(!_Vector_alg_in_find_is_safe<void**, void (*)()>, "should not optimize");
static_assert(_Vector_alg_in_find_is_safe<void (**)(), void (*)()>, "should optimize");
static_assert(_Vector_alg_in_find_is_safe<int (**)(int), int (*)(int)>, "should optimize");
static_assert(!_Vector_alg_in_find_is_safe<void (**)(int), int (*)(int)>, "should not optimize");
static_assert(!_Vector_alg_in_find_is_safe<int (**)(), int (*)(int)>, "should not optimize");
}
}

0 comments on commit c3217d2

Please sign in to comment.