Commit 1764ac3c authored by aria1th's avatar aria1th
Browse files

use hash to check valid optim

parent 0b143c11
Loading
Loading
Loading
Loading
+9 −4
Original line number Diff line number Diff line
@@ -177,12 +177,13 @@ class Hypernetwork:
        state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
        if self.optimizer_name is not None:
            optimizer_saved_dict['optimizer_name'] = self.optimizer_name

        torch.save(state_dict, filename)
        if self.optimizer_state_dict:
            optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
            optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
            torch.save(optimizer_saved_dict, filename + '.optim')

        torch.save(state_dict, filename)

    def load(self, filename):
        self.filename = filename
        if self.name is None:
@@ -204,7 +205,10 @@ class Hypernetwork:
        optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
        self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
        print(f"Optimizer name is {self.optimizer_name}")
        if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
            self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
        else:
            self.optimizer_state_dict = None
        if self.optimizer_state_dict:
            print("Loaded existing optimizer from checkpoint")
        else:
@@ -229,7 +233,7 @@ def list_hypernetworks(path):
        name = os.path.splitext(os.path.basename(filename))[0]
        # Prevent a hypothetical "None.pt" from being listed.
        if name != "None":
            res[name] = filename
            res[name + f"({sd_models.model_hash(filename)})"] = filename
    return res


@@ -375,6 +379,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
    else:
        hypernetwork_dir = None

    hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
    if create_image_every > 0:
        images_dir = os.path.join(log_directory, "images")
        os.makedirs(images_dir, exist_ok=True)