Commit 0b8acce6 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

separate part of denoiser code into a function to make it easier for extensions to override it

parent 03d7b394
Loading
Loading
Loading
Loading
+11 −6
Original line number Original line Diff line number Diff line
@@ -288,6 +288,16 @@ class CFGDenoiser(torch.nn.Module):
        self.init_latent = None
        self.init_latent = None
        self.step = 0
        self.step = 0


    def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
        denoised_uncond = x_out[-uncond.shape[0]:]
        denoised = torch.clone(denoised_uncond)

        for i, conds in enumerate(conds_list):
            for cond_index, weight in conds:
                denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)

        return denoised

    def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
    def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
        if state.interrupted or state.skipped:
        if state.interrupted or state.skipped:
            raise InterruptedException
            raise InterruptedException
@@ -329,12 +339,7 @@ class CFGDenoiser(torch.nn.Module):


            x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
            x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})


        denoised_uncond = x_out[-uncond.shape[0]:]
        denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
        denoised = torch.clone(denoised_uncond)

        for i, conds in enumerate(conds_list):
            for cond_index, weight in conds:
                denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)


        if self.mask is not None:
        if self.mask is not None:
            denoised = self.init_latent * self.mask + self.nmask * denoised
            denoised = self.init_latent * self.mask + self.nmask * denoised