Skip to content

Commit

Permalink
whisper : do not launch log_mel threads when n_thread is 1 (ggerganov…
Browse files Browse the repository at this point in the history
  • Loading branch information
maxilevi authored Apr 14, 2023
1 parent d88461b commit 6ad5c17
Showing 1 changed file with 66 additions and 71 deletions.
137 changes: 66 additions & 71 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2284,6 +2284,60 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
}
}

static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> &hann, const float *samples,
int n_samples, int fft_size, int fft_step, int n_threads,
const whisper_filters &filters, bool speed_up, whisper_mel &mel) {
std::vector<float> fft_in(fft_size, 0.0);
std::vector<float> fft_out(2 * fft_size);
int n_fft = 1 + (speed_up ? fft_size / 4 : fft_size / 2);

for (int i = ith; i < mel.n_len; i += n_threads) {
const int offset = i * fft_step;

// apply Hanning window
for (int j = 0; j < fft_size; j++) {
if (offset + j < n_samples) {
fft_in[j] = hann[j] * samples[offset + j];
} else {
fft_in[j] = 0.0;
}
}

// FFT -> mag^2
fft(fft_in, fft_out);

for (int j = 0; j < fft_size; j++) {
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
}
for (int j = 1; j < fft_size / 2; j++) {
fft_out[j] += fft_out[fft_size - j];
}

if (speed_up) {
// scale down in the frequency domain results in a speed up in the time domain
for (int j = 0; j < n_fft; j++) {
fft_out[j] = 0.5 * (fft_out[2 * j] + fft_out[2 * j + 1]);
}
}

// mel spectrogram
for (int j = 0; j < mel.n_mel; j++) {
double sum = 0.0;

for (int k = 0; k < n_fft; k++) {
sum += fft_out[k] * filters.data[j * n_fft + k];
}
if (sum < 1e-10) {
sum = 1e-10;
}

sum = log10(sum);

mel.data[j * mel.n_len + i] = sum;
}
}
}

// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
static bool log_mel_spectrogram(
whisper_state & wstate,
Expand All @@ -2310,81 +2364,22 @@ static bool log_mel_spectrogram(
mel.n_len = (n_samples)/fft_step;
mel.data.resize(mel.n_mel*mel.n_len);

const int n_fft = 1 + (speed_up ? fft_size/4 : fft_size/2);

//printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
//printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);

std::vector<std::thread> workers(n_threads);
for (int iw = 0; iw < n_threads; ++iw) {
workers[iw] = std::thread([&](int ith) {
std::vector<float> fft_in;
fft_in.resize(fft_size);
for (int i = 0; i < fft_size; i++) {
fft_in[i] = 0.0;
}

std::vector<float> fft_out;
fft_out.resize(2*fft_size);

for (int i = ith; i < mel.n_len; i += n_threads) {
const int offset = i*fft_step;

// apply Hanning window
for (int j = 0; j < fft_size; j++) {
if (offset + j < n_samples) {
fft_in[j] = hann[j]*samples[offset + j];
} else {
fft_in[j] = 0.0;
}
}

// FFT -> mag^2
fft(fft_in, fft_out);

for (int j = 0; j < fft_size; j++) {
fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]);
}
for (int j = 1; j < fft_size/2; j++) {
//if (i == 0) {
// printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]);
//}
fft_out[j] += fft_out[fft_size - j];
}
if (i == 0) {
//for (int j = 0; j < fft_size; j++) {
// printf("%d: %e\n", j, fft_out[j]);
//}
}

if (speed_up) {
// scale down in the frequency domain results in a speed up in the time domain
for (int j = 0; j < n_fft; j++) {
fft_out[j] = 0.5*(fft_out[2*j] + fft_out[2*j + 1]);
}
}

// mel spectrogram
for (int j = 0; j < mel.n_mel; j++) {
double sum = 0.0;

for (int k = 0; k < n_fft; k++) {
sum += fft_out[k]*filters.data[j*n_fft + k];
}
if (sum < 1e-10) {
sum = 1e-10;
}

sum = log10(sum);

mel.data[j*mel.n_len + i] = sum;
}
}
}, iw);
}
if (n_threads == 1) {
log_mel_spectrogram_worker_thread(0, hann, samples, n_samples, fft_size, fft_step, n_threads, filters, speed_up, mel);
} else {
std::vector<std::thread> workers(n_threads);
for (int iw = 0; iw < n_threads; ++iw) {
workers[iw] = std::thread(log_mel_spectrogram_worker_thread, iw, std::cref(hann), samples,
n_samples, fft_size, fft_step, n_threads,
std::cref(filters), speed_up, std::ref(mel));
}

for (int iw = 0; iw < n_threads; ++iw) {
workers[iw].join();
for (int iw = 0; iw < n_threads; ++iw) {
workers[iw].join();
}
}

// clamping and normalization
Expand Down

0 comments on commit 6ad5c17

Please sign in to comment.