Commit d8b90ac1 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

big rework of progressbar/preview system to allow multiple users to prompts at...

big rework of progressbar/preview system to allow multiple users to prompts at the same time and do not get previews of each other
parent ebfdd7ba
Loading
Loading
Loading
Loading
+159 −90
Original line number Diff line number Diff line
// code related to showing and updating progressbar shown as the image is being made
global_progressbars = {}


galleries = {}
storedGallerySelections = {}
galleryObservers = {}

// this tracks launches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running
timeoutIds = {}

function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
    // gradio 3.8's enlightened approach allows them to create two nested div elements inside each other with same id
    // every time you use gr.HTML(elem_id='xxx'), so we handle this here
    var progressbar = gradioApp().querySelector("#"+id_progressbar+" #"+id_progressbar)
    var progressbarParent
    if(progressbar){
        progressbarParent = gradioApp().querySelector("#"+id_progressbar)
    } else{
        progressbar = gradioApp().getElementById(id_progressbar)
        progressbarParent = null
function rememberGallerySelection(id_gallery){
    storedGallerySelections[id_gallery] = getGallerySelectedIndex(id_gallery)
}

    var skip = id_skip ? gradioApp().getElementById(id_skip) : null
    var interrupt = gradioApp().getElementById(id_interrupt)

    if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){
        if(progressbar.innerText){
            let newtitle = '[' + progressbar.innerText.trim() + '] Stable Diffusion';
            if(document.title != newtitle){
                document.title =  newtitle;
            }
        }else{
            let newtitle = 'Stable Diffusion'
            if(document.title != newtitle){
                document.title =  newtitle;
            }
        }
    }

	if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){
	    global_progressbars[id_progressbar] = progressbar

        var mutationObserver = new MutationObserver(function(m){
            if(timeoutIds[id_part]) return;

            preview = gradioApp().getElementById(id_preview)
            gallery = gradioApp().getElementById(id_gallery)

            if(preview != null && gallery != null){
                preview.style.width = gallery.clientWidth + "px"
                preview.style.height = gallery.clientHeight + "px"
                if(progressbarParent) progressbar.style.width = progressbarParent.clientWidth + "px"

				//only watch gallery if there is a generation process going on
                check_gallery(id_gallery);

                var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
                if(progressDiv){
                    timeoutIds[id_part] = window.setTimeout(function() {
                        timeoutIds[id_part] = null
                        requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt)
                    }, 500)
                } else{
                    if (skip) {
                        skip.style.display = "none"
                    }
                    interrupt.style.display = "none"
function getGallerySelectedIndex(id_gallery){
    let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
    let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')

                    //disconnect observer once generation finished, so user can close selected image if they want
                    if (galleryObservers[id_gallery]) {
                        galleryObservers[id_gallery].disconnect();
                        galleries[id_gallery] = null;
                    }
                }
            }
     let currentlySelectedIndex = -1
     galleryButtons.forEach(function(v, i){ if(v==galleryBtnSelected) { currentlySelectedIndex = i } })

        });
        mutationObserver.observe( progressbar, { childList:true, subtree:true })
	}
     return currentlySelectedIndex
}

