Commit 1242ba08 authored by gayshub's avatar gayshub
Browse files

add allow specify the task id and get the location of task in the queue of pending task

parent 4afaaf8a
Loading
Loading
Loading
Loading
+18 −2
Original line number Diff line number Diff line
@@ -33,7 +33,7 @@ from typing import Dict, List, Any
import piexif
import piexif.helper
from contextlib import closing

from modules.progress import create_task_id, add_task_to_queue, start_task, finish_task, current_task

def script_name_to_index(name, scripts):
    try:
@@ -337,6 +337,10 @@ class Api:
        return script_args

    def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
        task_id = create_task_id("text2img")
        if txt2imgreq.force_task_id != None:
            task_id = txt2imgreq.force_task_id
            
        script_runner = scripts.scripts_txt2img
        if not script_runner.scripts:
            script_runner.initialize_scripts(False)
@@ -363,6 +367,8 @@ class Api:
        send_images = args.pop('send_images', True)
        args.pop('save_images', None)

        add_task_to_queue(task_id)

        with self.queue_lock:
            with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
                p.is_api = True
@@ -372,12 +378,14 @@ class Api:

                try:
                    shared.state.begin(job="scripts_txt2img")
                    start_task(task_id)
                    if selectable_scripts is not None:
                        p.script_args = script_args
                        processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
                    else:
                        p.script_args = tuple(script_args) # Need to pass args as tuple here
                        processed = process_images(p)
                    finish_task(task_id)
                finally:
                    shared.state.end()
                    shared.total_tqdm.clear()
@@ -387,6 +395,10 @@ class Api:
        return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())

    def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
        task_id = create_task_id("img2img")
        if img2imgreq.force_task_id != None:
            task_id = img2imgreq.force_task_id

        init_images = img2imgreq.init_images
        if init_images is None:
            raise HTTPException(status_code=404, detail="Init image not found")
@@ -423,6 +435,8 @@ class Api:
        send_images = args.pop('send_images', True)
        args.pop('save_images', None)

        add_task_to_queue(task_id)

        with self.queue_lock:
            with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
                p.init_images = [decode_base64_to_image(x) for x in init_images]
@@ -433,12 +447,14 @@ class Api:

                try:
                    shared.state.begin(job="scripts_img2img")
                    start_task(task_id)
                    if selectable_scripts is not None:
                        p.script_args = script_args
                        processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
                    else:
                        p.script_args = tuple(script_args) # Need to pass args as tuple here
                        processed = process_images(p)
                    finish_task(task_id)
                finally:
                    shared.state.end()
                    shared.total_tqdm.clear()
@@ -514,7 +530,7 @@ class Api:
        if shared.state.current_image and not req.skip_current_image:
            current_image = encode_pil_to_base64(shared.state.current_image)

        return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
        return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo, current_task=current_task)

    def interrogateapi(self, interrogatereq: models.InterrogateRequest):
        image_b64 = interrogatereq.image
+2 −0
Original line number Diff line number Diff line
@@ -109,6 +109,7 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
        {"key": "send_images", "type": bool, "default": True},
        {"key": "save_images", "type": bool, "default": False},
        {"key": "alwayson_scripts", "type": dict, "default": {}},
        {"key": "force_task_id", "type": str, "default": None},
    ]
).generate_model()

@@ -126,6 +127,7 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
        {"key": "send_images", "type": bool, "default": True},
        {"key": "save_images", "type": bool, "default": False},
        {"key": "alwayson_scripts", "type": dict, "default": {}},
        {"key": "force_task_id", "type": str, "default": None},
    ]
).generate_model()

+2 −0
Original line number Diff line number Diff line
@@ -1023,6 +1023,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
    hr_sampler_name: str = None
    hr_prompt: str = ''
    hr_negative_prompt: str = ''
    force_task_id: str = None

    cached_hr_uc = [None, None]
    cached_hr_c = [None, None]
@@ -1358,6 +1359,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
    inpainting_mask_invert: int = 0
    initial_noise_multiplier: float = None
    latent_mask: Image = None
    force_task_id: string = None

    image_mask: Any = field(default=None, init=False)

+19 −2
Original line number Diff line number Diff line
@@ -8,10 +8,13 @@ from pydantic import BaseModel, Field
from modules.shared import opts

import modules.shared as shared

from collections import OrderedDict
import string
import random
from typing import List

current_task = None
pending_tasks = {}
pending_tasks = OrderedDict()
finished_tasks = []
recorded_results = []
recorded_results_limit = 2
@@ -34,6 +37,11 @@ def finish_task(id_task):
    if len(finished_tasks) > 16:
        finished_tasks.pop(0)

def create_task_id(task_type):
    N = 7
    res = ''.join(random.choices(string.ascii_uppercase +
    string.digits, k=N))
    return f"task({task_type}-{res})"

def record_results(id_task, res):
    recorded_results.append((id_task, res))
@@ -44,6 +52,9 @@ def record_results(id_task, res):
def add_task_to_queue(id_job):
    pending_tasks[id_job] = time.time()

class PendingTasksResponse(BaseModel):
    size: int = Field(title="Pending task size")
    tasks: List[str] = Field(title="Pending task ids")

class ProgressRequest(BaseModel):
    id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
@@ -63,8 +74,14 @@ class ProgressResponse(BaseModel):


def setup_progress_api(app):
    app.add_api_route("/internal/pendingTasks", get_pending_tasks, methods=["GET"])
    return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)

def get_pending_tasks():
    pending_tasks_ids = [x for x in pending_tasks]
    pending_len = len(pending_tasks_ids)
    return PendingTasksResponse(size=pending_len, tasks=pending_tasks_ids)


def progressapi(req: ProgressRequest):
    active = req.id_task == current_task