Commit 14978420 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

rework #3722 to not introduce duplicate code

parent 060ee5d3
Loading
Loading
Loading
Loading
+13 −30
Original line number Original line Diff line number Diff line
@@ -9,31 +9,6 @@ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusion
from modules.sd_samplers import all_samplers
from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo
from modules.extras import run_extras, run_pnginfo


# copy from wrap_gradio_gpu_call of webui.py
# because queue lock will be acquired in api handlers
# and time start needs to be set
# the function has been modified into two parts

def before_gpu_call():
    devices.torch_gc()

    shared.state.sampling_step = 0
    shared.state.job_count = -1
    shared.state.job_no = 0
    shared.state.job_timestamp = shared.state.get_job_timestamp()
    shared.state.current_latent = None
    shared.state.current_image = None
    shared.state.current_image_sampling_step = 0
    shared.state.skipped = False
    shared.state.interrupted = False
    shared.state.textinfo = None
    shared.state.time_start = time.time()

def after_gpu_call():
    shared.state.job = ""
    shared.state.job_count = 0

    devices.torch_gc()


def upscaler_to_index(name: str):
def upscaler_to_index(name: str):
    try:
    try:
@@ -41,8 +16,10 @@ def upscaler_to_index(name: str):
    except:
    except:
        raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
        raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")



sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)



def setUpscalers(req: dict):
def setUpscalers(req: dict):
    reqDict = vars(req)
    reqDict = vars(req)
    reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
    reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
@@ -51,6 +28,7 @@ def setUpscalers(req: dict):
    reqDict.pop('upscaler_2')
    reqDict.pop('upscaler_2')
    return reqDict
    return reqDict



class Api:
class Api:
    def __init__(self, app, queue_lock):
    def __init__(self, app, queue_lock):
        self.router = APIRouter()
        self.router = APIRouter()
@@ -78,10 +56,13 @@ class Api:
        )
        )
        p = StableDiffusionProcessingTxt2Img(**vars(populate))
        p = StableDiffusionProcessingTxt2Img(**vars(populate))
        # Override object param
        # Override object param
        before_gpu_call()

        shared.state.begin()

        with self.queue_lock:
        with self.queue_lock:
            processed = process_images(p)
            processed = process_images(p)
        after_gpu_call()

        shared.state.end()


        b64images = list(map(encode_pil_to_base64, processed.images))
        b64images = list(map(encode_pil_to_base64, processed.images))


@@ -119,11 +100,13 @@ class Api:
            imgs = [img] * p.batch_size
            imgs = [img] * p.batch_size


        p.init_images = imgs
        p.init_images = imgs
        # Override object param

        before_gpu_call()
        shared.state.begin()

        with self.queue_lock:
        with self.queue_lock:
            processed = process_images(p)
            processed = process_images(p)
        after_gpu_call()

        shared.state.end()


        b64images = list(map(encode_pil_to_base64, processed.images))
        b64images = list(map(encode_pil_to_base64, processed.images))


+19 −3
Original line number Original line Diff line number Diff line
@@ -144,9 +144,6 @@ class State:
        self.sampling_step = 0
        self.sampling_step = 0
        self.current_image_sampling_step = 0
        self.current_image_sampling_step = 0


    def get_job_timestamp(self):
        return datetime.datetime.now().strftime("%Y%m%d%H%M%S")  # shouldn't this return job_timestamp?

    def dict(self):
    def dict(self):
        obj = {
        obj = {
            "skipped": self.skipped,
            "skipped": self.skipped,
@@ -160,6 +157,25 @@ class State:


        return obj
        return obj


    def begin(self):
        self.sampling_step = 0
        self.job_count = -1
        self.job_no = 0
        self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
        self.current_latent = None
        self.current_image = None
        self.current_image_sampling_step = 0
        self.skipped = False
        self.interrupted = False
        self.textinfo = None

        devices.torch_gc()

    def end(self):
        self.job = ""
        self.job_count = 0

        devices.torch_gc()


state = State()
state = State()


+3 −16
Original line number Original line Diff line number Diff line
@@ -46,26 +46,13 @@ def wrap_queued_call(func):


def wrap_gradio_gpu_call(func, extra_outputs=None):
def wrap_gradio_gpu_call(func, extra_outputs=None):
    def f(*args, **kwargs):
    def f(*args, **kwargs):
        devices.torch_gc()


        shared.state.begin()
        shared.state.sampling_step = 0
        shared.state.job_count = -1
        shared.state.job_no = 0
        shared.state.job_timestamp = shared.state.get_job_timestamp()
        shared.state.current_latent = None
        shared.state.current_image = None
        shared.state.current_image_sampling_step = 0
        shared.state.skipped = False
        shared.state.interrupted = False
        shared.state.textinfo = None


        with queue_lock:
        with queue_lock:
            res = func(*args, **kwargs)
            res = func(*args, **kwargs)


        shared.state.job = ""
        shared.state.end()
        shared.state.job_count = 0

        devices.torch_gc()


        return res
        return res