Unverified Commit ba110bf0 authored by Artem Kotov's avatar Artem Kotov Committed by GitHub
Browse files

fallback to original file retrieving; skip img if mask not found

usage of `shared.walk_files` breaks controlnet extension
images are processed in different order 
which leads to unmatched img file used for img2img and img file used for controlnet 
(if no folder is specified for control net
or the same as img2img input dir used for it)
parent 49f4b4be
Loading
Loading
Loading
Loading
+7 −12
Original line number Diff line number Diff line
@@ -17,12 +17,11 @@ import modules.scripts
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0):
    processing.fix_seed(p)

    allowed_extensions = [ext for ext, f in Image.registered_extensions().items() if f in Image.OPEN]
    images = list(shared.walk_files(input_dir, allowed_extensions=allowed_extensions))
    images = shared.listfiles(input_dir)

    is_inpaint_batch = False
    if inpaint_mask_dir:
        inpaint_masks = list(shared.walk_files(inpaint_mask_dir, allowed_extensions=allowed_extensions))
        inpaint_masks = shared.listfiles(inpaint_mask_dir)
        is_inpaint_batch = len(inpaint_masks) > 0
    if is_inpaint_batch:
        print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
@@ -59,22 +58,18 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
        p.init_images = [img] * p.batch_size

        image_path = Path(image)
        if image_path.parent == Path(input_dir):
            image_subpath = ""
        else:
            image_subpath = str(image_path.parent).replace(f"{input_dir}/", "")

        if is_inpaint_batch:
            # try to find corresponding mask for an image using simple filename matching
            if len(inpaint_masks) == 1:
                mask_image_path = inpaint_masks[0]
            else:
                # try to find corresponding mask for an image using simple filename matching
                mask_image_dir = Path(inpaint_mask_dir).joinpath(image_subpath)
                mask_image_dir = Path(inpaint_mask_dir)
                masks_found = list(mask_image_dir.glob(f"{image_path.stem}.*"))

                if len(masks_found) == 0:
                    raise ValueError(f"Mask is not found for {image_path} in {mask_image_dir}")
                    print(f"Warning: mask is not found for {image_path} in {mask_image_dir}. Skipping it.")
                    continue

                # it should contain only 1 matching mask
                # otherwise user has many masks with the same name but different extensions
@@ -95,10 +90,10 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
                filename = f"{left}-{n}{right}"

            if not save_normally:
                os.makedirs(os.path.join(output_dir, image_subpath), exist_ok=True)
                os.makedirs(output_dir, exist_ok=True)
                if processed_image.mode == 'RGBA':
                    processed_image = processed_image.convert("RGB")
                processed_image.save(os.path.join(output_dir, image_subpath, filename))
                processed_image.save(os.path.join(output_dir, filename))


def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):