Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial refiner support #12371

Merged
merged 9 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge branch 'dev' into refiner
  • Loading branch information
AUTOMATIC1111 committed Aug 8, 2023
commit 54c3e5c913b17622bed4ff4d03df488b80611e21
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import math

import gradio as gr
from modules import scripts, shared, ui_components, ui_settings
from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste
from modules.ui_components import FormColumn


Expand All @@ -19,18 +21,37 @@ def show(self, is_img2img):
def ui(self, is_img2img):
self.comps = []
self.setting_names = []
self.infotext_fields = []

mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}

with gr.Blocks() as interface:
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(), gr.Row():
for setting_name in shared.opts.extra_options:
with FormColumn():
comp = ui_settings.create_setting_component(setting_name)
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group():

row_count = math.ceil(len(shared.opts.extra_options) / shared.opts.extra_options_cols)

for row in range(row_count):
with gr.Row():
for col in range(shared.opts.extra_options_cols):
index = row * shared.opts.extra_options_cols + col
if index >= len(shared.opts.extra_options):
break

setting_name = shared.opts.extra_options[index]

self.comps.append(comp)
self.setting_names.append(setting_name)
with FormColumn():
comp = ui_settings.create_setting_component(setting_name)

self.comps.append(comp)
self.setting_names.append(setting_name)

setting_infotext_name = mapping.get(setting_name)
if setting_infotext_name is not None:
self.infotext_fields.append((comp, setting_infotext_name))

def get_settings_values():
return [ui_settings.get_value_for_setting(key) for key in self.setting_names]
res = [ui_settings.get_value_for_setting(key) for key in self.setting_names]
return res[0] if len(res) == 1 else res

interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False)

Expand All @@ -44,5 +65,8 @@ def before_process(self, p, *args):

shared.options_templates.update(shared.options_section(('ui', "User interface"), {
"extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_reload_ui(),
"extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion").needs_restart()
"extra_options_cols": shared.OptionInfo(1, "Options in main UI - number of columns", gr.Number, {"precision": 0}).needs_reload_ui(),
"extra_options_accordion": shared.OptionInfo(False, "Options in main UI - place into an accordion").needs_reload_ui()
}))


5 changes: 5 additions & 0 deletions javascript/imageviewer.js
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ function setupImageForLightbox(e) {
var event = isFirefox ? 'mousedown' : 'click';

e.addEventListener(event, function(evt) {
if (evt.button == 1) {
open(evt.target.src);
evt.preventDefault();
return;
}
if (!opts.js_modal_lightbox || evt.button != 0) return;

modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed);
Expand Down
5 changes: 5 additions & 0 deletions modules/generation_parameters_copypaste.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,10 +416,15 @@ def paste_func(prompt):
return res

if override_settings_component is not None:
already_handled_fields = {key: 1 for _, key in paste_fields}

def paste_settings(params):
vals = {}

for param_name, setting_name in infotext_to_setting_name_mapping:
if param_name in already_handled_fields:
continue

v = params.get(param_name, None)
if v is None:
continue
Expand Down
6 changes: 3 additions & 3 deletions modules/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
import gradio as gr

from modules import sd_samplers, images as imgutil
from modules import images as imgutil
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
Expand Down Expand Up @@ -116,7 +116,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
process_images(p)


def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts)

is_batch = mode == 5
Expand Down Expand Up @@ -172,7 +172,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras,
sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
sampler_name=sampler_name,
batch_size=batch_size,
n_iter=n_iter,
steps=steps,
Expand Down
29 changes: 26 additions & 3 deletions modules/launch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,19 +139,42 @@ def check_run_python(code: str) -> bool:
return result.returncode == 0


def git_fix_workspace(dir, name):
run(f'"{git}" -C "{dir}" fetch --refetch --no-auto-gc', f"Fetching all contents for {name}", f"Couldn't fetch {name}", live=True)
run(f'"{git}" -C "{dir}" gc --aggressive --prune=now', f"Pruning {name}", f"Couldn't prune {name}", live=True)
return


def run_git(dir, name, command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live, autofix=True):
try:
return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
except RuntimeError:
pass

if not autofix:
return None

print(f"{errdesc}, attempting autofix...")
git_fix_workspace(dir, name)

return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)


def git_clone(url, dir, name, commithash=None):
# TODO clone into temporary dir and move if successful

if os.path.exists(dir):
if commithash is None:
return

current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
current_hash = run_git(dir, name, 'rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
if current_hash == commithash:
return

run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
run_git('fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")

run_git('checkout', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)

return

run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
Expand Down
3 changes: 0 additions & 3 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,9 +1119,6 @@ def save_intermediate(image, index):

img2img_sampler_name = self.hr_sampler_name or self.sampler_name

if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
img2img_sampler_name = 'DDIM'

self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)

if self.latent_scale_mode is not None:
Expand Down
4 changes: 1 addition & 3 deletions modules/sd_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr

import ldm.modules.attention
import ldm.modules.diffusionmodules.model
Expand Down Expand Up @@ -34,8 +34,6 @@
ldm.util.print = shared.ldm_print
ldm.models.diffusion.ddpm.print = shared.ldm_print

sd_hijack_inpainting.do_inpainting_hijack()

optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None

Expand Down
95 changes: 0 additions & 95 deletions modules/sd_hijack_inpainting.py

This file was deleted.

3 changes: 2 additions & 1 deletion modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer

sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()
sd_vae.load_vae(model, vae_file, vae_source)
timer.record("load VAE")

Expand Down Expand Up @@ -715,6 +715,7 @@ def reload_model_weights(sd_model=None, info=None):
print(f"Weights loaded in {timer.summary()}.")

model_data.set_sd_model(sd_model)
sd_unet.apply_unet()

return sd_model

Expand Down
19 changes: 11 additions & 8 deletions modules/sd_samplers.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, shared

# imports for functions that previously were here and are used by other modules
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401

all_samplers = [
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
*sd_samplers_compvis.samplers_data_compvis,
*sd_samplers_timesteps.samplers_data_timesteps,
]
all_samplers_map = {x.name: x for x in all_samplers}

samplers = []
samplers_for_img2img = []
samplers_map = {}
samplers_hidden = {}


def find_sampler_config(name):
Expand All @@ -38,13 +39,11 @@ def create_sampler(name, model):


def set_samplers():
global samplers, samplers_for_img2img
global samplers, samplers_for_img2img, samplers_hidden

hidden = set(shared.opts.hide_samplers)
hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC'])

samplers = [x for x in all_samplers if x.name not in hidden]
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
samplers_hidden = set(shared.opts.hide_samplers)
samplers = all_samplers
samplers_for_img2img = all_samplers

samplers_map.clear()
for sampler in all_samplers:
Expand All @@ -53,4 +52,8 @@ def set_samplers():
samplers_map[alias.lower()] = sampler.name


def visible_sampler_names():
return [x.name for x in samplers if x.name not in samplers_hidden]


set_samplers()
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.