Commit 5dc0739e authored by Stephen's avatar Stephen Committed by AUTOMATIC1111
Browse files

working mask

parent 9e1a8b77
Loading
Loading
Loading
Loading
+12 −8
Original line number Diff line number Diff line
@@ -33,6 +33,14 @@ class Api:
        self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
        self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])

    def __base64_to_image(self, base64_string):
        # if has a comma, deal with prefix
        if "," in base64_string:
            base64_string = base64_string.split(",")[1]
        imgdata = base64.b64decode(base64_string)
        # convert base64 to PIL image
        return Image.open(io.BytesIO(imgdata))

    def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
        sampler_index = sampler_to_index(txt2imgreq.sampler_index)
        
@@ -74,26 +82,22 @@ class Api:

        mask = img2imgreq.mask
        if mask:
            raise HTTPException(status_code=400, detail="Mask not supported yet") 
            mask = self.__base64_to_image(mask)

        
        populate = img2imgreq.copy(update={ # Override __init__ params
            "sd_model": shared.sd_model, 
            "sampler_index": sampler_index[0],
            "do_not_save_samples": True,
            "do_not_save_grid": True
            "do_not_save_grid": True, 
            "mask": mask
            }
        )
        p = StableDiffusionProcessingImg2Img(**vars(populate))

        imgs = []
        for img in init_images:
            # if has a comma, deal with prefix
            if "," in img:
                img = img.split(",")[1]
            # convert base64 to PIL image
            img = base64.b64decode(img)
            img = Image.open(io.BytesIO(img))
            img = self.__base64_to_image(img)
            imgs = [img] * p.batch_size

        p.init_images = imgs