Unverified Commit 52674143 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #4271 from MarkovInequality/racecond_fix

Fixes #4137 caused by race condition in training when VAE is unloaded
parents 5cd5a672 c9a2cfdf
Loading
Loading
Loading
Loading
+5 −0
Original line number Original line Diff line number Diff line
@@ -433,7 +433,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,


    dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
    dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)


    old_parallel_processing_allowed = shared.parallel_processing_allowed

    if unload:
    if unload:
        shared.parallel_processing_allowed = False
        shared.sd_model.cond_stage_model.to(devices.cpu)
        shared.sd_model.cond_stage_model.to(devices.cpu)
        shared.sd_model.first_stage_model.to(devices.cpu)
        shared.sd_model.first_stage_model.to(devices.cpu)
    
    
@@ -612,10 +615,12 @@ Last saved image: {html.escape(last_saved_image)}<br/>
    if shared.opts.save_optimizer_state:
    if shared.opts.save_optimizer_state:
        hypernetwork.optimizer_state_dict = optimizer.state_dict()
        hypernetwork.optimizer_state_dict = optimizer.state_dict()
    save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
    save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)

    del optimizer
    del optimizer
    hypernetwork.optimizer_state_dict = None  # dereference it after saving, to save memory.
    hypernetwork.optimizer_state_dict = None  # dereference it after saving, to save memory.
    shared.sd_model.cond_stage_model.to(devices.device)
    shared.sd_model.cond_stage_model.to(devices.device)
    shared.sd_model.first_stage_model.to(devices.device)
    shared.sd_model.first_stage_model.to(devices.device)
    shared.parallel_processing_allowed = old_parallel_processing_allowed


    return hypernetwork, filename
    return hypernetwork, filename


+3 −0
Original line number Original line Diff line number Diff line
@@ -269,6 +269,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_


   # dataset loading may take a while, so input validations and early returns should be done before this
   # dataset loading may take a while, so input validations and early returns should be done before this
    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
    old_parallel_processing_allowed = shared.parallel_processing_allowed
    
    
    pin_memory = shared.opts.pin_memory
    pin_memory = shared.opts.pin_memory
    
    
@@ -279,6 +280,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
    dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
    dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)


    if unload:
    if unload:
        shared.parallel_processing_allowed = False
        shared.sd_model.first_stage_model.to(devices.cpu)
        shared.sd_model.first_stage_model.to(devices.cpu)


    embedding.vec.requires_grad = True
    embedding.vec.requires_grad = True
@@ -450,6 +452,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
        pbar.leave = False
        pbar.leave = False
        pbar.close()
        pbar.close()
        shared.sd_model.first_stage_model.to(devices.device)
        shared.sd_model.first_stage_model.to(devices.device)
        shared.parallel_processing_allowed = old_parallel_processing_allowed


    return embedding, filename
    return embedding, filename