ClearSep / infer_data_engine_json.py
Tianhao Wang
first commit
dbbd709
import os
import json
import random
import torch
import torchaudio
import torchaudio.transforms as AT
import csv
import numpy as np
import librosa
import pandas as pd
import laion_clap
from model.CLAPSep_infer import LightningModule
from model.CLAPSep_decoder import HTSAT_Decoder
import argparse
import pytorch_lightning as pl
from helpers import utils as local_utils
class AudioCapsTest(torch.utils.data.Dataset): # type: ignore
def __init__(self, audioset_json, video2path_map_csv, sr=32000, resample_rate=48000):
self.data_names = []
self.data_labels = []
video2path = {}
for item in csv.reader(open(video2path_map_csv, 'r')):
video2path[item[0]] = item[-1]
video2labels = json.load(open(audioset_json, 'r'))
for video, labels in video2labels.items():
if video in video2path:
video_path = video2path[video]
self.data_names.append(video_path)
self.data_labels.append(labels)
if resample_rate is not None:
self.resampler = AT.Resample(sr, resample_rate)
self.sr = sr
self.resample_rate = resample_rate
else:
self.sr = sr
def __len__(self):
return len(self.data_names)
def load_wav(self, path):
max_length = self.sr * 10
wav = librosa.core.load(path, sr=self.sr)[0]
if len(wav) > max_length:
wav = wav[0:max_length]
# pad audio to max length, 10s for AudioCaps
if len(wav) < max_length:
# audio = torch.nn.functional.pad(audio, (0, self.max_length - audio.size(1)), 'constant')
wav = np.pad(wav, (0, max_length - len(wav)), 'constant')
return wav
def __getitem__(self, idx):
tgt_name = self.data_names[idx]
tgt_labels = self.data_labels[idx]
mixed = torch.tensor(self.load_wav(tgt_name))
return mixed, self.resampler(mixed), '|'.join(tgt_labels), tgt_name
def main(args):
torch.set_float32_matmul_precision('highest')
# Load dataset
data_test = AudioCapsTest(audioset_json=args.audioset_json,
video2path_map_csv=args.video2path_map_csv,
sr=args.sample_rate,
resample_rate=48000)
test_loader = torch.utils.data.DataLoader(data_test,
batch_size=1,
num_workers=1,
pin_memory=True,
shuffle=False)
clap_model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cpu')
clap_model.load_ckpt(args.clap_path)
decoder = HTSAT_Decoder(**args.model)
lightning_module = LightningModule(clap_model, decoder, lr=args.optim['lr'],
use_lora=args.lora,
rank=args.lora_rank,
nfft=args.nfft)
distributed_backend = "ddp"
trainer = pl.Trainer(
default_root_dir=os.path.join(args.exp_dir, 'checkpoint'),
devices=args.gpu_ids if args.use_cuda else "auto",
accelerator="gpu" if args.use_cuda else "cpu",
benchmark=False,
gradient_clip_val=5.0,
precision='bf16-mixed',
limit_train_batches=1.0,
max_epochs=args.epochs,
strategy=distributed_backend,
logger=False
)
weights = torch.load(args.ckpt_path, map_location='cpu')
lightning_module.load_state_dict(weights, strict=False)
trainer.test(model=lightning_module, dataloaders=test_loader)
# trainer.test(model=lightning_module, dataloaders=test_loader, ckpt_path=args.ckpt_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Data Params
parser.add_argument('exp_dir', type=str,
default='experiments',
help="Path to save checkpoints and logs.")
parser.add_argument('--sample_rate', type=int, default=16000)
parser.add_argument('--ckpt_path', type=str, default='')
parser.add_argument('--audioset_json', type=str, default='')
parser.add_argument('--video2path_map_csv', type=str, default='')
parser.add_argument('--use_cuda', dest='use_cuda', action='store_true',
help="Whether to use cuda")
parser.add_argument('--gpu_ids', nargs='+', type=int, default=None,
help="List of GPU ids used for training. "
"Eg., --gpu_ids 2 4. All GPUs are used by default.")
args = parser.parse_args()
# Set the random seed for reproducible experiments
pl.seed_everything(114514)
# Set up checkpoints
if not os.path.exists(args.exp_dir):
os.makedirs(args.exp_dir)
# Load model and training params
params = local_utils.Params(os.path.join(args.exp_dir, 'config.json'))
for k, v in params.__dict__.items():
vars(args)[k] = v
main(args)