| | import os |
| | import librosa |
| | import torch |
| | import torch.nn.functional as F |
| | import numpy as np |
| | import soundfile as sf |
| | from glob import glob |
| | from tqdm import tqdm |
| | from os.path import basename, join, exists |
| | from vq.codec_encoder import CodecEncoder |
| | |
| | from vq.codec_decoder_vocos import CodecDecoderVocos |
| | from argparse import ArgumentParser |
| | from time import time |
| | from transformers import AutoModel |
| | import torch.nn as nn |
| | from vq.module import SemanticDecoder,SemanticEncoder |
| | from transformers import AutoFeatureExtractor, Wav2Vec2BertModel |
| | if __name__ == '__main__': |
| | parser = ArgumentParser() |
| | parser.add_argument('--input-dir', type=str, default='test_audio/input_test') |
| | parser.add_argument('--ckpt', type=str, default='ckpt/epoch=4-step=1400000.ckpt') |
| | parser.add_argument('--output-dir', type=str, default='test_audio/output_test') |
| | |
| | args = parser.parse_args() |
| | sr = 16000 |
| |
|
| | print(f'Load codec ckpt from {args.ckpt}') |
| | ckpt = torch.load(args.ckpt, map_location='cpu') |
| | ckpt=ckpt['state_dict'] |
| |
|
| | state_dict = ckpt |
| | from collections import OrderedDict |
| | |
| | filtered_state_dict_codec = OrderedDict() |
| | filtered_state_dict_semantic_encoder = OrderedDict() |
| | filtered_state_dict_gen = OrderedDict() |
| | filtered_state_dict_fc_post_a = OrderedDict() |
| | filtered_state_dict_fc_prior = OrderedDict() |
| | for key, value in state_dict.items(): |
| | if key.startswith('CodecEnc.'): |
| | |
| | new_key = key[len('CodecEnc.'):] |
| | filtered_state_dict_codec[new_key] = value |
| | elif key.startswith('generator.'): |
| | |
| | new_key = key[len('generator.'):] |
| | filtered_state_dict_gen[new_key] = value |
| | elif key.startswith('fc_post_a.'): |
| | |
| | new_key = key[len('fc_post_a.'):] |
| | filtered_state_dict_fc_post_a[new_key] = value |
| | elif key.startswith('SemanticEncoder_module.'): |
| | |
| | new_key = key[len('SemanticEncoder_module.'):] |
| | filtered_state_dict_semantic_encoder[new_key] = value |
| | elif key.startswith('fc_prior.'): |
| | |
| | new_key = key[len('fc_prior.'):] |
| | filtered_state_dict_fc_prior[new_key] = value |
| | semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", output_hidden_states=True) |
| |
|
| | semantic_model.eval().cuda() |
| | SemanticEncoder_module = SemanticEncoder(1024,1024,1024) |
| | SemanticEncoder_module.load_state_dict(filtered_state_dict_semantic_encoder) |
| | SemanticEncoder_module = SemanticEncoder_module.eval().cuda() |
| | encoder = CodecEncoder() |
| | encoder.load_state_dict(filtered_state_dict_codec) |
| | encoder = encoder.eval().cuda() |
| | decoder = CodecDecoderVocos() |
| | decoder.load_state_dict(filtered_state_dict_gen) |
| | decoder = decoder.eval().cuda() |
| | fc_post_a = nn.Linear( 2048, 1024 ) |
| | fc_post_a.load_state_dict(filtered_state_dict_fc_post_a) |
| | fc_post_a = fc_post_a.eval().cuda() |
| | fc_prior = nn.Linear( 2048, 2048 ) |
| | fc_prior.load_state_dict(filtered_state_dict_fc_prior) |
| | fc_prior = fc_prior.eval().cuda() |
| | feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0") |
| |
|
| | wav_dir = args.output_dir |
| | os.makedirs(wav_dir, exist_ok=True) |
| |
|
| | |
| | |
| | |
| | |
| | wav_paths = glob(os.path.join(args.input_dir, '**', '*.wav'), recursive=True) |
| | flac_paths = glob(os.path.join(args.input_dir, '**', '*.flac'), recursive=True) |
| | mp3_paths = glob(os.path.join(args.input_dir, '**', '*.mp3'), recursive=True) |
| |
|
| | |
| | wav_paths = wav_paths + flac_paths + mp3_paths |
| | print(f'Found {len(wav_paths)} wavs in {args.input_dir}') |
| | |
| | st = time() |
| | for wav_path in tqdm(wav_paths): |
| | target_wav_path = join(wav_dir, basename(wav_path)) |
| | wav = librosa.load(wav_path, sr=sr)[0] |
| | wav_cpu = torch.from_numpy(wav) |
| |
|
| | |
| | wav = wav_cpu.unsqueeze(0).cuda() |
| | pad_for_wav = (320 - (wav.shape[1] % 320)) |
| | |
| | wav = torch.nn.functional.pad(wav, (0, pad_for_wav)) |
| |
|
| | feat = feature_extractor(F.pad(wav[0,:].cpu(), (160, 160)), sampling_rate=16000, return_tensors="pt") .data['input_features'] |
| | |
| | feat = feat.cuda() |
| |
|
| | with torch.no_grad(): |
| | vq_emb = encoder(wav.unsqueeze(1)) |
| | vq_emb = vq_emb.transpose(1, 2) |
| |
|
| | |
| | semantic_target = semantic_model(feat[:, :,:]) |
| |
|
| | semantic_target = semantic_target.hidden_states[16] |
| |
|
| | semantic_target = semantic_target.transpose(1, 2) |
| | semantic_target = SemanticEncoder_module(semantic_target) |
| | |
| |
|
| | vq_emb = torch.cat([semantic_target, vq_emb], dim=1) |
| | vq_emb = fc_prior(vq_emb.transpose(1, 2)).transpose(1, 2) |
| |
|
| | _, vq_code, _ = decoder(vq_emb, vq=True) |
| |
|
| | vq_post_emb = decoder.quantizer.get_output_from_indices(vq_code.transpose(1, 2)) |
| | vq_post_emb = vq_post_emb.transpose(1, 2) |
| | vq_post_emb = fc_post_a(vq_post_emb.transpose(1,2)).transpose(1,2) |
| | recon = decoder(vq_post_emb.transpose(1, 2), vq=False)[0].squeeze().detach().cpu().numpy() |
| | |
| | sf.write(target_wav_path, recon, sr) |
| | et = time() |
| | print(f'Inference ends, time: {(et-st)/60:.2f} mins') |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|