StyleTTS2_vi / check_style.py
hieuducle's picture
Upload folder using huggingface_hub
84f3a60 verified
# test_quick.py - CHỈ CẦN 10 DÒNG CODE
import torch
import torchaudio
import librosa
from scipy.spatial.distance import cosine
# Function từ inference code của bạn
def extract_style_simple(audio_path, model):
wave, sr = librosa.load(audio_path, sr=24000)
audio, _ = librosa.effects.trim(wave, top_db=30)
to_mel = torchaudio.transforms.MelSpectrogram(n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
mel = to_mel(torch.from_numpy(audio).float())
mel = (torch.log(1e-5 + mel.unsqueeze(0)) - (-4)) / 4
mel = mel.to("cuda")
with torch.no_grad():
ref_s = model['style_encoder'](mel.unsqueeze(1))
return ref_s.cpu().numpy()
# Load model của bạn (code như inference)
from models import *
from utils import *
import yaml
config = yaml.safe_load(open("./Configs/config_ft.yml"))
text_aligner = load_ASR_models(config['ASR_path'], config['ASR_config'])
pitch_extractor = load_F0_models(config['F0_path'])
from Utils.PLBERT.util import load_plbert
plbert = load_plbert(config['PLBERT_dir'])
model = build_model(recursive_munch(config['model_params']), text_aligner, pitch_extractor, plbert)
params = torch.load("/u01/colombo/hungnt/hieuld/tts/StyleTTS2/hieuducle/styletts2-ver2-model-bestmodel/best_model_ver2.pth", map_location='cuda')['net']
for key in model:
state_dict = params[key]
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith("module."):
new_state_dict[k[len("module."):]] = v
else:
new_state_dict[k] = v
model[key].load_state_dict(new_state_dict, strict=True)
model[key].eval().to("cuda")
# TEST - Dùng 2 audio của 2 speakers KHÁC NHAU
s1 = extract_style_simple("/u01/colombo/hungnt/hieuld/tts/styletts2_vastai/audio_ref/sena30.wav", model)
s2 = extract_style_simple("/u01/colombo/hungnt/hieuld/tts/styletts2_vastai/audio_ref/megame.wav", model)
similarity = 1 - cosine(s1.flatten(), s2.flatten())
print(f"\n{'='*60}")
print(f"Cross-speaker similarity: {similarity:.4f}")
print(f"{'='*60}")
if similarity > 0.85:
print("❌ CRITICAL: Speaker Collapse!")
print(" → Style encoder học 'average voice'")
print(" → CẦN RETRAIN với style_dim=256")
elif similarity > 0.75:
print("⚠️ WARNING: Weak discrimination")
print(" → Có thể fine-tune thêm")
else:
print("✅ Style encoder OK")
print(" → Vấn đề ở chỗ khác")