| import torch |
| from speechbrain.inference.interfaces import Pretrained |
|
|
|
|
| class CustomEncoderBestRQ(Pretrained): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| def encode_batch(self, wavs, wav_lens=None, normalize=False): |
| |
| if len(wavs.shape) == 1: |
| wavs = wavs.unsqueeze(0) |
|
|
| |
| if wav_lens is None: |
| wav_lens = torch.ones(wavs.shape[0], device=self.device) |
|
|
| |
| wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) |
| wavs = wavs.float() |
|
|
| feats = self.hparams.compute_features(wavs) |
| feats = self.mods.normalizer(feats, wav_lens) |
| src = self.mods.extractor(feats) |
| enc_out = self.mods.encoder(src, wav_lens) |
| return enc_out |
|
|
| def encode_file(self, path, normalize=False): |
| waveform = self.load_audio(path) |
| |
| batch = waveform.unsqueeze(0) |
| rel_length = torch.tensor([1.0]) |
| outputs = self.encode_batch(batch, rel_length) |
| return outputs |
|
|
| def forward(self, wavs, wav_lens=None, normalize=False): |
| return self.encode_batch(wavs=wavs, wav_lens=wav_lens, normalize=normalize) |
|
|