Commit d74c3810 authored by Jesse Williams's avatar Jesse Williams Committed by AUTOMATIC1111
Browse files

Confirm that options are valid before starting

When using the 'Sampler' or 'Checkpoint' options, if one of the entered
names has a typo, an error will only be thrown once the `draw_xy_grid`
loop reaches that name. This can waste a lot of time for large grids
with a typo near the end of a list, since the script needs to start over
and re-generate any earlier images to finish making the grid.

Also fixing typo in variable name in `draw_xy_grid`.
parent 6f6798dd
Loading
Loading
Loading
Loading
+15 −6
Original line number Diff line number Diff line
@@ -145,7 +145,7 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend):
    ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
    hor_texts = [[images.GridAnnotation(x)] for x in x_labels]

    first_pocessed = None
    first_processed = None

    state.job_count = len(xs) * len(ys) * p.n_iter

@@ -154,8 +154,8 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend):
            state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"

            processed = cell(x, y)
            if first_pocessed is None:
                first_pocessed = processed
            if first_processed is None:
                first_processed = processed

            try:
              res.append(processed.images[0])
@@ -166,9 +166,9 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend):
    if draw_legend:
        grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)

    first_pocessed.images = [grid]
    first_processed.images = [grid]

    return first_pocessed
    return first_processed


re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
@@ -216,7 +216,6 @@ class Script(scripts.Script):
                    m = re_range.fullmatch(val)
                    mc = re_range_count.fullmatch(val)
                    if m is not None:

                        start = int(m.group(1))
                        end = int(m.group(2))+1
                        step = int(m.group(3)) if m.group(3) is not None else 1
@@ -259,6 +258,16 @@ class Script(scripts.Script):

            valslist = [opt.type(x) for x in valslist]
            
            # Confirm options are valid before starting
            if opt.label == "Sampler":
                for sampler_val in valslist:
                    if sampler_val.lower() not in samplers_dict.keys():
                        raise RuntimeError(f"Unknown sampler: {sampler_val}")
            elif opt.label == "Checkpoint name":
                for ckpt_val in valslist:
                    if modules.sd_models.get_closet_checkpoint_match(ckpt_val) is None:
                        raise RuntimeError(f"Checkpoint for {ckpt_val} not found")

            return valslist

        x_opt = axis_options[x_type]