forked from Const-me/Whisper
-
Notifications
You must be signed in to change notification settings - Fork 0
/
miscUtils.hlsli
84 lines (76 loc) · 2.17 KB
/
miscUtils.hlsli
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
// When GPUs are converting FP32 to FP16, they always truncate towards 0, documented there:
// https://learn.microsoft.com/en-us/windows/win32/direct3d10/d3d10-graphics-programming-guide-resources-data-conversion#conververting-from-a-higher-range-representation-to-a-lower-range-representation
// Whisper code uses _mm_cvtps_ph( x, 0 ), the 0 stands for "Round to nearest even": https://www.felixcloutier.com/x86/vcvtps2ph
// This function adjusts FP32 value making it so that truncation towards 0 results in the value equal to what CPU is doing
inline float adjustFp16( const float src )
{
const uint trunc16 = f32tof16( src );
const float trunc32 = f16tof32( trunc16 );
const uint truncExp = ( trunc16 >> 10 ) & 0x1F;
if( truncExp != 0x1F )
{
const uint next16 = trunc16 + 1;
const float next32 = f16tof32( next16 );
const float errTrunc = abs( src - trunc32 );
const float errNext = abs( src - next32 );
if( errTrunc < errNext )
{
// Truncated was closer to the source
return src;
}
else if( errTrunc > errNext )
{
// Truncated + 1 was closer to the source
return next32;
}
else
{
// Exactly half, doing banker's rounding to nearest even
return ( 0 == ( trunc16 & 1 ) ) ? src : next32;
}
}
else
{
// INF or NAN
return src;
}
}
// Convert FP32 number to FP16, using rounding to nearest
inline uint fp16Rounded( const float src )
{
const uint trunc16 = f32tof16( src );
const float trunc32 = f16tof32( trunc16 );
const uint truncExp = ( trunc16 >> 10 ) & 0x1F;
if( truncExp != 0x1F )
{
const uint next16 = trunc16 + 1;
const float next32 = f16tof32( next16 );
const float errTrunc = abs( src - trunc32 );
const float errNext = abs( src - next32 );
if( errTrunc < errNext )
{
// Truncated was closer to the source
return trunc16;
}
else if( errTrunc > errNext )
{
// Truncated + 1 was closer to the source
return next16;
}
else
{
// Exactly half, doing banker's rounding to nearest even
return ( 0 == ( trunc16 & 1 ) ) ? trunc16 : next16;
}
}
else
{
// INF or NAN
return trunc16;
}
}
// Round up the number to be a multiple of 32
inline uint roundUp32( uint x )
{
return ( x + 31 ) & ( ~31u );
}