Commit 10421f93 authored by brkirch's avatar brkirch
Browse files

Fix full previews, --no-half-vae

parent 6cff4401
Loading
Loading
Loading
Loading
+4 −4
Original line number Original line Diff line number Diff line
@@ -172,7 +172,7 @@ class StableDiffusionProcessing:
        midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
        midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
        midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
        midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)


        conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_unet) if devices.unet_needs_upcast else source_image))
        conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_vae) if devices.unet_needs_upcast else source_image))
        conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image
        conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image
        conditioning = torch.nn.functional.interpolate(
        conditioning = torch.nn.functional.interpolate(
            self.sd_model.depth_model(midas_in),
            self.sd_model.depth_model(midas_in),
@@ -217,7 +217,7 @@ class StableDiffusionProcessing:
        )
        )


        # Encode the new masked image using first stage of network.
        # Encode the new masked image using first stage of network.
        conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_unet) if devices.unet_needs_upcast else conditioning_image))
        conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_vae) if devices.unet_needs_upcast else conditioning_image))


        # Create the concatenated conditioning tensor to be fed to `c_concat`
        # Create the concatenated conditioning tensor to be fed to `c_concat`
        conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
        conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
@@ -417,7 +417,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see


def decode_first_stage(model, x):
def decode_first_stage(model, x):
    with devices.autocast(disable=x.dtype == devices.dtype_vae):
    with devices.autocast(disable=x.dtype == devices.dtype_vae):
        x = model.decode_first_stage(x)
        x = model.decode_first_stage(x.to(devices.dtype_vae) if devices.unet_needs_upcast else x)


    return x
    return x


@@ -1001,7 +1001,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):


        image = torch.from_numpy(batch_images)
        image = torch.from_numpy(batch_images)
        image = 2. * image - 1.
        image = 2. * image - 1.
        image = image.to(device=shared.device, dtype=devices.dtype_unet if devices.unet_needs_upcast else None)
        image = image.to(device=shared.device, dtype=devices.dtype_vae if devices.unet_needs_upcast else None)


        self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
        self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))


+1 −1
Original line number Original line Diff line number Diff line
@@ -5,7 +5,7 @@ class CondFunc:
        self = super(CondFunc, cls).__new__(cls)
        self = super(CondFunc, cls).__new__(cls)
        if isinstance(orig_func, str):
        if isinstance(orig_func, str):
            func_path = orig_func.split('.')
            func_path = orig_func.split('.')
            for i in range(len(func_path)-2, -1, -1):
            for i in range(len(func_path)-1, -1, -1):
                try:
                try:
                    resolved_obj = importlib.import_module('.'.join(func_path[:i]))
                    resolved_obj = importlib.import_module('.'.join(func_path[:i]))
                    break
                    break