Commit b5050ad2 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

make SD2 compatible with --medvram setting

parent 64c7b797
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -51,6 +51,10 @@ def setup_for_low_vram(sd_model, use_medvram):
        send_me_to_gpu(first_stage_model, None)
        return first_stage_model_decode(z)

    # for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
    if hasattr(sd_model.cond_stage_model, 'model'):
        sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model

    # remove three big modules, cond, first_stage, and unet from the model and then
    # send the model to GPU. Then put modules back. the modules will be in CPU.
    stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
@@ -65,6 +69,10 @@ def setup_for_low_vram(sd_model, use_medvram):
    sd_model.first_stage_model.decode = first_stage_model_decode_wrap
    parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model

    if hasattr(sd_model.cond_stage_model, 'model'):
        sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
        del sd_model.cond_stage_model.transformer

    if use_medvram:
        sd_model.model.register_forward_pre_hook(send_me_to_gpu)
    else: