From 07ffd49627f20638cb0d98cc6ad6ca5e832e538f Mon Sep 17 00:00:00 2001 From: cocktailpeanut <121128867+cocktailpeanut@users.noreply.github.com> Date: Thu, 20 Jun 2024 02:01:24 -0400 Subject: [PATCH] feat: gradio WebUI (#51) * WebUI + Audio Fix 1. audio fix: explicitly specify the audio codec in `util.py`, otherwise the video is technically corrupt and doesn't play sound 2. web ui: gradio web ui 3. print the current step while running inference gradio * lint * update --- hallo/utils/util.py | 2 +- requirements.txt | 3 +- scripts/app.py | 65 ++++++++++++++++++++++++++++++++++++++++++++ scripts/inference.py | 2 ++ 4 files changed, 70 insertions(+), 2 deletions(-) create mode 100644 scripts/app.py diff --git a/hallo/utils/util.py b/hallo/utils/util.py index 3a460f7..f4b6563 100644 --- a/hallo/utils/util.py +++ b/hallo/utils/util.py @@ -315,7 +315,7 @@ def make_frame(t): new_video_clip = VideoClip(make_frame, duration=tensor.shape[0] / fps) audio_clip = AudioFileClip(audio_source).subclip(0, tensor.shape[0] / fps) new_video_clip = new_video_clip.set_audio(audio_clip) - new_video_clip.write_videofile(output_video_file, fps=fps) + new_video_clip.write_videofile(output_video_file, fps=fps, audio_codec='aac') silhouette_ids = [ diff --git a/requirements.txt b/requirements.txt index 40eff18..7c3c5dc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,4 +27,5 @@ transformers==4.39.2 xformers==0.0.25.post1 isort==5.13.2 pylint==3.2.2 -pre-commit==3.7.1 \ No newline at end of file +pre-commit==3.7.1 +gradio==4.36.1 diff --git a/scripts/app.py b/scripts/app.py new file mode 100644 index 0000000..e106c02 --- /dev/null +++ b/scripts/app.py @@ -0,0 +1,65 @@ +""" +This script is a gradio web ui. + +The script takes an image and an audio clip, and lets you configure all the +variables such as cfg_scale, pose_weight, face_weight, lip_weight, etc. + +Usage: +This script can be run from the command line with the following command: + +python scripts/app.py +""" +import argparse + +import gradio as gr +from inference import inference_process + + +def predict(image, audio, size, steps, fps, cfg, pose_weight, face_weight, lip_weight, face_expand_ratio): + """ + Create a gradio interface with the configs. + """ + config = { + 'data': { + 'source_image': { + 'width': size, + 'height': size + }, + 'export_video': { + 'fps': fps + } + }, + 'cfg_scale': cfg, + 'source_image': image, + 'driving_audio': audio, + 'pose_weight': pose_weight, + 'face_weight': face_weight, + 'lip_weight': lip_weight, + 'face_expand_ratio': face_expand_ratio, + 'config': 'configs/inference/default.yaml', + 'checkpoint': None, + 'output': ".cache/output.mp4", + 'inference_steps': steps + } + args = argparse.Namespace() + for key, value in config.items(): + setattr(args, key, value) + return inference_process(args) + +app = gr.Interface( + fn=predict, + inputs=[ + gr.Image(label="source image (no webp)", type="filepath", format="jpeg"), + gr.Audio(label="source audio", type="filepath"), + gr.Number(label="size", value=512, minimum=256, maximum=512, step=64, precision=0), + gr.Number(label="steps", value=40, minimum=1, step=1, precision=0), + gr.Number(label="fps", value=25, minimum=1, step=1, precision=0), + gr.Slider(label="CFG Scale", value=3.5, minimum=0, maximum=10, step=0.01), + gr.Number(label="pose weight", value=1.0), + gr.Number(label="face weight", value=1.0), + gr.Number(label="lip weight", value=1.0), + gr.Number(label="face expand ratio", value=1.2), + ], + outputs=[gr.Video()], +) +app.launch() diff --git a/scripts/inference.py b/scripts/inference.py index 8bbc5cc..c2ef0bb 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -288,6 +288,7 @@ def inference_process(args: argparse.Namespace): generator = torch.manual_seed(42) for t in range(times): + print(f"[{t+1}/{times}]") if len(tensor_result) == 0: # The first iteration @@ -342,6 +343,7 @@ def inference_process(args: argparse.Namespace): output_file = config.output # save the result after all iteration tensor_to_video(tensor_result, output_file, driving_audio_path) + return output_file if __name__ == "__main__":