Commit b7f0e815 authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

fix error that causes some extra networks to be disabled if both <lora:> and...

fix error that causes some extra networks to be disabled if both <lora:> and <lyco:> are present in the prompt
parent 72ee347e
Loading
Loading
Loading
Loading
+42 −16
Original line number Diff line number Diff line
import json
import os
import re
import logging
from collections import defaultdict

from modules import errors
@@ -86,27 +87,55 @@ class ExtraNetwork:
        raise NotImplementedError


def activate(p, extra_network_data):
    """call activate for extra networks in extra_network_data in specified order, then call
    activate for all remaining registered networks with an empty argument list"""
def lookup_extra_networks(extra_network_data):
    """returns a dict mapping ExtraNetwork objects to lists of arguments for those extra networks.

    activated = []
    Example input:
    {
        'lora': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58310>],
        'lyco': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58F70>],
        'hypernet': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D5A800>]
    }

    Example output:

    for extra_network_name, extra_network_args in extra_network_data.items():
    {
        <extra_networks_lora.ExtraNetworkLora object at 0x0000020581BEECE0>: [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58310>, <modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58F70>],
        <modules.extra_networks_hypernet.ExtraNetworkHypernet object at 0x0000020581BEEE60>: [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D5A800>]
    }
    """

    res = {}

    for extra_network_name, extra_network_args in list(extra_network_data.items()):
        extra_network = extra_network_registry.get(extra_network_name, None)
        alias = extra_network_aliases.get(extra_network_name, None)

        if extra_network is None:
            extra_network = extra_network_aliases.get(extra_network_name, None)
        if alias is not None and extra_network is None:
            extra_network = alias

        if extra_network is None:
            print(f"Skipping unknown extra network: {extra_network_name}")
            logging.info(f"Skipping unknown extra network: {extra_network_name}")
            continue

        res.setdefault(extra_network, []).extend(extra_network_args)

    return res


def activate(p, extra_network_data):
    """call activate for extra networks in extra_network_data in specified order, then call
    activate for all remaining registered networks with an empty argument list"""

    activated = []

    for extra_network, extra_network_args in lookup_extra_networks(extra_network_data).items():

        try:
            extra_network.activate(p, extra_network_args)
            activated.append(extra_network)
        except Exception as e:
            errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
            errors.display(e, f"activating extra network {extra_network.name} with arguments {extra_network_args}")

    for extra_network_name, extra_network in extra_network_registry.items():
        if extra_network in activated:
@@ -125,19 +154,16 @@ def deactivate(p, extra_network_data):
    """call deactivate for extra networks in extra_network_data in specified order, then call
    deactivate for all remaining registered networks"""

    for extra_network_name in extra_network_data:
        extra_network = extra_network_registry.get(extra_network_name, None)
        if extra_network is None:
            continue
    data = lookup_extra_networks(extra_network_data)

    for extra_network in data:
        try:
            extra_network.deactivate(p)
        except Exception as e:
            errors.display(e, f"deactivating extra network {extra_network_name}")
            errors.display(e, f"deactivating extra network {extra_network.name}")

    for extra_network_name, extra_network in extra_network_registry.items():
        args = extra_network_data.get(extra_network_name, None)
        if args is not None:
        if extra_network in data:
            continue

        try: