|
|
import os |
|
|
import json |
|
|
import random |
|
|
|
|
|
import torch |
|
|
import torchaudio |
|
|
import torchaudio.transforms as AT |
|
|
import csv |
|
|
import numpy as np |
|
|
import librosa |
|
|
import pandas as pd |
|
|
import laion_clap |
|
|
from model.CLAPSep_infer import LightningModule |
|
|
from model.CLAPSep_decoder import HTSAT_Decoder |
|
|
import argparse |
|
|
import pytorch_lightning as pl |
|
|
from helpers import utils as local_utils |
|
|
|
|
|
|
|
|
class AudioCapsTest(torch.utils.data.Dataset): |
|
|
|
|
|
def __init__(self, audioset_json, video2path_map_csv, sr=32000, resample_rate=48000): |
|
|
self.data_names = [] |
|
|
self.data_labels = [] |
|
|
video2path = {} |
|
|
for item in csv.reader(open(video2path_map_csv, 'r')): |
|
|
video2path[item[0]] = item[-1] |
|
|
|
|
|
video2labels = json.load(open(audioset_json, 'r')) |
|
|
for video, labels in video2labels.items(): |
|
|
if video in video2path: |
|
|
video_path = video2path[video] |
|
|
self.data_names.append(video_path) |
|
|
self.data_labels.append(labels) |
|
|
|
|
|
if resample_rate is not None: |
|
|
self.resampler = AT.Resample(sr, resample_rate) |
|
|
self.sr = sr |
|
|
self.resample_rate = resample_rate |
|
|
else: |
|
|
self.sr = sr |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data_names) |
|
|
|
|
|
def load_wav(self, path): |
|
|
max_length = self.sr * 10 |
|
|
wav = librosa.core.load(path, sr=self.sr)[0] |
|
|
if len(wav) > max_length: |
|
|
wav = wav[0:max_length] |
|
|
|
|
|
|
|
|
if len(wav) < max_length: |
|
|
|
|
|
wav = np.pad(wav, (0, max_length - len(wav)), 'constant') |
|
|
return wav |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
tgt_name = self.data_names[idx] |
|
|
tgt_labels = self.data_labels[idx] |
|
|
|
|
|
mixed = torch.tensor(self.load_wav(tgt_name)) |
|
|
|
|
|
return mixed, self.resampler(mixed), '|'.join(tgt_labels), tgt_name |
|
|
|
|
|
|
|
|
|
|
|
def main(args): |
|
|
torch.set_float32_matmul_precision('highest') |
|
|
|
|
|
|
|
|
data_test = AudioCapsTest(audioset_json=args.audioset_json, |
|
|
video2path_map_csv=args.video2path_map_csv, |
|
|
sr=args.sample_rate, |
|
|
resample_rate=48000) |
|
|
|
|
|
test_loader = torch.utils.data.DataLoader(data_test, |
|
|
batch_size=1, |
|
|
num_workers=1, |
|
|
pin_memory=True, |
|
|
shuffle=False) |
|
|
|
|
|
clap_model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cpu') |
|
|
clap_model.load_ckpt(args.clap_path) |
|
|
decoder = HTSAT_Decoder(**args.model) |
|
|
lightning_module = LightningModule(clap_model, decoder, lr=args.optim['lr'], |
|
|
use_lora=args.lora, |
|
|
rank=args.lora_rank, |
|
|
nfft=args.nfft) |
|
|
distributed_backend = "ddp" |
|
|
trainer = pl.Trainer( |
|
|
default_root_dir=os.path.join(args.exp_dir, 'checkpoint'), |
|
|
devices=args.gpu_ids if args.use_cuda else "auto", |
|
|
accelerator="gpu" if args.use_cuda else "cpu", |
|
|
benchmark=False, |
|
|
gradient_clip_val=5.0, |
|
|
precision='bf16-mixed', |
|
|
limit_train_batches=1.0, |
|
|
max_epochs=args.epochs, |
|
|
strategy=distributed_backend, |
|
|
logger=False |
|
|
) |
|
|
|
|
|
weights = torch.load(args.ckpt_path, map_location='cpu') |
|
|
lightning_module.load_state_dict(weights, strict=False) |
|
|
|
|
|
trainer.test(model=lightning_module, dataloaders=test_loader) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument('exp_dir', type=str, |
|
|
default='experiments', |
|
|
help="Path to save checkpoints and logs.") |
|
|
|
|
|
parser.add_argument('--sample_rate', type=int, default=16000) |
|
|
parser.add_argument('--ckpt_path', type=str, default='') |
|
|
parser.add_argument('--audioset_json', type=str, default='') |
|
|
parser.add_argument('--video2path_map_csv', type=str, default='') |
|
|
|
|
|
parser.add_argument('--use_cuda', dest='use_cuda', action='store_true', |
|
|
help="Whether to use cuda") |
|
|
parser.add_argument('--gpu_ids', nargs='+', type=int, default=None, |
|
|
help="List of GPU ids used for training. " |
|
|
"Eg., --gpu_ids 2 4. All GPUs are used by default.") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
pl.seed_everything(114514) |
|
|
|
|
|
if not os.path.exists(args.exp_dir): |
|
|
os.makedirs(args.exp_dir) |
|
|
|
|
|
|
|
|
params = local_utils.Params(os.path.join(args.exp_dir, 'config.json')) |
|
|
for k, v in params.__dict__.items(): |
|
|
vars(args)[k] = v |
|
|
main(args) |
|
|
|