Commit adb6cb76 authored by Billy Cao's avatar Billy Cao
Browse files

Patch UNet Forward to support resolutions that are not multiples of 64

Also modifed the UI to no longer step in 64
parent 828438b4
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@ import ldm.modules.attention
import ldm.modules.diffusionmodules.model
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
import ldm.modules.diffusionmodules.openaimodel

attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
@@ -26,6 +27,7 @@ def apply_optimizations():
    undo_optimizations()

    ldm.modules.diffusionmodules.model.nonlinearity = silu
    ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_hijack_optimizations.patched_unet_forward

    if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
        print("Applying xformers cross attention optimization.")
+31 −0
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@ import importlib

import torch
from torch import einsum
import torch.nn.functional as F

from ldm.util import default
from einops import rearrange
@@ -12,6 +13,8 @@ from einops import rearrange
from modules import shared
from modules.hypernetworks import hypernetwork

from ldm.modules.diffusionmodules.util import timestep_embedding


if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
    try:
@@ -310,3 +313,31 @@ def xformers_attnblock_forward(self, x):
        return x + out
    except NotImplementedError:
        return cross_attention_attnblock_forward(self, x)

def patched_unet_forward(self, x, timesteps=None, context=None, y=None,**kwargs):
    assert (y is not None) == (
        self.num_classes is not None
    ), "must specify y if and only if the model is class-conditional"
    hs = []
    t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
    emb = self.time_embed(t_emb)

    if self.num_classes is not None:
        assert y.shape == (x.shape[0],)
        emb = emb + self.label_emb(y)

    h = x.type(self.dtype)
    for module in self.input_blocks:
        h = module(h, emb, context)
        hs.append(h)
    h = self.middle_block(h, emb, context)
    for module in self.output_blocks:
        if h.shape[-2:] != hs[-1].shape[-2:]:
            h = F.interpolate(h, hs[-1].shape[-2:], mode="nearest")
        h = torch.cat([h, hs.pop()], dim=1)
        h = module(h, emb, context)
    h = h.type(x.dtype)
    if self.predict_codebook_ids:
        return self.id_predictor(h)
    else:
        return self.out(h)
+12 −12
Original line number Diff line number Diff line
@@ -380,8 +380,8 @@ def create_seed_inputs():

    with gr.Row(visible=False) as seed_extra_row_2:
        seed_extras.append(seed_extra_row_2)
        seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from width", value=0)
        seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from height", value=0)
        seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=1, label="Resize seed from width", value=0)
        seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=1, label="Resize seed from height", value=0)

    random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
    random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
@@ -715,8 +715,8 @@ def create_ui(wrap_gradio_gpu_call):
                sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index")

                with gr.Group():
                    width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
                    height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
                    width = gr.Slider(minimum=64, maximum=2048, step=1, label="Width", value=512)
                    height = gr.Slider(minimum=64, maximum=2048, step=1, label="Height", value=512)

                with gr.Row():
                    restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
@@ -724,8 +724,8 @@ def create_ui(wrap_gradio_gpu_call):
                    enable_hr = gr.Checkbox(label='Highres. fix', value=False)

                with gr.Row(visible=False) as hr_options:
                    firstphase_width = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass width", value=0)
                    firstphase_height = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass height", value=0)
                    firstphase_width = gr.Slider(minimum=0, maximum=1024, step=1, label="Firstpass width", value=0)
                    firstphase_height = gr.Slider(minimum=0, maximum=1024, step=1, label="Firstpass height", value=0)
                    denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)

                with gr.Row(equal_height=True):
@@ -901,8 +901,8 @@ def create_ui(wrap_gradio_gpu_call):
                sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")

                with gr.Group():
                    width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width")
                    height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height")
                    width = gr.Slider(minimum=64, maximum=2048, step=1, label="Width", value=512, elem_id="img2img_width")
                    height = gr.Slider(minimum=64, maximum=2048, step=1, label="Height", value=512, elem_id="img2img_height")

                with gr.Row():
                    restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
@@ -1231,8 +1231,8 @@ def create_ui(wrap_gradio_gpu_call):
                with gr.Tab(label="Preprocess images"):
                    process_src = gr.Textbox(label='Source directory')
                    process_dst = gr.Textbox(label='Destination directory')
                    process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
                    process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
                    process_width = gr.Slider(minimum=64, maximum=2048, step=1, label="Width", value=512)
                    process_height = gr.Slider(minimum=64, maximum=2048, step=1, label="Height", value=512)
                    preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"])

                    with gr.Row():
@@ -1289,8 +1289,8 @@ def create_ui(wrap_gradio_gpu_call):
                    dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
                    log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
                    template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
                    training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
                    training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
                    training_width = gr.Slider(minimum=64, maximum=2048, step=1, label="Width", value=512)
                    training_height = gr.Slider(minimum=64, maximum=2048, step=1, label="Height", value=512)
                    steps = gr.Number(label='Max steps', value=100000, precision=0)
                    create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
                    save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)