| import os |
| import subprocess |
| from pathlib import Path |
|
|
| import gradio as gr |
|
|
| from config import hparams as hp |
| from config import hparams_gradio as hp_gradio |
| from nota_wav2lip import Wav2LipModelComparisonGradio |
|
|
| |
| device = hp_gradio.device |
| print(f'Using {device} for inference.') |
| video_label_dict = hp_gradio.sample.video |
| audio_label_dict = hp_gradio.sample.audio |
|
|
| LRS_ORIGINAL_URL = os.getenv('LRS_ORIGINAL_URL', None) |
| LRS_COMPRESSED_URL = os.getenv('LRS_COMPRESSED_URL', None) |
| LRS_INFERENCE_SAMPLE = os.getenv('LRS_INFERENCE_SAMPLE', None) |
|
|
| if not Path(hp.inference.model.wav2lip.checkpoint).exists() and LRS_ORIGINAL_URL is not None: |
| subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.wav2lip.checkpoint} {LRS_ORIGINAL_URL}", shell=True) |
| if not Path(hp.inference.model.nota_wav2lip.checkpoint).exists() and LRS_COMPRESSED_URL is not None: |
| subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.nota_wav2lip.checkpoint} {LRS_COMPRESSED_URL}", shell=True) |
|
|
| path_inference_sample = "sample.tar.gz" |
| if not Path(path_inference_sample).exists() and LRS_INFERENCE_SAMPLE is not None: |
| subprocess.call(f"wget --no-check-certificate -O {path_inference_sample} {LRS_INFERENCE_SAMPLE}", shell=True) |
| subprocess.call(f"tar -zxvf {path_inference_sample}", shell=True) |
|
|
|
|
| if __name__ == "__main__": |
|
|
| servicer = Wav2LipModelComparisonGradio( |
| device=device, |
| video_label_dict=video_label_dict, |
| audio_label_list=audio_label_dict, |
| default_video='v1', |
| default_audio='a1' |
| ) |
|
|
| for video_name in sorted(video_label_dict): |
| video_stem = Path(video_label_dict[video_name]) |
| servicer.update_video(video_stem, video_stem.with_suffix('.json'), |
| name=video_name) |
|
|
| for audio_name in sorted(audio_label_dict): |
| audio_path = Path(audio_label_dict[audio_name]) |
| servicer.update_audio(audio_path, name=audio_name) |
|
|
| with gr.Blocks(theme='nota-ai/theme', css=Path('docs/main.css').read_text()) as demo: |
| gr.Markdown(Path('docs/header.md').read_text()) |
| gr.Markdown(Path('docs/description.md').read_text()) |
| with gr.Row(): |
| with gr.Column(variant='panel'): |
|
|
| gr.Markdown('## Select input video and audio', sanitize_html=False) |
| |
| sample_video = gr.Video(interactive=False, label="Input Video") |
| sample_audio = gr.Audio(interactive=False, label="Input Audio") |
|
|
| |
| video_selection = gr.components.Radio(video_label_dict, |
| type='value', label="Select an input video:") |
| audio_selection = gr.components.Radio(audio_label_dict, |
| type='value', label="Select an input audio:") |
| |
| with gr.Row(equal_height=True): |
| generate_original_button = gr.Button(value="Generate with Original Model", variant="primary") |
| generate_compressed_button = gr.Button(value="Generate with Compressed Model", variant="primary") |
| with gr.Column(variant='panel'): |
| |
| gr.Markdown('## Original Wav2Lip') |
| original_model_output = gr.Video(label="Original Model", interactive=False) |
| with gr.Column(): |
| with gr.Row(equal_height=True): |
| original_model_inference_time = gr.Textbox(value="", label="Total inference time (sec)") |
| original_model_fps = gr.Textbox(value="", label="FPS") |
| original_model_params = gr.Textbox(value=servicer.params['wav2lip'], label="# Parameters") |
| with gr.Column(variant='panel'): |
| |
| gr.Markdown('## Compressed Wav2Lip (Ours)') |
| compressed_model_output = gr.Video(label="Compressed Model", interactive=False) |
| with gr.Column(): |
| with gr.Row(equal_height=True): |
| compressed_model_inference_time = gr.Textbox(value="", label="Total inference time (sec)") |
| compressed_model_fps = gr.Textbox(value="", label="FPS") |
| compressed_model_params = gr.Textbox(value=servicer.params['nota_wav2lip'], label="# Parameters") |
|
|
| |
| video_selection.change(fn=servicer.switch_video_samples, inputs=video_selection, outputs=sample_video) |
| audio_selection.change(fn=servicer.switch_audio_samples, inputs=audio_selection, outputs=sample_audio) |
|
|
| |
| generate_original_button.click(servicer.generate_original_model, |
| inputs=[video_selection, audio_selection], |
| outputs=[original_model_output, original_model_inference_time, original_model_fps]) |
| |
| generate_compressed_button.click(servicer.generate_compressed_model, |
| inputs=[video_selection, audio_selection], |
| outputs=[compressed_model_output, compressed_model_inference_time, compressed_model_fps]) |
|
|
| gr.Markdown(Path('docs/footer.md').read_text()) |
|
|
| demo.queue().launch() |
|
|