Commit 617c5b48 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

make it possible for StableDiffusionProcessing to accept multiple different...

make it possible for StableDiffusionProcessing to accept multiple different negative prompts in a batch
parent e35d8b49
Loading
Loading
Loading
Loading
+24 −22
Original line number Original line Diff line number Diff line
@@ -124,6 +124,7 @@ class StableDiffusionProcessing():
        self.scripts = None
        self.scripts = None
        self.script_args = None
        self.script_args = None
        self.all_prompts = None
        self.all_prompts = None
        self.all_negative_prompts = None
        self.all_seeds = None
        self.all_seeds = None
        self.all_subseeds = None
        self.all_subseeds = None


@@ -202,7 +203,7 @@ class StableDiffusionProcessing():




class Processed:
class Processed:
    def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
    def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
        self.images = images_list
        self.images = images_list
        self.prompt = p.prompt
        self.prompt = p.prompt
        self.negative_prompt = p.negative_prompt
        self.negative_prompt = p.negative_prompt
@@ -241,16 +242,18 @@ class Processed:
        self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
        self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
        self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
        self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning


        self.all_prompts = all_prompts or [self.prompt]
        self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
        self.all_seeds = all_seeds or [self.seed]
        self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
        self.all_subseeds = all_subseeds or [self.subseed]
        self.all_seeds = all_seeds or p.all_seeds or [self.seed]
        self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
        self.infotexts = infotexts or [info]
        self.infotexts = infotexts or [info]


    def js(self):
    def js(self):
        obj = {
        obj = {
            "prompt": self.prompt,
            "prompt": self.all_prompts[0],
            "all_prompts": self.all_prompts,
            "all_prompts": self.all_prompts,
            "negative_prompt": self.negative_prompt,
            "negative_prompt": self.all_negative_prompts[0],
            "all_negative_prompts": self.all_negative_prompts,
            "seed": self.seed,
            "seed": self.seed,
            "all_seeds": self.all_seeds,
            "all_seeds": self.all_seeds,
            "subseed": self.subseed,
            "subseed": self.subseed,
@@ -411,7 +414,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration


    generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
    generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])


    negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
    negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[0] if  p.all_negative_prompts[0] else ""


    return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
    return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()


@@ -440,10 +443,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
    else:
    else:
        assert p.prompt is not None
        assert p.prompt is not None


    with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
        processed = Processed(p, [], p.seed, "")
        file.write(processed.infotext(p, 0))

    devices.torch_gc()
    devices.torch_gc()


    seed = get_fixed_seed(p.seed)
    seed = get_fixed_seed(p.seed)
@@ -453,15 +452,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
    modules.sd_hijack.model_hijack.clear_comments()
    modules.sd_hijack.model_hijack.clear_comments()


    comments = {}
    comments = {}
    prompt_tmp = p.prompt
    negative_prompt_tmp = p.negative_prompt

    shared.prompt_styles.apply_styles(p)


    if type(p.prompt) == list:
    if type(p.prompt) == list:
        p.all_prompts = p.prompt
        p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
    else:
        p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]

    if type(p.negative_prompt) == list:
        p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
    else:
    else:
        p.all_prompts = p.batch_size * p.n_iter * [p.prompt]
        p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]


    if type(seed) == list:
    if type(seed) == list:
        p.all_seeds = seed
        p.all_seeds = seed
@@ -476,6 +476,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
    def infotext(iteration=0, position_in_batch=0):
    def infotext(iteration=0, position_in_batch=0):
        return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
        return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)


    with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
        processed = Processed(p, [], p.seed, "")
        file.write(processed.infotext(p, 0))

    if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
    if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
        model_hijack.embedding_db.load_textual_inversion_embeddings()
        model_hijack.embedding_db.load_textual_inversion_embeddings()


@@ -500,6 +504,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
                break
                break


            prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
            prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
            negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
            seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
            seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
            subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
            subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]


@@ -510,7 +515,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
                p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
                p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)


            with devices.autocast():
            with devices.autocast():
                uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
                uc = prompt_parser.get_learned_conditioning(shared.sd_model, negative_prompts, p.steps)
                c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
                c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)


            if len(model_hijack.comments) > 0:
            if len(model_hijack.comments) > 0:
@@ -596,14 +601,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:


    devices.torch_gc()
    devices.torch_gc()


    res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
    res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)


    if p.scripts is not None:
    if p.scripts is not None:
        p.scripts.postprocess(p, res)
        p.scripts.postprocess(p, res)


    p.prompt = prompt_tmp
    p.negative_prompt = negative_prompt_tmp

    return res
    return res




+0 −11
Original line number Original line Diff line number Diff line
@@ -65,17 +65,6 @@ class StyleDatabase:
    def apply_negative_styles_to_prompt(self, prompt, styles):
    def apply_negative_styles_to_prompt(self, prompt, styles):
        return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
        return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])


    def apply_styles(self, p: StableDiffusionProcessing) -> None:
        if isinstance(p.prompt, list):
            p.prompt = [self.apply_styles_to_prompt(prompt, p.styles) for prompt in p.prompt]
        else:
            p.prompt = self.apply_styles_to_prompt(p.prompt, p.styles)

        if isinstance(p.negative_prompt, list):
            p.negative_prompt = [self.apply_negative_styles_to_prompt(prompt, p.styles) for prompt in p.negative_prompt]
        else:
            p.negative_prompt = self.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)

    def save_styles(self, path: str) -> None:
    def save_styles(self, path: str) -> None:
        # Write to temporary file first, so we don't nuke the file if something goes wrong
        # Write to temporary file first, so we don't nuke the file if something goes wrong
        fd, temp_path = tempfile.mkstemp(".csv")
        fd, temp_path = tempfile.mkstemp(".csv")