Commit d2c97fc3 authored by flamelaw's avatar flamelaw
Browse files

fix dropout, implement train/eval mode

parent 89d8ecff
Loading
Loading
Loading
Loading
+18 −6
Original line number Diff line number Diff line
@@ -154,16 +154,28 @@ class Hypernetwork:
                HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
                                   self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
            )
        self.eval_mode()

    def weights(self):
        res = []
        for k, layers in self.layers.items():
            for layer in layers:
                res += layer.parameters()
        return res

    def train_mode(self):
        for k, layers in self.layers.items():
            for layer in layers:
                layer.train()
                res += layer.trainables()
                for param in layer.parameters():
                    param.requires_grad = True

        return res
    def eval_mode(self):
        for k, layers in self.layers.items():
            for layer in layers:
                layer.eval()
                for param in layer.parameters():
                    param.requires_grad = False

    def save(self, filename):
        state_dict = {}
@@ -426,8 +438,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
        shared.sd_model.first_stage_model.to(devices.cpu)
    
    weights = hypernetwork.weights()
    for weight in weights:
        weight.requires_grad = True
    hypernetwork.train_mode()

    # Here we use optimizer from saved HN, or we can specify as UI option.
    if hypernetwork.optimizer_name in optimizer_dict:
@@ -538,7 +549,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
                if images_dir is not None and steps_done % create_image_every == 0:
                    forced_filename = f'{hypernetwork_name}-{steps_done}'
                    last_saved_image = os.path.join(images_dir, forced_filename)

                    hypernetwork.eval_mode()
                    shared.sd_model.cond_stage_model.to(devices.device)
                    shared.sd_model.first_stage_model.to(devices.device)

@@ -571,7 +582,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
                    if unload:
                        shared.sd_model.cond_stage_model.to(devices.cpu)
                        shared.sd_model.first_stage_model.to(devices.cpu)

                    hypernetwork.train_mode()
                    if image is not None:
                        shared.state.current_image = image
                        last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
@@ -593,6 +604,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
    finally:
        pbar.leave = False
        pbar.close()
        hypernetwork.eval_mode()
        #report_statistics(loss_dict)

    filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')