Commit 4b0dc206 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

use modelloader for #4956

parent 2a649154
Loading
Loading
Loading
Loading
+8 −14
Original line number Diff line number Diff line
import contextlib
import os
import sys
import traceback
@@ -11,12 +10,9 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

import modules.shared as shared
from modules import devices, paths, lowvram
from modules import devices, paths, lowvram, modelloader

blip_image_eval_size = 384
blip_local_dir = os.path.join('models', 'Interrogator')
blip_local_file = os.path.join(blip_local_dir, 'model_base_caption_capfilt_large.pth')
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
clip_model_name = 'ViT-L/14'

Category = namedtuple("Category", ["name", "topn", "items"])
@@ -49,16 +45,14 @@ class InterrogateModels:
    def load_blip_model(self):
        import models.blip

        if not os.path.isfile(blip_local_file):
            if not os.path.isdir(blip_local_dir):
                os.mkdir(blip_local_dir)
        files = modelloader.load_models(
            model_path=os.path.join(paths.models_path, "BLIP"),
            model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
            ext_filter=[".pth"],
            download_name='model_base_caption_capfilt_large.pth',
        )

            print("Downloading BLIP...")
            from requests import get as reqget
            open(blip_local_file, 'wb').write(reqget(blip_model_url, allow_redirects=True).content)
            print("BLIP downloaded to", blip_local_file + '.')

        blip_model = models.blip.blip_decoder(pretrained=blip_local_file, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
        blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
        blip_model.eval()

        return blip_model