Skip to content

Commit

Permalink
implemented torchaudio based log Mel fb features
Browse files Browse the repository at this point in the history
* used torchaudio as much as possible
* could not use tochaudios MelSpectrogram directly since it has fixed
  padding
   - the computation here is exactly the same appart from the padding
  • Loading branch information
menne committed Jul 16, 2019
1 parent 0cf520d commit 5ce9f13
Showing 1 changed file with 24 additions and 7 deletions.
31 changes: 24 additions & 7 deletions neural_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,21 +651,38 @@ def forward(self, x):
class logMelFb(nn.Module):
def __init__(self, options,inp_dim):
super(logMelFb, self).__init__()
import torchaudio
self._sample_rate = int(options['logmelfb_nr_sample_rate'])
self._nr_of_filters = int(options['logmelfb_nr_filt'])
self._stft_window_size = int(options['logmelfb_stft_window_size'])
self._stft_window_shift = int(options['logmelfb_stft_window_shift'])
self._use_cuda = strtobool(options['use_cuda'])
self.out_dim = self._nr_of_filters
self._mspec = torchaudio.transforms.MelSpectrogram(
sr=self._sample_rate,
n_fft=self._stft_window_size,
ws=self._stft_window_size,
hop=self._stft_window_shift,
n_mels=self._nr_of_filters,
)

def forward(self, x):
nr_of_frames = int((x.shape[0] - self._stft_window_size) / float(self._stft_window_shift) + 1)
out = torch.zeros(nr_of_frames, x.shape[1], self.out_dim).contiguous()
for frame in range(nr_of_frames):
frame_start = frame * self._stft_window_shift
frame_end = frame_start + self._stft_window_size
out[frame, :, :] = x[frame_start:frame_start + self.out_dim, :, :].squeeze(-1).transpose(0, 1)
assert x.shape[-1] == 1, 'Multi channel time signal processing not suppored yet'
x_reshape_for_stft = torch.squeeze(x, -1).transpose(0, 1)
if self._use_cuda:
out=out.cuda()
window = self._mspec.window(self._stft_window_size).cuda()
else:
window = self._mspec.window(self._stft_window_size)
x_stft = torch.stft(
x_reshape_for_stft,
self._stft_window_size,
hop_length = self._stft_window_shift,
center = False,
window = window,
)
x_power_stft = x_stft.pow(2).sum(-1)
x_power_stft_reshape_for_filterbank_mult = x_power_stft.transpose(1, 2)
out = self._mspec.fm(x_power_stft_reshape_for_filterbank_mult).transpose(0, 1)
return out

class liGRU(nn.Module):
Expand Down

0 comments on commit 5ce9f13

Please sign in to comment.