// this is a workaround for https://github.com/gradio-app/gradio/issues/2984
function check_gallery(id_gallery){
    let gallery = gradioApp().getElementById(id_gallery)
    // if gallery has no change, no need to setting up observer again.
@@ -85,10 +28,16 @@ function check_gallery(id_gallery){
        if(galleryObservers[id_gallery]){
            galleryObservers[id_gallery].disconnect();
        }
        let prevSelectedIndex = selected_gallery_index();

        storedGallerySelections[id_gallery] = -1

        galleryObservers[id_gallery] = new MutationObserver(function (){
            let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
            let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
            let currentlySelectedIndex = getGallerySelectedIndex(id_gallery)
            prevSelectedIndex = storedGallerySelections[id_gallery]
            storedGallerySelections[id_gallery] = -1

            if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
                // automatically re-open previously selected index (if exists)
                activeElement = gradioApp().activeElement;
@@ -120,30 +69,150 @@ function check_gallery(id_gallery){
}

onUiUpdate(function(){
    check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
    check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
    check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', '', 'ti_interrupt', 'ti_preview', 'ti_gallery')
    check_gallery('txt2img_gallery')
    check_gallery('img2img_gallery')
})

function requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt){
    btn = gradioApp().getElementById(id_part+"_check_progress");
    if(btn==null) return;
function request(url, data, handler, errorHandler){
    var xhr = new XMLHttpRequest();
    var url = url;
    xhr.open("POST", url, true);
    xhr.setRequestHeader("Content-Type", "application/json");
    xhr.onreadystatechange = function () {
        if (xhr.readyState === 4) {
            if (xhr.status === 200) {
                var js = JSON.parse(xhr.responseText);
                handler(js)
            } else{
                errorHandler()
            }
        }
    };
    var js = JSON.stringify(data);
    xhr.send(js);
}

function pad2(x){
    return x<10 ? '0'+x : x
}

function formatTime(secs){
    if(secs > 3600){
        return pad2(Math.floor(secs/60/60)) + ":" + pad2(Math.floor(secs/60)%60) + ":" + pad2(Math.floor(secs)%60)
    } else if(secs > 60){
        return pad2(Math.floor(secs/60)) + ":" + pad2(Math.floor(secs)%60)
    } else{
        return Math.floor(secs) + "s"
    }
}

function randomId(){
    return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")"
}

// starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and
// preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd.
// calls onProgress every time there is a progress update
function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress){
    var dateStart = new Date()
    var wasEverActive = false
    var parentProgressbar = progressbarContainer.parentNode
    var parentGallery = gallery.parentNode

    var divProgress = document.createElement('div')
    divProgress.className='progressDiv'
    var divInner = document.createElement('div')
    divInner.className='progress'

    divProgress.appendChild(divInner)
    parentProgressbar.insertBefore(divProgress, progressbarContainer)

    var livePreview = document.createElement('div')
    livePreview.className='livePreview'
    parentGallery.insertBefore(livePreview, gallery)

    var removeProgressBar = function(){
        parentProgressbar.removeChild(divProgress)
        parentGallery.removeChild(livePreview)
        atEnd()
    }

    var fun = function(id_task, id_live_preview){
        request("/internal/progress", {"id_task": id_task, "id_live_preview": id_live_preview}, function(res){
            console.log(res)

            if(res.completed){
                removeProgressBar()
                return
            }

            var rect = progressbarContainer.getBoundingClientRect()

            if(rect.width){
                divProgress.style.width = rect.width + "px";
            }

            progressText = ""

            divInner.style.width = ((res.progress || 0) * 100.0) + '%'

            if(res.progress > 0){
                progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%'
            }

            if(res.eta){
                progressText += " ETA: " + formatTime(res.eta)
            } else if(res.textinfo){
                progressText += " " + res.textinfo
            }

            divInner.textContent = progressText

            var elapsedFromStart = (new Date() - dateStart) / 1000

    btn.click();
    var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
    var skip = id_skip ? gradioApp().getElementById(id_skip) : null
    var interrupt = gradioApp().getElementById(id_interrupt)
    if(progressDiv && interrupt){
        if (skip) {
            skip.style.display = "block"
            if(res.active) wasEverActive = true;

            if(! res.active && wasEverActive){
                removeProgressBar()
                return
            }
        interrupt.style.display = "block"

            if(elapsedFromStart > 5 && !res.queued && !res.active){
                removeProgressBar()
                return
            }


            if(res.live_preview){
                var img = new Image();
                img.onload = function() {
                    var rect = gallery.getBoundingClientRect()
                    if(rect.width){
                        livePreview.style.width = rect.width + "px"
                        livePreview.style.height = rect.height + "px"
                    }

function requestProgress(id_part){
    btn = gradioApp().getElementById(id_part+"_check_progress_initial");
    if(btn==null) return;
                    livePreview.innerHTML = ''
                    livePreview.appendChild(img)
                    if(livePreview.childElementCount > 2){
                        livePreview.removeChild(livePreview.firstElementChild)
                    }
                }
                img.src = res.live_preview;
            }


            if(onProgress){
                onProgress(res)
            }

            setTimeout(() => {
                fun(id_task, res.id_live_preview);
            }, 500)
        }, function(){
            removeProgressBar()
        })
    }

    btn.click();
    fun(id_task, 0)
}
+11 −2
Original line number Diff line number Diff line



function start_training_textual_inversion(){
    requestProgress('ti')
    gradioApp().querySelector('#ti_error').innerHTML=''

    return args_to_array(arguments)
    var id = randomId()
    requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function(){}, function(progress){
        gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo
    })

    var res = args_to_array(arguments)

    res[0] = id

    return res
}
+28 −5
Original line number Diff line number Diff line
@@ -126,18 +126,41 @@ function create_submit_args(args){
    return res
}

function showSubmitButtons(tabname, show){
    gradioApp().getElementById(tabname+'_interrupt').style.display = show ? "none" : "block"
    gradioApp().getElementById(tabname+'_skip').style.display = show ? "none" : "block"
}

function submit(){
    requestProgress('txt2img')
    rememberGallerySelection('txt2img_gallery')
    showSubmitButtons('txt2img', false)

    var id = randomId()
    requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){
        showSubmitButtons('txt2img', true)

    })

    return create_submit_args(arguments)
    var res = create_submit_args(arguments)

    res[0] = id

    return res
}

function submit_img2img(){
    requestProgress('img2img')
    rememberGallerySelection('img2img_gallery')
    showSubmitButtons('img2img', false)

    var id = randomId()
    requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){
        showSubmitButtons('img2img', true)
    })

    res = create_submit_args(arguments)
    var res = create_submit_args(arguments)

    res[0] = get_tab_index('mode_img2img')
    res[0] = id
    res[1] = get_tab_index('mode_img2img')

    return res
}
+15 −4
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@ import threading
import traceback
import time

from modules import shared
from modules import shared, progress

queue_lock = threading.Lock()

@@ -22,10 +22,21 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func, extra_outputs=None):
    def f(*args, **kwargs):

        shared.state.begin()
        # if the first argument is a string that says "task(...)", it is treated as a job id
        if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
            id_task = args[0]
            progress.add_task_to_queue(id_task)
        else:
            id_task = None

        with queue_lock:
            shared.state.begin()
            progress.start_task(id_task)

            try:
                res = func(*args, **kwargs)
            finally:
                progress.finish_task(id_task)

            shared.state.end()

+3 −3
Original line number Diff line number Diff line
@@ -453,7 +453,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
    shared.reload_hypernetworks()


def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
    # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
    from modules import images

@@ -629,7 +629,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,

                description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
                pbar.set_description(description)
                shared.state.textinfo = description
                if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
                    # Before saving, change name to match current checkpoint.
                    hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
@@ -701,7 +700,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
                        torch.cuda.set_rng_state_all(cuda_rng_state)
                    hypernetwork.train()
                    if image is not None:
                        shared.state.current_image = image
                        shared.state.assign_current_image(image)

                        last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
                        last_saved_image += f", prompt: {preview_text}"

Loading