Commit 9a3f35b0 authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

repair medvram and lowvram

parent abb948da
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -100,6 +100,8 @@ def setup_for_low_vram(sd_model, use_medvram):
        sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
    if sd_model.embedder:
        sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)

    if hasattr(sd_model, 'cond_stage_model'):
        parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model

    if use_medvram:
+2 −2
Original line number Diff line number Diff line
@@ -32,7 +32,7 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit
    def encode_embedding_init_text(self, init_text, nvpt):
        ids = tokenizer.encode(init_text)
        ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
        embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
        embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0)

        return embedded

@@ -66,6 +66,6 @@ class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWi
    def encode_embedding_init_text(self, init_text, nvpt):
        ids = tokenizer.encode(init_text)
        ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
        embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
        embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0)

        return embedded