Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial refiner support #12371

Merged
merged 9 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,9 @@ def setup_conds(self):
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)

def get_conds(self):
return self.c, self.uc

def parse_extra_network_prompts(self):
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)

Expand Down Expand Up @@ -611,6 +614,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}

try:
# after running refiner, the refiner model is not unloaded - webui swaps back to main model here
if shared.sd_model.sd_checkpoint_info.title != opts.sd_model_checkpoint:
sd_models.reload_model_weights()

# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
p.override_settings.pop('sd_model_checkpoint', None)
Expand Down Expand Up @@ -710,6 +717,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.interrupted:
break

sd_models.reload_model_weights() # model can be changed for example by refiner

p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
Expand Down Expand Up @@ -1201,6 +1210,13 @@ def setup_conds(self):
with devices.autocast():
extra_networks.activate(self, self.extra_network_data)

def get_conds(self):
if self.is_hr_pass:
return self.hr_c, self.hr_uc

return super().get_conds()


def parse_extra_network_prompts(self):
res = super().parse_extra_network_prompts()

Expand Down
25 changes: 22 additions & 3 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,27 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
return res


class SkipWritingToConfig:
"""This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight."""

skip = False
previous = None

def __enter__(self):
self.previous = SkipWritingToConfig.skip
SkipWritingToConfig.skip = True
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
SkipWritingToConfig.skip = self.previous


def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")

shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title

if state_dict is None:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
Expand Down Expand Up @@ -624,8 +640,11 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
timer.record("send model to device")

model_data.set_sd_model(already_loaded)
shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256

if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256

print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
return model_data.sd_model
elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
Expand Down
23 changes: 21 additions & 2 deletions modules/sd_samplers_cfg_denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,24 @@ class CFGDenoiser(torch.nn.Module):
negative prompt.
"""

def __init__(self, model, sampler):
def __init__(self, sampler):
super().__init__()
self.inner_model = model
self.model_wrap = None
self.mask = None
self.nmask = None
self.init_latent = None
self.steps = None
self.step = 0
self.image_cfg_scale = None
self.padded_cond_uncond = False
self.sampler = sampler
self.model_wrap = None
self.p = None

@property
def inner_model(self):
raise NotImplementedError()


def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:]
Expand All @@ -68,10 +76,21 @@ def combine_denoised_for_edit_model(self, x_out, cond_scale):
def get_pred_x0(self, x_in, x_out, sigma):
return x_out

def update_inner_model(self):
self.model_wrap = None

c, uc = self.p.get_conds()
self.sampler.sampler_extra_args['cond'] = c
self.sampler.sampler_extra_args['uncond'] = uc

def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException

if sd_samplers_common.apply_refiner(self):
cond = self.sampler.sampler_extra_args['cond']
uncond = self.sampler.sampler_extra_args['uncond']

# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
# so is_edit_model is set to False to support AND composition.
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
Expand Down
37 changes: 35 additions & 2 deletions modules/sd_samplers_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import torch
from PIL import Image
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
from modules.shared import opts, state
import k_diffusion.sampling

Expand Down Expand Up @@ -131,6 +131,35 @@ def torchsde_randn(size, dtype, device, seed):
replace_torchsde_browinan()


def apply_refiner(sampler):
completed_ratio = sampler.step / sampler.steps

if completed_ratio <= shared.opts.sd_refiner_switch_at:
return False

if shared.opts.sd_refiner_checkpoint == "None":
return False

if shared.sd_model.sd_checkpoint_info.title == shared.opts.sd_refiner_checkpoint:
return False

refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint)
if refiner_checkpoint_info is None:
raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}')

sampler.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
sampler.p.extra_generation_params['Refiner switch at'] = shared.opts.sd_refiner_switch_at

with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=refiner_checkpoint_info)

devices.torch_gc()
sampler.p.setup_conds()
sampler.update_inner_model()

return True


class TorchHijack:
"""This is here to replace torch.randn_like of k-diffusion.

Expand Down Expand Up @@ -176,8 +205,9 @@ def __init__(self, funcname):

self.conditioning_key = shared.sd_model.model.conditioning_key

self.model_wrap = None
self.p = None
self.model_wrap_cfg = None
self.sampler_extra_args = None

def callback_state(self, d):
step = d['i']
Expand All @@ -189,6 +219,7 @@ def callback_state(self, d):
shared.total_tqdm.update()

def launch_sampling(self, steps, func):
self.model_wrap_cfg.steps = steps
state.sampling_steps = steps
state.sampling_step = 0

Expand All @@ -208,6 +239,8 @@ def number_of_needed_noises(self, p):
return p.steps

