File size: 5,087 Bytes
dbbd709 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import os
import json
import random
import torch
import torchaudio
import torchaudio.transforms as AT
import csv
import numpy as np
import librosa
import pandas as pd
import laion_clap
from model.CLAPSep_infer import LightningModule
from model.CLAPSep_decoder import HTSAT_Decoder
import argparse
import pytorch_lightning as pl
from helpers import utils as local_utils
class AudioCapsTest(torch.utils.data.Dataset): # type: ignore
def __init__(self, audioset_json, video2path_map_csv, sr=32000, resample_rate=48000):
self.data_names = []
self.data_labels = []
video2path = {}
for item in csv.reader(open(video2path_map_csv, 'r')):
video2path[item[0]] = item[-1]
video2labels = json.load(open(audioset_json, 'r'))
for video, labels in video2labels.items():
if video in video2path:
video_path = video2path[video]
self.data_names.append(video_path)
self.data_labels.append(labels)
if resample_rate is not None:
self.resampler = AT.Resample(sr, resample_rate)
self.sr = sr
self.resample_rate = resample_rate
else:
self.sr = sr
def __len__(self):
return len(self.data_names)
def load_wav(self, path):
max_length = self.sr * 10
wav = librosa.core.load(path, sr=self.sr)[0]
if len(wav) > max_length:
wav = wav[0:max_length]
# pad audio to max length, 10s for AudioCaps
if len(wav) < max_length:
# audio = torch.nn.functional.pad(audio, (0, self.max_length - audio.size(1)), 'constant')
wav = np.pad(wav, (0, max_length - len(wav)), 'constant')
return wav
def __getitem__(self, idx):
tgt_name = self.data_names[idx]
tgt_labels = self.data_labels[idx]
mixed = torch.tensor(self.load_wav(tgt_name))
return mixed, self.resampler(mixed), '|'.join(tgt_labels), tgt_name
def main(args):
torch.set_float32_matmul_precision('highest')
# Load dataset
data_test = AudioCapsTest(audioset_json=args.audioset_json,
video2path_map_csv=args.video2path_map_csv,
sr=args.sample_rate,
resample_rate=48000)
test_loader = torch.utils.data.DataLoader(data_test,
batch_size=1,
num_workers=1,
pin_memory=True,
shuffle=False)
clap_model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cpu')
clap_model.load_ckpt(args.clap_path)
decoder = HTSAT_Decoder(**args.model)
lightning_module = LightningModule(clap_model, decoder, lr=args.optim['lr'],
use_lora=args.lora,
rank=args.lora_rank,
nfft=args.nfft)
distributed_backend = "ddp"
trainer = pl.Trainer(
default_root_dir=os.path.join(args.exp_dir, 'checkpoint'),
devices=args.gpu_ids if args.use_cuda else "auto",
accelerator="gpu" if args.use_cuda else "cpu",
benchmark=False,
gradient_clip_val=5.0,
precision='bf16-mixed',
limit_train_batches=1.0,
max_epochs=args.epochs,
strategy=distributed_backend,
logger=False
)
weights = torch.load(args.ckpt_path, map_location='cpu')
lightning_module.load_state_dict(weights, strict=False)
trainer.test(model=lightning_module, dataloaders=test_loader)
# trainer.test(model=lightning_module, dataloaders=test_loader, ckpt_path=args.ckpt_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Data Params
parser.add_argument('exp_dir', type=str,
default='experiments',
help="Path to save checkpoints and logs.")
parser.add_argument('--sample_rate', type=int, default=16000)
parser.add_argument('--ckpt_path', type=str, default='')
parser.add_argument('--audioset_json', type=str, default='')
parser.add_argument('--video2path_map_csv', type=str, default='')
parser.add_argument('--use_cuda', dest='use_cuda', action='store_true',
help="Whether to use cuda")
parser.add_argument('--gpu_ids', nargs='+', type=int, default=None,
help="List of GPU ids used for training. "
"Eg., --gpu_ids 2 4. All GPUs are used by default.")
args = parser.parse_args()
# Set the random seed for reproducible experiments
pl.seed_everything(114514)
# Set up checkpoints
if not os.path.exists(args.exp_dir):
os.makedirs(args.exp_dir)
# Load model and training params
params = local_utils.Params(os.path.join(args.exp_dir, 'config.json'))
for k, v in params.__dict__.items():
vars(args)[k] = v
main(args)
|