Unverified Commit 452ab8fe authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #13718 from avantcontra/bugfix_gfpgan_custom_path

fix bug when using --gfpgan-models-path
parents 399baa54 236dd55d
Loading
Loading
Loading
Loading
+20 −5
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@ from modules import paths, shared, devices, modelloader, errors
model_dir = "GFPGAN"
user_path = None
model_path = os.path.join(paths.models_path, model_dir)
model_file_path = None
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
have_gfpgan = False
loaded_gfpgan_model = None
@@ -17,6 +18,7 @@ loaded_gfpgan_model = None
def gfpgann():
    global loaded_gfpgan_model
    global model_path
    global model_file_path
    if loaded_gfpgan_model is not None:
        loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
        return loaded_gfpgan_model
@@ -24,17 +26,24 @@ def gfpgann():
    if gfpgan_constructor is None:
        return None

    models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
    models = modelloader.load_models(model_path, model_url, user_path, ext_filter=['.pth'])

    if len(models) == 1 and models[0].startswith("http"):
        model_file = models[0]
    elif len(models) != 0:
        latest_file = max(models, key=os.path.getctime)
        gfp_models = []
        for item in models:
            if 'GFPGAN' in os.path.basename(item):
                gfp_models.append(item)
        latest_file = max(gfp_models, key=os.path.getctime)
        model_file = latest_file
    else:
        print("Unable to load gfpgan model!")
        return None

    if hasattr(facexlib.detection.retinaface, 'device'):
        facexlib.detection.retinaface.device = devices.device_gfpgan
    model_file_path = model_file
    model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
    loaded_gfpgan_model = model

@@ -77,19 +86,25 @@ def setup_model(dirname):
        global user_path
        global have_gfpgan
        global gfpgan_constructor
        global model_file_path

        facexlib_path = model_path

        if dirname is not None:
            facexlib_path = dirname

        load_file_from_url_orig = gfpgan.utils.load_file_from_url
        facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
        facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url

        def my_load_file_from_url(**kwargs):
            return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
            return load_file_from_url_orig(**dict(kwargs, model_dir=model_file_path))

        def facex_load_file_from_url(**kwargs):
            return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
            return facex_load_file_from_url_orig(**dict(kwargs, save_dir=facexlib_path, model_dir=None))

        def facex_load_file_from_url2(**kwargs):
            return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
            return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=facexlib_path, model_dir=None))

        gfpgan.utils.load_file_from_url = my_load_file_from_url
        facexlib.detection.load_file_from_url = facex_load_file_from_url