Unverified Commit 5524301a authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #9169 from space-nuko/extension-settings-backup

Extension settings backup/restore feature
parents c018eefe 78d0ee3b
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -33,3 +33,4 @@ notification.mp3
/test/stdout.txt
/test/stderr.txt
/cache.json*
/config_states/
+22 −0
Original line number Diff line number Diff line
@@ -47,3 +47,25 @@ function install_extension_from_index(button, url){

    gradioApp().querySelector('#install_extension_button').click()
}

function config_state_confirm_restore(_, config_state_name, config_restore_type) {
    if (config_state_name == "Current") {
        return [false, config_state_name, config_restore_type];
    }
    let restored = "";
    if (config_restore_type == "extensions") {
        restored = "all saved extension versions";
    } else if (config_restore_type == "webui") {
        restored = "the webui version";
    } else {
        restored = "the webui version and all saved extension versions";
    }
    let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + ".");
    if (confirmed) {
        restart_reload();
        gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
            x.innerHTML = "Loading..."
        })
    }
    return [confirmed, config_state_name, config_restore_type];
}
+200 −0
Original line number Diff line number Diff line
"""
Supports saving and restoring webui and extensions from a known working set of commits
"""

import os
import sys
import traceback
import json
import time
import tqdm

from datetime import datetime
from collections import OrderedDict
import git

from modules import shared, extensions
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path, config_states_dir


all_config_states = OrderedDict()


def list_config_states():
    global all_config_states

    all_config_states.clear()
    os.makedirs(config_states_dir, exist_ok=True)

    config_states = []
    for filename in os.listdir(config_states_dir):
        if filename.endswith(".json"):
            path = os.path.join(config_states_dir, filename)
            with open(path, "r", encoding="utf-8") as f:
                j = json.load(f)
                j["filepath"] = path
                config_states.append(j)

    config_states = list(sorted(config_states, key=lambda cs: cs["created_at"], reverse=True))

    for cs in config_states:
        timestamp = time.asctime(time.gmtime(cs["created_at"]))
        name = cs.get("name", "Config")
        full_name = f"{name}: {timestamp}"
        all_config_states[full_name] = cs

    return all_config_states


def get_webui_config():
    webui_repo = None

    try:
        if os.path.exists(os.path.join(script_path, ".git")):
            webui_repo = git.Repo(script_path)
    except Exception:
        print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
        print(traceback.format_exc(), file=sys.stderr)

    webui_remote = None
    webui_commit_hash = None
    webui_commit_date = None
    webui_branch = None
    if webui_repo and not webui_repo.bare:
        try:
            webui_remote = next(webui_repo.remote().urls, None)
            head = webui_repo.head.commit
            webui_commit_date = webui_repo.head.commit.committed_date
            webui_commit_hash = head.hexsha
            webui_branch = webui_repo.active_branch.name

        except Exception:
            webui_remote = None

    return {
        "remote": webui_remote,
        "commit_hash": webui_commit_hash,
        "commit_date": webui_commit_date,
        "branch": webui_branch,
    }


def get_extension_config():
    ext_config = {}

    for ext in extensions.extensions:
        entry = {
            "name": ext.name,
            "path": ext.path,
            "enabled": ext.enabled,
            "is_builtin": ext.is_builtin,
            "remote": ext.remote,
            "commit_hash": ext.commit_hash,
            "commit_date": ext.commit_date,
            "branch": ext.branch,
            "have_info_from_repo": ext.have_info_from_repo
        }

        ext_config[ext.name] = entry

    return ext_config


def get_config():
    creation_time = datetime.now().timestamp()
    webui_config = get_webui_config()
    ext_config = get_extension_config()

    return {
        "created_at": creation_time,
        "webui": webui_config,
        "extensions": ext_config
    }


def restore_webui_config(config):
    print("* Restoring webui state...")

    if "webui" not in config:
        print("Error: No webui data saved to config")
        return

    webui_config = config["webui"]

    if "commit_hash" not in webui_config:
        print("Error: No commit saved to webui config")
        return

    webui_commit_hash = webui_config.get("commit_hash", None)
    webui_repo = None

    try:
        if os.path.exists(os.path.join(script_path, ".git")):
            webui_repo = git.Repo(script_path)
    except Exception:
        print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
        print(traceback.format_exc(), file=sys.stderr)
        return

    try:
        webui_repo.git.fetch(all=True)
        webui_repo.git.reset(webui_commit_hash, hard=True)
        print(f"* Restored webui to commit {webui_commit_hash}.")
    except Exception:
        print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr)
        print(traceback.format_exc(), file=sys.stderr)


