Skip to content

Commit

Permalink
🎨 Multi-Inputs more user friendly UI (Mikubill#2533)
Browse files Browse the repository at this point in the history
* format

* 🎨 multi-inputs UI refresh

* nit
  • Loading branch information
huchenlei committed Jan 21, 2024
1 parent 7660993 commit 81b5dde
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 29 deletions.
2 changes: 1 addition & 1 deletion internal_controlnet/external_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class ControlNetUnit:
module: str = "none"
model: str = "None"
weight: float = 1.0
image: Optional[InputImage] = None
image: Optional[Union[InputImage, List[InputImage]]] = None
resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT
low_vram: bool = False
processor_res: int = -1
Expand Down
3 changes: 1 addition & 2 deletions scripts/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,7 @@ def get_element(obj, strict=False):
return None

attribute_value = get_element(getattr(p, attribute, None), strict)
default_value = get_element(default)
return attribute_value if attribute_value is not None else default_value
return attribute_value if attribute_value is not None else default

@staticmethod
def parse_remote_call(p, unit: external_code.ControlNetUnit, idx):
Expand Down
111 changes: 85 additions & 26 deletions scripts/controlnet_ui/controlnet_ui_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import gradio as gr
import functools
from copy import copy
from typing import List, Optional, Union, Callable, Dict, Tuple
from typing import List, Optional, Union, Callable, Dict, Tuple, Literal
from dataclasses import dataclass
import numpy as np

from scripts.utils import svg_preprocess, read_image_dir
from scripts.utils import svg_preprocess, read_image
from scripts import (
global_state,
external_code,
Expand Down Expand Up @@ -83,8 +83,11 @@ def ui_initialized(self) -> bool:
"img2img_inpaint_sketch_tab": "img2img_inpaint_sketch_tab",
"img2img_inpaint_upload_tab": "img2img_inpaint_upload_tab",
}
return all(c for name, c in vars(self).items()
if name not in optional_components.values())
return all(
c
for name, c in vars(self).items()
if name not in optional_components.values()
)

def set_component(self, component: gr.components.IOComponent):
id_mapping = {
Expand All @@ -111,7 +114,9 @@ def set_component(self, component: gr.components.IOComponent):
if elem_id in id_mapping:
setattr(self, id_mapping[elem_id], component)
logger.debug(f"Setting {elem_id}.")
logger.debug(f"A1111 initialized {sum(c is not None for c in vars(self).values())}/{len(vars(self).keys())}.")
logger.debug(
f"A1111 initialized {sum(c is not None for c in vars(self).values())}/{len(vars(self).keys())}."
)


class UiControlNetUnit(external_code.ControlNetUnit):
Expand All @@ -122,8 +127,10 @@ def __init__(
input_mode: InputMode = InputMode.SIMPLE,
batch_images: Optional[Union[str, List[external_code.InputImage]]] = None,
output_dir: str = "",
merge_image_dir: Optional[str] = None,
loopback: bool = False,
merge_gallery_files: List[
Dict[Union[Literal["name"], Literal["data"]], str]
] = [],
use_preview_as_input: bool = False,
generated_image: Optional[np.ndarray] = None,
mask_image: Optional[np.ndarray] = None,
Expand All @@ -146,28 +153,40 @@ def __init__(
assert isinstance(input_image, dict)
input_image["mask"] = mask_image

if merge_gallery_files and input_mode == InputMode.MERGE:
input_image = [
{"image": read_image(file["name"])} for file in merge_gallery_files
]

super().__init__(enabled, module, model, weight, input_image, *args, **kwargs)
self.is_ui = True
self.input_mode = input_mode
self.batch_images = batch_images
self.output_dir = output_dir
self.merge_image_dir = merge_image_dir
self.loopback = loopback

def accepts_multiple_inputs(self) -> bool:
"""This unit can accept multiple input images."""
return self.module == "ip-adapter_face_id"

def unfold_merged(self) -> List[external_code.ControlNetUnit]:
"""Unfolds a merged unit to multiple units."""
"""Unfolds a merged unit to multiple units. Keeps the unit merged for
preprocessors that can accept multiple input images.
"""
if self.input_mode != InputMode.MERGE:
return [copy(self)]

assert self.merge_image_dir
# if self.accepts_multiple_inputs():
# self.input_mode = InputMode.SIMPLE
# return [copy(self)]

assert isinstance(self.image, list)
result = []
for image in read_image_dir(self.merge_image_dir):
for image in self.image:
unit = copy(self)
unit.image = image
unit.image = image["image"]
unit.input_mode = InputMode.SIMPLE
result.append(unit)
if not result:
logger.warn(f"No image detected in '{self.merge_image_dir}'.")
return result


Expand Down Expand Up @@ -217,6 +236,10 @@ def __init__(
self.webcam_mirrored = False

# Note: All gradio elements declared in `render` will be defined as member variable.
# Update counter to trigger a force update of UiControlNetUnit.
# This is useful when a field with no event subscriber available changes.
# e.g. gr.Gallery, gr.State, etc.
self.update_unit_counter = None
self.upload_tab = None
self.image = None
self.generated_image_group = None
Expand All @@ -226,7 +249,9 @@ def __init__(
self.batch_tab = None
self.batch_image_dir = None
self.merge_tab = None
self.merge_image_dir = None
self.merge_gallery = None
self.merge_upload_button = None
self.merge_clear_button = None
self.create_canvas = None
self.canvas_width = None
self.canvas_height = None
Expand Down Expand Up @@ -289,6 +314,7 @@ def render(self, tabname: str, elem_id_tabname: str) -> None:
Returns:
None
"""
self.update_unit_counter = gr.Number(value=0, visible=False)
self.openpose_editor = OpenposeEditor()

with gr.Group(visible=not self.is_img2img) as self.image_upload_panel:
Expand Down Expand Up @@ -361,11 +387,16 @@ def render(self, tabname: str, elem_id_tabname: str) -> None:
)

with gr.Tab(label="Multi-Inputs") as self.merge_tab:
self.merge_image_dir = gr.Textbox(
label="Input Directory",
placeholder="All images in the input directory will be taken as inputs to this unit",
elem_id=f"{elem_id_tabname}_{tabname}_merge_image_dir",
self.merge_gallery = gr.Gallery(
columns=[4], rows=[2], object_fit="contain", height="auto"
)
with gr.Row():
self.merge_upload_button = gr.UploadButton(
"Upload Images",
file_types=["image"],
file_count="multiple",
)
self.merge_clear_button = gr.Button("Clear Images")

if self.photopea:
self.photopea.attach_photopea_output(self.generated_image)
Expand Down Expand Up @@ -623,11 +654,11 @@ def render(self, tabname: str, elem_id_tabname: str) -> None:
self.input_mode,
self.batch_image_dir_state,
self.output_dir_state,
self.merge_image_dir,
self.loopback,
# Non-persistent fields.
# Following inputs will not be persistent on `ControlNetUnit`.
# They are only used during object construction.
self.merge_gallery,
self.use_preview_as_input,
self.generated_image,
self.mask_image,
Expand All @@ -651,7 +682,7 @@ def render(self, tabname: str, elem_id_tabname: str) -> None:
)

unit = gr.State(self.default_unit)
for comp in unit_args:
for comp in unit_args + (self.update_unit_counter,):
event_subscribers = []
if hasattr(comp, "edit"):
event_subscribers.append(comp.edit)
Expand Down Expand Up @@ -889,13 +920,15 @@ def filter_selected(k: str):
)

def sd_version_changed(type_filter: str, current_model: str):
""" When SD version changes, update model dropdown choices. """
"""When SD version changes, update model dropdown choices."""
(
filtered_preprocessor_list,
filtered_model_list,
default_option,
default_model,
) = global_state.select_control_type(type_filter, global_state.get_sd_version())
) = global_state.select_control_type(
type_filter, global_state.get_sd_version()
)

if current_model in filtered_model_list:
return gr.update()
Expand Down Expand Up @@ -1236,6 +1269,34 @@ def clear_preview(x):
outputs=[self.use_preview_as_input, self.generated_image],
)

def register_multi_images_upload(self):
"""Register callbacks on merge tab multiple images upload."""
self.merge_clear_button.click(
fn=lambda: [],
inputs=[],
outputs=[self.merge_gallery],
).then(
fn=lambda x: gr.update(value=x + 1),
inputs=[self.update_unit_counter],
outputs=[self.update_unit_counter],
)

def upload_file(files, current_files):
return {file_d["name"] for file_d in current_files} | {
file.name for file in files
}

self.merge_upload_button.upload(
upload_file,
inputs=[self.merge_upload_button, self.merge_gallery],
outputs=[self.merge_gallery],
queue=False,
).then(
fn=lambda x: gr.update(value=x + 1),
inputs=[self.update_unit_counter],
outputs=[self.update_unit_counter],
)

def register_callbacks(self):
"""Register callbacks on the UI elements."""
# Prevent infinite recursion.
Expand All @@ -1254,6 +1315,7 @@ def register_callbacks(self):
self.register_create_canvas()
self.register_sync_batch_dir()
self.register_clear_preview()
self.register_multi_images_upload()
self.openpose_editor.register_callbacks(
self.generated_image,
self.use_preview_as_input,
Expand Down Expand Up @@ -1310,10 +1372,7 @@ def register_input_mode_sync(ui_groups: List["ControlNetUiGroup"]):
fn=lambda *mode_values: (
(
gr.update(
visible=any(
m == InputMode.BATCH
for m in mode_values
)
visible=any(m == InputMode.BATCH for m in mode_values)
),
)
* len(ui_groups)
Expand Down

0 comments on commit 81b5dde

Please sign in to comment.