Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: sort enhance images #62

Merged
merged 3 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions args_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
args_parser.parser.add_argument("--disable-preset-download", action='store_true',
help="Disables downloading models for presets", default=False)

args_parser.parser.add_argument("--disable-enhance-output-sorting", action='store_true',
help="Disables enhance output sorting for final image gallery.")

args_parser.parser.add_argument("--enable-auto-describe-image", action='store_true',
help="Enables automatic description of uov and enhance image when prompt is empty", default=False)

Expand Down
22 changes: 16 additions & 6 deletions modules/async_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class AsyncTask:
def __init__(self, args):
from modules.flags import Performance, MetadataScheme, ip_list, controlnet_image_count
from modules.flags import Performance, MetadataScheme, ip_list, controlnet_image_count, disabled
from modules.util import get_enabled_loras
from modules.config import default_max_lora_number
import args_manager
Expand Down Expand Up @@ -155,7 +155,9 @@ def __init__(self, args):
enhance_inpaint_erode_or_dilate,
enhance_mask_invert
])

self.should_enhance = self.enhance_checkbox and (self.enhance_uov_method != disabled.casefold() or len(self.enhance_ctrls) > 0)
self.images_to_enhance_count = 0
self.enhance_stats = {}

async_tasks = []

Expand Down Expand Up @@ -1276,8 +1278,8 @@ def callback(step, x0, x, total_steps, y):
int(current_progress + async_task.callback_steps),
f'Sampling step {step + 1}/{total_steps}, image {current_task_id + 1}/{total_count} ...', y)])

show_intermediate_results = len(tasks) > 1 or should_enhance
persist_image = not should_enhance or not async_task.save_final_enhanced_image_only
show_intermediate_results = len(tasks) > 1 or async_task.should_enhance
persist_image = not async_task.should_enhance or not async_task.save_final_enhanced_image_only

for current_task_id, task in enumerate(tasks):
progressbar(async_task, current_progress, f'Preparing task {current_task_id + 1}/{async_task.image_number} ...')
Expand Down Expand Up @@ -1309,7 +1311,7 @@ def callback(step, x0, x, total_steps, y):
execution_time = time.perf_counter() - execution_start_time
print(f'Generating and saving time: {execution_time:.2f} seconds')

if not should_enhance:
if not async_task.should_enhance:
print(f'[Enhance] Skipping, preconditions aren\'t met')
stop_processing(async_task, processing_start_time)
return
Expand All @@ -1325,14 +1327,16 @@ def callback(step, x0, x, total_steps, y):
enhance_uov_before = async_task.enhance_uov_processing_order == flags.enhancement_uov_before
enhance_uov_after = async_task.enhance_uov_processing_order == flags.enhancement_uov_after
total_count = len(images_to_enhance) * active_enhance_tabs
async_task.images_to_enhance_count = len(images_to_enhance)

base_progress = current_progress
current_task_id = -1
done_steps_upscaling = 0
done_steps_inpainting = 0
enhance_steps, _, _, _ = apply_overrides(async_task, async_task.original_steps, height, width)
exception_result = None
for img in images_to_enhance:
for index, img in enumerate(images_to_enhance):
async_task.enhance_stats[index] = 0
enhancement_image_start_time = time.perf_counter()

last_enhance_prompt = async_task.prompt
Expand All @@ -1346,6 +1350,8 @@ def callback(step, x0, x, total_steps, y):
current_task_id, denoising_strength, done_steps_inpainting, done_steps_upscaling, enhance_steps,
async_task.prompt, async_task.negative_prompt, final_scheduler_name, height, img, preparation_steps,
switch, tiled, total_count, use_expansion, use_style, use_synthetic_refiner, width, persist_image)
async_task.enhance_stats[index] += 1

if exception_result == 'continue':
continue
elif exception_result == 'break':
Expand Down Expand Up @@ -1389,6 +1395,7 @@ def callback(step, x0, x, total_steps, y):
async_task.yields.append(['preview', (current_progress, 'Loading ...', mask)])
yield_result(async_task, mask, current_progress, async_task.black_out_nsfw, False,
async_task.disable_intermediate_results)
async_task.enhance_stats[index] += 1

print(f'[Enhance] {dino_detection_count} boxes detected')
print(f'[Enhance] {sam_detection_count} segments detected in boxes')
Expand All @@ -1408,6 +1415,7 @@ def callback(step, x0, x, total_steps, y):
enhance_prompt, enhance_negative_prompt, final_scheduler_name, goals_enhance, height, img, mask,
preparation_steps, enhance_steps, switch, tiled, total_count, use_expansion, use_style,
use_synthetic_refiner, width, persist_image=persist_image)
async_task.enhance_stats[index] += 1

if (should_process_enhance_uov and async_task.enhance_uov_processing_order == flags.enhancement_uov_after
and async_task.enhance_uov_prompt_type == flags.enhancement_uov_prompt_type_last_filled):
Expand Down Expand Up @@ -1444,6 +1452,8 @@ def callback(step, x0, x, total_steps, y):
last_enhance_prompt, last_enhance_negative_prompt, final_scheduler_name, height, img,
preparation_steps, switch, tiled, total_count, use_expansion, use_style, use_synthetic_refiner,
width, persist_image)
async_task.enhance_stats[index] += 1

if exception_result == 'continue':
continue
elif exception_result == 'break':
Expand Down
1 change: 1 addition & 0 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,7 @@ entry_with_update.py [-h] [--listen [IP]] [--port PORT]
[--disable-offload-from-vram] [--theme THEME]
[--disable-image-log] [--disable-analytics]
[--disable-metadata] [--disable-preset-download]
[--disable-enhance-output-sorting]
[--enable-auto-describe-image]
[--always-download-new-model]
[--rebuild-hash-cache [CPU_NUM_THREADS]]
Expand Down
22 changes: 22 additions & 0 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def generate_clicked(task: worker.AsyncTask):
gr.update(visible=True, value=product), \
gr.update(visible=False)
if flag == 'finish':
if not args_manager.args.disable_enhance_output_sorting:
product = sort_enhance_images(product, task)

yield gr.update(visible=False), \
gr.update(visible=False), \
gr.update(visible=False), \
Expand All @@ -90,6 +93,25 @@ def generate_clicked(task: worker.AsyncTask):
return


def sort_enhance_images(images, task):
if not task.should_enhance or len(images) <= task.images_to_enhance_count:
return images

sorted_images = []
walk_index = task.images_to_enhance_count

for index, enhanced_img in enumerate(images[:task.images_to_enhance_count]):
sorted_images.append(enhanced_img)
if index not in task.enhance_stats:
continue
target_index = walk_index + task.enhance_stats[index]
if walk_index < len(images) and target_index <= len(images):
sorted_images += images[walk_index:target_index]
walk_index += task.enhance_stats[index]

return sorted_images


def inpaint_mode_change(mode, inpaint_engine_version):
assert mode in modules.flags.inpaint_options

Expand Down