Spaces:
Sleeping
Sleeping
| # test_quick.py - CHỈ CẦN 10 DÒNG CODE | |
| import torch | |
| import torchaudio | |
| import librosa | |
| from scipy.spatial.distance import cosine | |
| # Function từ inference code của bạn | |
| def extract_style_simple(audio_path, model): | |
| wave, sr = librosa.load(audio_path, sr=24000) | |
| audio, _ = librosa.effects.trim(wave, top_db=30) | |
| to_mel = torchaudio.transforms.MelSpectrogram(n_mels=80, n_fft=2048, win_length=1200, hop_length=300) | |
| mel = to_mel(torch.from_numpy(audio).float()) | |
| mel = (torch.log(1e-5 + mel.unsqueeze(0)) - (-4)) / 4 | |
| mel = mel.to("cuda") | |
| with torch.no_grad(): | |
| ref_s = model['style_encoder'](mel.unsqueeze(1)) | |
| return ref_s.cpu().numpy() | |
| # Load model của bạn (code như inference) | |
| from models import * | |
| from utils import * | |
| import yaml | |
| config = yaml.safe_load(open("./Configs/config_ft.yml")) | |
| text_aligner = load_ASR_models(config['ASR_path'], config['ASR_config']) | |
| pitch_extractor = load_F0_models(config['F0_path']) | |
| from Utils.PLBERT.util import load_plbert | |
| plbert = load_plbert(config['PLBERT_dir']) | |
| model = build_model(recursive_munch(config['model_params']), text_aligner, pitch_extractor, plbert) | |
| params = torch.load("/u01/colombo/hungnt/hieuld/tts/StyleTTS2/hieuducle/styletts2-ver2-model-bestmodel/best_model_ver2.pth", map_location='cuda')['net'] | |
| for key in model: | |
| state_dict = params[key] | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| if k.startswith("module."): | |
| new_state_dict[k[len("module."):]] = v | |
| else: | |
| new_state_dict[k] = v | |
| model[key].load_state_dict(new_state_dict, strict=True) | |
| model[key].eval().to("cuda") | |
| # TEST - Dùng 2 audio của 2 speakers KHÁC NHAU | |
| s1 = extract_style_simple("/u01/colombo/hungnt/hieuld/tts/styletts2_vastai/audio_ref/sena30.wav", model) | |
| s2 = extract_style_simple("/u01/colombo/hungnt/hieuld/tts/styletts2_vastai/audio_ref/megame.wav", model) | |
| similarity = 1 - cosine(s1.flatten(), s2.flatten()) | |
| print(f"\n{'='*60}") | |
| print(f"Cross-speaker similarity: {similarity:.4f}") | |
| print(f"{'='*60}") | |
| if similarity > 0.85: | |
| print("❌ CRITICAL: Speaker Collapse!") | |
| print(" → Style encoder học 'average voice'") | |
| print(" → CẦN RETRAIN với style_dim=256") | |
| elif similarity > 0.75: | |
| print("⚠️ WARNING: Weak discrimination") | |
| print(" → Có thể fine-tune thêm") | |
| else: | |
| print("✅ Style encoder OK") | |
| print(" → Vấn đề ở chỗ khác") |