Unverified Commit 75b67eeb authored by Sena's avatar Sena Committed by GitHub
Browse files

Fix bare base64 not accept

parent 828438b4
Loading
Loading
Loading
Loading
+10 −3
Original line number Original line Diff line number Diff line
@@ -3,6 +3,7 @@ import io
import time
import time
import uvicorn
import uvicorn
from threading import Lock
from threading import Lock
from io import BytesIO
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, FastAPI, HTTPException
from fastapi import APIRouter, Depends, FastAPI, HTTPException
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.security import HTTPBasic, HTTPBasicCredentials
@@ -13,7 +14,7 @@ from modules import sd_samplers, deepbooru
from modules.api.models import *
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.extras import run_extras, run_pnginfo
from modules.extras import run_extras, run_pnginfo
from PIL import PngImagePlugin
from PIL import PngImagePlugin,Image
from modules.sd_models import checkpoints_list
from modules.sd_models import checkpoints_list
from modules.realesrgan_model import get_realesrgan_models
from modules.realesrgan_model import get_realesrgan_models
from typing import List
from typing import List
@@ -133,7 +134,10 @@ class Api:


        mask = img2imgreq.mask
        mask = img2imgreq.mask
        if mask:
        if mask:
            if mask.startswith("data:image/"):
                mask = decode_base64_to_image(mask)
                mask = decode_base64_to_image(mask)
            else:
                mask = Image.open(BytesIO(base64.b64decode(mask)))


        populate = img2imgreq.copy(update={ # Override __init__ params
        populate = img2imgreq.copy(update={ # Override __init__ params
            "sd_model": shared.sd_model,
            "sd_model": shared.sd_model,
@@ -147,7 +151,10 @@ class Api:


        imgs = []
        imgs = []
        for img in init_images:
        for img in init_images:
            if img.startswith("data:image/"):
                img = decode_base64_to_image(img)
                img = decode_base64_to_image(img)
            else:
                img = Image.open(BytesIO(base64.b64decode(img)))
            imgs = [img] * p.batch_size
            imgs = [img] * p.batch_size


        p.init_images = imgs
        p.init_images = imgs