Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: inference with both config file and cli arguments #89

Merged
merged 1 commit into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions configs/inference/default.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
source_image: ./default.png
driving_audio: default.wav
source_image: examples/reference_images/1.jpg
driving_audio: examples/driving_audios/1.wav

weight_dtype: fp16

Expand Down Expand Up @@ -38,10 +38,10 @@ vae:

save_path: ./.cache

face_expand_ratio: 1.1
pose_weight: 1.1
face_weight: 1.1
lip_weight: 1.1
face_expand_ratio: 1.2
pose_weight: 1.0
face_weight: 1.0
lip_weight: 1.0

unet_additional_kwargs:
use_inflated_groupnorm: true
Expand Down
25 changes: 25 additions & 0 deletions hallo/utils/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""
This module provides utility functions for configuration manipulation.
"""

from typing import Dict


def filter_non_none(dict_obj: Dict):
"""
Filters out key-value pairs from the given dictionary where the value is None.

Args:
dict_obj (Dict): The dictionary to be filtered.

Returns:
Dict: The dictionary with key-value pairs removed where the value was None.

This function creates a new dictionary containing only the key-value pairs from
the original dictionary where the value is not None. It then clears the original
dictionary and updates it with the filtered key-value pairs.
"""
non_none_filter = { k: v for k, v in dict_obj.items() if v is not None }
dict_obj.clear()
dict_obj.update(non_none_filter)
return dict_obj
21 changes: 11 additions & 10 deletions scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from hallo.models.image_proj import ImageProjModel
from hallo.models.unet_2d_condition import UNet2DConditionModel
from hallo.models.unet_3d import UNet3DConditionModel
from hallo.utils.config import filter_non_none
from hallo.utils.util import tensor_to_video


Expand Down Expand Up @@ -125,16 +126,16 @@ def inference_process(args: argparse.Namespace):
modules and variables to prepare for the upcoming inference steps.
"""
# 1. init config
cli_args = filter_non_none(vars(args))
config = OmegaConf.load(args.config)
config = OmegaConf.merge(config, vars(args))
config = OmegaConf.merge(config, cli_args)
source_image_path = config.source_image
driving_audio_path = config.driving_audio
save_path = config.save_path
if not os.path.exists(save_path):
os.makedirs(save_path)
motion_scale = [config.pose_weight, config.face_weight, config.lip_weight]
if args.checkpoint is not None:
config.audio_ckpt_dir = args.checkpoint

# 2. runtime variables
device = torch.device(
"cuda") if torch.cuda.is_available() else torch.device("cpu")
Expand Down Expand Up @@ -353,21 +354,21 @@ def inference_process(args: argparse.Namespace):
parser.add_argument(
"-c", "--config", default="configs/inference/default.yaml")
parser.add_argument("--source_image", type=str, required=False,
help="source image", default="test_data/source_images/6.jpg")
help="source image")
parser.add_argument("--driving_audio", type=str, required=False,
help="driving audio", default="test_data/driving_audios/singing/sing_4.wav")
help="driving audio")
parser.add_argument(
"--output", type=str, help="output video file name", default=".cache/output.mp4")
parser.add_argument(
"--pose_weight", type=float, help="weight of pose", default=1.0)
"--pose_weight", type=float, help="weight of pose", required=False)
parser.add_argument(
"--face_weight", type=float, help="weight of face", default=1.0)
"--face_weight", type=float, help="weight of face", required=False)
parser.add_argument(
"--lip_weight", type=float, help="weight of lip", default=1.0)
"--lip_weight", type=float, help="weight of lip", required=False)
parser.add_argument(
"--face_expand_ratio", type=float, help="face region", default=1.2)
"--face_expand_ratio", type=float, help="face region", required=False)
parser.add_argument(
"--checkpoint", type=str, help="which checkpoint", default=None)
"--audio_ckpt_dir", "--checkpoint", type=str, help="specific checkpoint dir", required=False)


command_line_args = parser.parse_args()
Expand Down
Loading