Commit e672cfb0 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

rework of callback for #6094

parent 6062c85d
Loading
Loading
Loading
Loading
+6 −4
Original line number Diff line number Diff line
@@ -39,12 +39,14 @@ def image_grid(imgs, batch_size=1, rows=None):

    cols = math.ceil(len(imgs) / rows)

    params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
    script_callbacks.image_grid_callback(params)

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols * w, rows * h), color='black')
    grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black')

    for i, img in enumerate(imgs):
        script_callbacks.image_grid_loop_callback(img)
        grid.paste(img, box=(i % cols * w, i // cols * h))
    for i, img in enumerate(params.imgs):
        grid.paste(img, box=(i % params.cols * w, i // params.cols * h))

    return grid

+15 −11
Original line number Diff line number Diff line
@@ -52,8 +52,10 @@ class UiTrainTabParams:


class ImageGridLoopParams:
    def __init__(self, img):
        self.img = img
    def __init__(self, imgs, cols, rows):
        self.imgs = imgs
        self.cols = cols
        self.rows = rows


ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
@@ -68,7 +70,7 @@ callback_map = dict(
    callbacks_cfg_denoiser=[],
    callbacks_before_component=[],
    callbacks_after_component=[],
    callbacks_image_grid_loop=[],
    callbacks_image_grid=[],
)


@@ -160,12 +162,14 @@ def after_component_callback(component, **kwargs):
        except Exception:
            report_exception(c, 'after_component_callback')

def image_grid_loop_callback(component, **kwargs):
    for c in callback_map['callbacks_image_grid_loop']:

def image_grid_callback(params: ImageGridLoopParams):
    for c in callback_map['callbacks_image_grid']:
        try:
            c.callback(component, **kwargs)
            c.callback(params)
        except Exception:
            report_exception(c, 'image_grid_loop')
            report_exception(c, 'image_grid')


def add_callback(callbacks, fun):
    stack = [x for x in inspect.stack() if x.filename != __file__]
@@ -269,9 +273,9 @@ def on_after_component(callback):
    add_callback(callback_map['callbacks_after_component'], callback)


def on_image_grid_loop(callback):
    """register a function to be called inside the image grid loop.
def on_image_grid(callback):
    """register a function to be called before making an image grid.
    The callback is called with one argument:
       - params: ImageGridLoopParams - parameters to be used inside the image grid loop.
       - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
    """
    add_callback(callback_map['callbacks_image_grid_loop'], callback)
    add_callback(callback_map['callbacks_image_grid'], callback)