# test_quick.py - CHỈ CẦN 10 DÒNG CODE import torch import torchaudio import librosa from scipy.spatial.distance import cosine # Function từ inference code của bạn def extract_style_simple(audio_path, model): wave, sr = librosa.load(audio_path, sr=24000) audio, _ = librosa.effects.trim(wave, top_db=30) to_mel = torchaudio.transforms.MelSpectrogram(n_mels=80, n_fft=2048, win_length=1200, hop_length=300) mel = to_mel(torch.from_numpy(audio).float()) mel = (torch.log(1e-5 + mel.unsqueeze(0)) - (-4)) / 4 mel = mel.to("cuda") with torch.no_grad(): ref_s = model['style_encoder'](mel.unsqueeze(1)) return ref_s.cpu().numpy() # Load model của bạn (code như inference) from models import * from utils import * import yaml config = yaml.safe_load(open("./Configs/config_ft.yml")) text_aligner = load_ASR_models(config['ASR_path'], config['ASR_config']) pitch_extractor = load_F0_models(config['F0_path']) from Utils.PLBERT.util import load_plbert plbert = load_plbert(config['PLBERT_dir']) model = build_model(recursive_munch(config['model_params']), text_aligner, pitch_extractor, plbert) params = torch.load("/u01/colombo/hungnt/hieuld/tts/StyleTTS2/hieuducle/styletts2-ver2-model-bestmodel/best_model_ver2.pth", map_location='cuda')['net'] for key in model: state_dict = params[key] new_state_dict = {} for k, v in state_dict.items(): if k.startswith("module."): new_state_dict[k[len("module."):]] = v else: new_state_dict[k] = v model[key].load_state_dict(new_state_dict, strict=True) model[key].eval().to("cuda") # TEST - Dùng 2 audio của 2 speakers KHÁC NHAU s1 = extract_style_simple("/u01/colombo/hungnt/hieuld/tts/styletts2_vastai/audio_ref/sena30.wav", model) s2 = extract_style_simple("/u01/colombo/hungnt/hieuld/tts/styletts2_vastai/audio_ref/megame.wav", model) similarity = 1 - cosine(s1.flatten(), s2.flatten()) print(f"\n{'='*60}") print(f"Cross-speaker similarity: {similarity:.4f}") print(f"{'='*60}") if similarity > 0.85: print("❌ CRITICAL: Speaker Collapse!") print(" → Style encoder học 'average voice'") print(" → CẦN RETRAIN với style_dim=256") elif similarity > 0.75: print("⚠️ WARNING: Weak discrimination") print(" → Có thể fine-tune thêm") else: print("✅ Style encoder OK") print(" → Vấn đề ở chỗ khác")