Commit 88f46a5b authored by evshiron's avatar evshiron
Browse files

update progress response model

parent e9c6c2a5
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -61,7 +61,7 @@ class Api:
        self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
        self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
        self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
        self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"])
        self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)

    def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
        sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -171,7 +171,7 @@ class Api:
        # copy from check_progress_call of ui.py

        if shared.state.job_count == 0:
            return ProgressResponse(progress=0, eta_relative=0, state=shared.state.js())
            return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict())

        # avoid dividing zero
        progress = 0.01
@@ -187,7 +187,7 @@ class Api:

        progress = min(progress, 1)

        return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.js())
        return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict())

    def launch(self, server_name, port):
        self.app.include_router(self.router)
+2 −2
Original line number Diff line number Diff line
import inspect
from click import prompt
from pydantic import BaseModel, Field, Json, create_model
from pydantic import BaseModel, Field, create_model
from typing import Any, Optional
from typing_extensions import Literal
from inflection import underscore
@@ -160,4 +160,4 @@ class PNGInfoResponse(BaseModel):
class ProgressResponse(BaseModel):
    progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
    eta_relative: float = Field(title="ETA in secs")
    state: Json = Field(title="State", description="The current state snapshot")
    state: dict = Field(title="State", description="The current state snapshot")
+2 −2
Original line number Diff line number Diff line
@@ -147,7 +147,7 @@ class State:
    def get_job_timestamp(self):
        return datetime.datetime.now().strftime("%Y%m%d%H%M%S")  # shouldn't this return job_timestamp?

    def js(self):
    def dict(self):
        obj = {
            "skipped": self.skipped,
            "interrupted": self.skipped,
@@ -158,7 +158,7 @@ class State:
            "sampling_steps": self.sampling_steps,
        }

        return json.dumps(obj)
        return obj


state = State()