Commit 0cd74602 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

add script callback for before image save and change callback for after image...

add script callback for before image save and change callback for after image save to use a class with parameters
parent 1e428238
Loading
Loading
Loading
Loading
+24 −18
Original line number Diff line number Diff line
@@ -451,17 +451,6 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
    """
    namegen = FilenameGenerator(p, seed, prompt)

    if extension == 'png' and opts.enable_pnginfo and info is not None:
        pnginfo = PngImagePlugin.PngInfo()

        if existing_info is not None:
            for k, v in existing_info.items():
                pnginfo.add_text(k, str(v))

        pnginfo.add_text(pnginfo_section_name, info)
    else:
        pnginfo = None

    if save_to_dirs is None:
        save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)

@@ -489,19 +478,27 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
        if add_number:
            basecount = get_next_sequence_number(path, basename)
            fullfn = None
            fullfn_without_extension = None
            for i in range(500):
                fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
                fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
                fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
                if not os.path.exists(fullfn):
                    break
        else:
            fullfn = os.path.join(path, f"{file_decoration}.{extension}")
            fullfn_without_extension = os.path.join(path, file_decoration)
    else:
        fullfn = os.path.join(path, f"{forced_filename}.{extension}")
        fullfn_without_extension = os.path.join(path, forced_filename)

    pnginfo = existing_info or {}
    if info is not None:
        pnginfo[pnginfo_section_name] = info

    params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
    script_callbacks.before_image_saved_callback(params)

    image = params.image
    fullfn = params.filename
    info = params.pnginfo.get(pnginfo_section_name, None)
    fullfn_without_extension, extension = os.path.splitext(params.filename)

    def exif_bytes():
        return piexif.dump({
@@ -510,12 +507,20 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
            },
        })

    if extension.lower() in ("jpg", "jpeg", "webp"):
    if extension.lower() == '.png':
        pnginfo_data = PngImagePlugin.PngInfo()
        for k, v in params.pnginfo.items():
            pnginfo_data.add_text(k, str(v))

        image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo_data)

    elif extension.lower() in (".jpg", ".jpeg", ".webp"):
        image.save(fullfn, quality=opts.jpeg_quality)

        if opts.enable_pnginfo and info is not None:
            piexif.insert(exif_bytes(), fullfn)
    else:
        image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo)
        image.save(fullfn, quality=opts.jpeg_quality)

    target_side_length = 4000
    oversize = image.width > target_side_length or image.height > target_side_length
@@ -538,7 +543,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
    else:
        txt_fullfn = None

    script_callbacks.image_saved_callback(image, p, fullfn, txt_fullfn)
    script_callbacks.image_saved_callback(params)

    return fullfn, txt_fullfn


+40 −8
Original line number Diff line number Diff line
@@ -9,15 +9,34 @@ def report_exception(c, job):
    print(traceback.format_exc(), file=sys.stderr)


class ImageSaveParams:
    def __init__(self, image, p, filename, pnginfo):
        self.image = image
        """the PIL image itself"""

        self.p = p
        """p object with processing parameters; either StableDiffusionProcessing or an object with same fields"""

        self.filename = filename
        """name of file that the image would be saved to"""

        self.pnginfo = pnginfo
        """dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""


ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callbacks_model_loaded = []
callbacks_ui_tabs = []
callbacks_ui_settings = []
callbacks_before_image_saved = []
callbacks_image_saved = []


def clear_callbacks():
    callbacks_model_loaded.clear()
    callbacks_ui_tabs.clear()
    callbacks_ui_settings.clear()
    callbacks_before_image_saved.clear()
    callbacks_image_saved.clear()


@@ -49,10 +68,18 @@ def ui_settings_callback():
            report_exception(c, 'ui_settings_callback')


def image_saved_callback(image, p, fullfn, txt_fullfn):
def before_image_saved_callback(params: ImageSaveParams):
    for c in callbacks_image_saved:
        try:
            c.callback(image, p, fullfn, txt_fullfn)
            c.callback(params)
        except Exception:
            report_exception(c, 'before_image_saved_callback')


def image_saved_callback(params: ImageSaveParams):
    for c in callbacks_image_saved:
        try:
            c.callback(params)
        except Exception:
            report_exception(c, 'image_saved_callback')

@@ -64,7 +91,6 @@ def add_callback(callbacks, fun):
    callbacks.append(ScriptCallback(filename, fun))



def on_model_loaded(callback):
    """register a function to be called when the stable diffusion model is created; the model is
    passed as an argument"""
@@ -90,11 +116,17 @@ def on_ui_settings(callback):
    add_callback(callbacks_ui_settings, callback)


def on_before_image_saved(callback):
    """register a function to be called before an image is saved to a file.
    The callback is called with one argument:
        - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
    """
    add_callback(callbacks_before_image_saved, callback)


def on_image_saved(callback):
    """register a function to be called after modules.images.save_image is called.
    The callback is called with three arguments:
        - p - procesing object (or a dummy object with same fields if the image is saved using save button)
        - fullfn - image filename
        - txt_fullfn - text file with parameters; may be None
    """register a function to be called after an image is saved to a file.
    The callback is called with one argument:
        - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
    """
    add_callback(callbacks_image_saved, callback)