Skip to content

Commit

Permalink
fix long shuffle implementations for windows (ROCm#1895)
Browse files Browse the repository at this point in the history
Fixes for SWDEV-223694
  • Loading branch information
Nick Curtis authored Feb 26, 2020
1 parent 69404d8 commit b7dd073
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions include/hip/hcc_detail/device_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ __device__
inline
long __shfl(long var, int src_lane, int width = warpSize)
{
#ifndef _MSC_VER
static_assert(sizeof(long) == 2 * sizeof(int), "");
static_assert(sizeof(long) == sizeof(uint64_t), "");

Expand All @@ -333,6 +334,10 @@ long __shfl(long var, int src_lane, int width = warpSize)
uint64_t tmp0 = (static_cast<uint64_t>(tmp[1]) << 32ull) | static_cast<uint32_t>(tmp[0]);
long tmp1; __builtin_memcpy(&tmp1, &tmp0, sizeof(tmp0));
return tmp1;
#else
static_assert(sizeof(long) == sizeof(int), "");
return static_cast<long>(__shfl(static_cast<int>(var), src_lane, width));
#endif
}
__device__
inline
Expand Down Expand Up @@ -390,6 +395,7 @@ __device__
inline
long __shfl_up(long var, unsigned int lane_delta, int width = warpSize)
{
#ifndef _MSC_VER
static_assert(sizeof(long) == 2 * sizeof(int), "");
static_assert(sizeof(long) == sizeof(uint64_t), "");

Expand All @@ -400,6 +406,10 @@ long __shfl_up(long var, unsigned int lane_delta, int width = warpSize)
uint64_t tmp0 = (static_cast<uint64_t>(tmp[1]) << 32ull) | static_cast<uint32_t>(tmp[0]);
long tmp1; __builtin_memcpy(&tmp1, &tmp0, sizeof(tmp0));
return tmp1;
#else
static_assert(sizeof(long) == sizeof(int), "");
return static_cast<long>(__shfl_up(static_cast<int>(var), lane_delta, width));
#endif
}
__device__
inline
Expand Down Expand Up @@ -455,6 +465,7 @@ __device__
inline
long __shfl_down(long var, unsigned int lane_delta, int width = warpSize)
{
#ifndef _MSC_VER
static_assert(sizeof(long) == 2 * sizeof(int), "");
static_assert(sizeof(long) == sizeof(uint64_t), "");

Expand All @@ -465,6 +476,10 @@ long __shfl_down(long var, unsigned int lane_delta, int width = warpSize)
uint64_t tmp0 = (static_cast<uint64_t>(tmp[1]) << 32ull) | static_cast<uint32_t>(tmp[0]);
long tmp1; __builtin_memcpy(&tmp1, &tmp0, sizeof(tmp0));
return tmp1;
#else
static_assert(sizeof(long) == sizeof(int), "");
return static_cast<long>(__shfl_down(static_cast<int>(var), lane_delta, width));
#endif
}
__device__
inline
Expand Down Expand Up @@ -520,6 +535,7 @@ __device__
inline
long __shfl_xor(long var, int lane_mask, int width = warpSize)
{
#ifndef _MSC_VER
static_assert(sizeof(long) == 2 * sizeof(int), "");
static_assert(sizeof(long) == sizeof(uint64_t), "");

Expand All @@ -530,6 +546,10 @@ long __shfl_xor(long var, int lane_mask, int width = warpSize)
uint64_t tmp0 = (static_cast<uint64_t>(tmp[1]) << 32ull) | static_cast<uint32_t>(tmp[0]);
long tmp1; __builtin_memcpy(&tmp1, &tmp0, sizeof(tmp0));
return tmp1;
#else
static_assert(sizeof(long) == sizeof(int), "");
return static_cast<long>(__shfl_down(static_cast<int>(var), lane_delta, width));
#endif
}
__device__
inline
Expand Down

0 comments on commit b7dd073

Please sign in to comment.