Commit e0e80050 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

make StableDiffusionProcessing class not hold a reference to shared.sd_model object

parent 9991967f
Loading
Loading
Loading
Loading
+5 −4
Original line number Diff line number Diff line
@@ -94,7 +94,7 @@ def txt2img_image_conditioning(sd_model, x, width, height):
    return image_conditioning


class StableDiffusionProcessing():
class StableDiffusionProcessing:
    """
    The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
    """
@@ -102,7 +102,6 @@ class StableDiffusionProcessing():
        if sampler_index is not None:
            print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)

        self.sd_model = sd_model
        self.outpath_samples: str = outpath_samples
        self.outpath_grids: str = outpath_grids
        self.prompt: str = prompt
@@ -156,6 +155,10 @@ class StableDiffusionProcessing():
        self.all_subseeds = None
        self.iteration = 0

    @property
    def sd_model(self):
        return shared.sd_model

    def txt2img_image_conditioning(self, x, width=None, height=None):
        self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}

@@ -236,7 +239,6 @@ class StableDiffusionProcessing():
        raise NotImplementedError()

    def close(self):
        self.sd_model = None
        self.sampler = None


@@ -471,7 +473,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:

            if k == 'sd_model_checkpoint':
                sd_models.reload_model_weights()  # make onchange call for changing SD model
                p.sd_model = shared.sd_model

            if k == 'sd_vae':
                sd_vae.reload_vae_weights()  # make onchange call for changing VAE
+0 −1
Original line number Diff line number Diff line
@@ -86,7 +86,6 @@ def apply_checkpoint(p, x, xs):
    if info is None:
        raise RuntimeError(f"Unknown checkpoint: {x}")
    modules.sd_models.reload_model_weights(shared.sd_model, info)
    p.sd_model = shared.sd_model


def confirm_checkpoints(p, xs):