def restore_extension_config(config):
    print("* Restoring extension state...")

    if "extensions" not in config:
        print("Error: No extension data saved to config")
        return

    ext_config = config["extensions"]

    results = []
    disabled = []

    for ext in tqdm.tqdm(extensions.extensions):
        if ext.is_builtin:
            continue

        ext.read_info_from_repo()
        current_commit = ext.commit_hash

        if ext.name not in ext_config:
            ext.disabled = True
            disabled.append(ext.name)
            results.append((ext, current_commit[:8], False, "Saved extension state not found in config, marking as disabled"))
            continue

        entry = ext_config[ext.name]

        if "commit_hash" in entry and entry["commit_hash"]:
            try:
                ext.fetch_and_reset_hard(entry["commit_hash"])
                ext.read_info_from_repo()
                if current_commit != entry["commit_hash"]:
                    results.append((ext, current_commit[:8], True, entry["commit_hash"][:8]))
            except Exception as ex:
                results.append((ext, current_commit[:8], False, ex))
        else:
            results.append((ext, current_commit[:8], False, "No commit hash found in config"))

        if not entry.get("enabled", False):
            ext.disabled = True
            disabled.append(ext.name)
        else:
            ext.disabled = False

    shared.opts.disabled_extensions = disabled
    shared.opts.save(shared.config_filename)

    print("* Finished restoring extensions. Results:")
    for ext, prev_commit, success, result in results:
        if success:
            print(f"  + {ext.name}: {prev_commit} -> {result}")
        else:
            print(f"  ! {ext.name}: FAILURE ({result})")
+30 −9
Original line number Diff line number Diff line
@@ -3,10 +3,11 @@ import sys
import traceback

import time
from datetime import datetime
import git

from modules import shared
from modules.paths_internal import extensions_dir, extensions_builtin_dir
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path

extensions = []

@@ -31,12 +32,15 @@ class Extension:
        self.status = ''
        self.can_update = False
        self.is_builtin = is_builtin
        self.commit_hash = ''
        self.commit_date = None
        self.version = ''
        self.branch = None
        self.remote = None
        self.have_info_from_repo = False

    def read_info_from_repo(self):
        if self.have_info_from_repo:
        if self.is_builtin or self.have_info_from_repo:
            return

        self.have_info_from_repo = True
@@ -56,10 +60,15 @@ class Extension:
                self.status = 'unknown'
                self.remote = next(repo.remote().urls, None)
                head = repo.head.commit
                ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
                self.version = f'{head.hexsha[:8]} ({ts})'

            except Exception:
                self.commit_date = repo.head.commit.committed_date
                ts = time.asctime(time.gmtime(self.commit_date))
                if repo.active_branch:
                    self.branch = repo.active_branch.name
                self.commit_hash = head.hexsha
                self.version = f'{self.commit_hash[:8]} ({ts})'

            except Exception as ex:
                print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
                self.remote = None

    def list_files(self, subdir, extension):
@@ -82,18 +91,30 @@ class Extension:
        for fetch in repo.remote().fetch(dry_run=True):
            if fetch.flags != fetch.HEAD_UPTODATE:
                self.can_update = True
                self.status = "behind"
                self.status = "new commits"
                return

        try:
            origin = repo.rev_parse('origin')
            if repo.head.commit != origin:
                self.can_update = True
                self.status = "behind HEAD"
                return
        except Exception:
            self.can_update = False
            self.status = "unknown (remote error)"
            return

        self.can_update = False
        self.status = "latest"

    def fetch_and_reset_hard(self):
    def fetch_and_reset_hard(self, commit='origin'):
        repo = git.Repo(self.path)
        # Fix: `error: Your local changes to the following files would be overwritten by merge`,
        # because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
        repo.git.fetch(all=True)
        repo.git.reset('origin', hard=True)
        repo.git.reset(commit, hard=True)
        self.have_info_from_repo = False


def list_extensions():
+1 −0
Original line number Diff line number Diff line
@@ -20,3 +20,4 @@ data_path = cmd_opts_pre.data_dir
models_path = os.path.join(data_path, "models")
extensions_dir = os.path.join(data_path, "extensions")
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
config_states_dir = os.path.join(script_path, "config_states")
Loading