Commit 03e57088 authored by frostydad's avatar frostydad Committed by AUTOMATIC1111
Browse files

Fix incorrect sampler name in output

parent 122d4268
Loading
Loading
Loading
Loading
+8 −1
Original line number Diff line number Diff line

import json
import math
import os
@@ -46,6 +47,12 @@ def apply_color_correction(correction, image):
    return image


def get_correct_sampler(p):
    if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
        return sd_samplers.samplers
    elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
        return sd_samplers.samplers_for_img2img

class StableDiffusionProcessing:
    def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
        self.sd_model = sd_model
@@ -272,7 +279,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration

    generation_params = {
        "Steps": p.steps,
        "Sampler": sd_samplers.samplers[p.sampler_index].name,
        "Sampler": get_correct_sampler(p)[p.sampler_index].name,
        "CFG scale": p.cfg_scale,
        "Seed": all_seeds[index],
        "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
+9 −7
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@ import modules.scripts as scripts
import gradio as gr

from modules import images
from modules.processing import process_images, Processed
from modules.processing import process_images, Processed, get_correct_sampler
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.sd_samplers
@@ -56,15 +56,17 @@ def apply_order(p, x, xs):
    p.prompt = prompt_tmp + p.prompt
    

def build_samplers_dict(p):
    samplers_dict = {}
for i, sampler in enumerate(modules.sd_samplers.samplers):
    for i, sampler in enumerate(get_correct_sampler(p)):
        samplers_dict[sampler.name.lower()] = i
        for alias in sampler.aliases:
            samplers_dict[alias.lower()] = i
    return samplers_dict


def apply_sampler(p, x, xs):
    sampler_index = samplers_dict.get(x.lower(), None)
    sampler_index = build_samplers_dict(p).get(x.lower(), None)
    if sampler_index is None:
        raise RuntimeError(f"Unknown sampler: {x}")