def initialize(self, p) -> dict:
self.p = p
self.model_wrap_cfg.p = p
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap_cfg.step = 0
Expand Down
Empty file.
29 changes: 18 additions & 11 deletions modules/sd_samplers_kdiffusion.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import torch
import inspect
import k_diffusion.sampling
from modules import sd_samplers_common, sd_samplers_extra
from modules.sd_samplers_cfg_denoiser import CFGDenoiser
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser

from modules.shared import opts
import modules.shared as shared
Expand Down Expand Up @@ -53,17 +52,24 @@
}


class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
@property
def inner_model(self):
if self.model_wrap is None:
denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)

return self.model_wrap


class KDiffusionSampler(sd_samplers_common.Sampler):
def __init__(self, funcname, sd_model):

super().__init__(funcname)

self.extra_params = sampler_extra_params.get(funcname, [])
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)

denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
self.model_wrap_cfg = CFGDenoiser(self.model_wrap, self)
self.model_wrap_cfg = CFGDenoiserKDiffusion(self)
self.model_wrap = self.model_wrap_cfg.inner_model

def get_sigmas(self, p, steps):
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
Expand Down Expand Up @@ -144,15 +150,15 @@ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning,

self.model_wrap_cfg.init_latent = x
self.last_latent = x
extra_args = {
self.sampler_extra_args = {
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale,
's_min_uncond': self.s_min_uncond
}

samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))

if self.model_wrap_cfg.padded_cond_uncond:
p.extra_generation_params["Pad conds"] = True
Expand Down Expand Up @@ -184,13 +190,14 @@ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, ima
extra_params_kwargs['noise_sampler'] = noise_sampler

self.last_latent = x
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
self.sampler_extra_args = {
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale,
's_min_uncond': self.s_min_uncond
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
}
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))

if self.model_wrap_cfg.padded_cond_uncond:
p.extra_generation_params["Pad conds"] = True
Expand Down
27 changes: 17 additions & 10 deletions modules/sd_samplers_timesteps.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ def forward(self, input, timesteps, **kwargs):

class CFGDenoiserTimesteps(CFGDenoiser):

def __init__(self, model, sampler):
super().__init__(model, sampler)
def __init__(self, sampler):
super().__init__(sampler)

self.alphas = model.inner_model.alphas_cumprod
self.alphas = shared.sd_model.alphas_cumprod

def get_pred_x0(self, x_in, x_out, sigma):
ts = int(sigma.item())
Expand All @@ -61,6 +61,14 @@ def get_pred_x0(self, x_in, x_out, sigma):

return pred_x0

@property
def inner_model(self):
if self.model_wrap is None:
denoiser = CompVisTimestepsVDenoiser if shared.sd_model.parameterization == "v" else CompVisTimestepsDenoiser
self.model_wrap = denoiser(shared.sd_model)

return self.model_wrap


class CompVisSampler(sd_samplers_common.Sampler):
def __init__(self, funcname, sd_model):
Expand All @@ -69,9 +77,7 @@ def __init__(self, funcname, sd_model):
self.eta_option_field = 'eta_ddim'
self.eta_infotext_field = 'Eta DDIM'

denoiser = CompVisTimestepsVDenoiser if sd_model.parameterization == "v" else CompVisTimestepsDenoiser
self.model_wrap = denoiser(sd_model)
self.model_wrap_cfg = CFGDenoiserTimesteps(self.model_wrap, self)
self.model_wrap_cfg = CFGDenoiserTimesteps(self)

def get_timesteps(self, p, steps):
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
Expand Down Expand Up @@ -107,15 +113,15 @@ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning,

self.model_wrap_cfg.init_latent = x
self.last_latent = x
extra_args = {
self.sampler_extra_args = {
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale,
's_min_uncond': self.s_min_uncond
}

samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))

if self.model_wrap_cfg.padded_cond_uncond:
p.extra_generation_params["Pad conds"] = True
Expand All @@ -133,13 +139,14 @@ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, ima
extra_params_kwargs['timesteps'] = timesteps

self.last_latent = x
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
self.sampler_extra_args = {
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale,
's_min_uncond': self.s_min_uncond
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
}
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))

if self.model_wrap_cfg.padded_cond_uncond:
p.extra_generation_params["Pad conds"] = True
Expand Down
2 changes: 2 additions & 0 deletions modules/shared_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
"tiling": OptionInfo(False, "Tiling", infotext='Tiling').info("produce a tileable picture"),
"sd_refiner_checkpoint": OptionInfo("None", "Refiner checkpoint", gr.Dropdown, lambda: {"choices": ["None"] + shared_items.list_checkpoint_tiles()}, refresh=shared_items.refresh_checkpoints, infotext="Refiner").info("switch to another model in the middle of generation"),
"sd_refiner_switch_at": OptionInfo(1.0, "Refiner switch at", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}, infotext='Refiner switch at').info("fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation"),
}))

options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
Expand Down