Commit 9f104b53 authored by evshiron's avatar evshiron
Browse files

preview current image when opts.show_progress_every_n_steps is enabled

parent 88f46a5b
Loading
Loading
Loading
Loading
+6 −2
Original line number Original line Diff line number Diff line
import time
import time
import uvicorn
import uvicorn
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, Depends, HTTPException
import modules.shared as shared
import modules.shared as shared
from modules import devices
from modules import devices
from modules.api.models import *
from modules.api.models import *
@@ -187,7 +187,11 @@ class Api:


        progress = min(progress, 1)
        progress = min(progress, 1)


        return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict())
        current_image = None
        if shared.state.current_image:
            current_image = encode_pil_to_base64(shared.state.current_image)

        return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)


    def launch(self, server_name, port):
    def launch(self, server_name, port):
        self.app.include_router(self.router)
        self.app.include_router(self.router)
+1 −0
Original line number Original line Diff line number Diff line
@@ -161,3 +161,4 @@ class ProgressResponse(BaseModel):
    progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
    progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
    eta_relative: float = Field(title="ETA in secs")
    eta_relative: float = Field(title="ETA in secs")
    state: dict = Field(title="State", description="The current state snapshot")
    state: dict = Field(title="State", description="The current state snapshot")
    current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")