Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial refiner support #12371

Merged
merged 9 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
initial refiner support
  • Loading branch information
AUTOMATIC1111 committed Aug 6, 2023
commit f1975b0213f5be400889ec04b3891d1cb571fe20
4 changes: 4 additions & 0 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}

try:
# after running refiner, the refiner model is not unloaded - webui swaps back to main model here
if shared.sd_model.sd_checkpoint_info.title != opts.sd_model_checkpoint:
sd_models.reload_model_weights()

# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
p.override_settings.pop('sd_model_checkpoint', None)
Expand Down
18 changes: 17 additions & 1 deletion modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,27 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
return res


class SkipWritingToConfig:
"""This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight."""

skip = False
previous = None

def __enter__(self):
self.previous = SkipWritingToConfig.skip
SkipWritingToConfig.skip = True
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
SkipWritingToConfig.skip = self.previous


def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")

shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title

if state_dict is None:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
Expand Down
19 changes: 18 additions & 1 deletion modules/sd_samplers_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import torch
from PIL import Image
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
from modules.shared import opts, state

SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
Expand Down Expand Up @@ -127,3 +127,20 @@ def torchsde_randn(size, dtype, device, seed):


replace_torchsde_browinan()


def apply_refiner(sampler):
completed_ratio = sampler.step / sampler.steps
if completed_ratio > shared.opts.sd_refiner_switch_at and shared.sd_model.sd_checkpoint_info.title != shared.opts.sd_refiner_checkpoint:
refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint)
if refiner_checkpoint_info is None:
raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}')

with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=refiner_checkpoint_info)

devices.torch_gc()

sampler.update_inner_model()

sampler.p.setup_conds()
12 changes: 11 additions & 1 deletion modules/sd_samplers_compvis.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model)
self.p = None
self.sampler = constructor(shared.sd_model)
self.is_ddim = hasattr(self.sampler, 'p_sample_ddim')
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler)
Expand All @@ -32,6 +33,7 @@ def __init__(self, constructor, sd_model):
self.nmask = None
self.init_latent = None
self.sampler_noises = None
self.steps = None
self.step = 0
self.stop_at = None
self.eta = None
Expand All @@ -44,6 +46,7 @@ def number_of_needed_noises(self, p):
return 0

def launch_sampling(self, steps, func):
self.steps = steps
state.sampling_steps = steps
state.sampling_step = 0

Expand All @@ -61,10 +64,15 @@ def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args,

return res

def update_inner_model(self):
self.sampler.model = shared.sd_model

def before_sample(self, x, ts, cond, unconditional_conditioning):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException

sd_samplers_common.apply_refiner(self)

if self.stop_at is not None and self.step > self.stop_at:
raise sd_samplers_common.InterruptedException

Expand Down Expand Up @@ -134,6 +142,8 @@ def unipc_after_update(self, x, model_x):
self.update_step(x)

def initialize(self, p):
self.p = p

if self.is_ddim:
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
else:
Expand Down
30 changes: 24 additions & 6 deletions modules/sd_samplers_kdiffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import inspect
import k_diffusion.sampling
from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra
from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra, sd_models

from modules.processing import StableDiffusionProcessing
from modules.shared import opts, state
Expand Down Expand Up @@ -87,15 +87,25 @@ class CFGDenoiser(torch.nn.Module):
negative prompt.
"""

def __init__(self, model):
def __init__(self):
super().__init__()
self.inner_model = model
self.model_wrap = None
self.mask = None
self.nmask = None
self.init_latent = None
self.steps = None
self.step = 0
self.image_cfg_scale = None
self.padded_cond_uncond = False
self.p = None

@property
def inner_model(self):
if self.model_wrap is None:
denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)

return self.model_wrap

def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:]
Expand All @@ -113,10 +123,15 @@ def combine_denoised_for_edit_model(self, x_out, cond_scale):

return denoised

def update_inner_model(self):
self.model_wrap = None

def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException

sd_samplers_common.apply_refiner(self)

# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
# so is_edit_model is set to False to support AND composition.
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
Expand Down Expand Up @@ -267,13 +282,13 @@ def randn_like(self, x):

class KDiffusionSampler:
def __init__(self, funcname, sd_model):
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser

self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
self.p = None
self.funcname = funcname
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
self.extra_params = sampler_extra_params.get(funcname, [])
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
self.model_wrap_cfg = CFGDenoiser()
self.model_wrap = self.model_wrap_cfg.inner_model
self.sampler_noises = None
self.stop_at = None
self.eta = None
Expand Down Expand Up @@ -305,6 +320,7 @@ def callback_state(self, d):
shared.total_tqdm.update()

def launch_sampling(self, steps, func):
self.model_wrap_cfg.steps = steps
state.sampling_steps = steps
state.sampling_step = 0

Expand All @@ -324,6 +340,8 @@ def number_of_needed_noises(self, p):
return p.steps

def initialize(self, p: StableDiffusionProcessing):
self.p = p
self.model_wrap_cfg.p = p
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap_cfg.step = 0
Expand Down
2 changes: 2 additions & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,8 @@ def list_samplers():
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
"sd_refiner_checkpoint": OptionInfo(None, "Refiner checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints).info("switch to another model in the middle of generation"),
"sd_refiner_switch_at": OptionInfo(1.0, "Refiner switch at", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}).info("fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation"),
}))

options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
Expand Down
Loading