2023-01-15 23:50:56 +08:00
import base64
import io
import time
import gradio as gr
from pydantic import BaseModel , Field
from modules . shared import opts
import modules . shared as shared
2023-12-15 16:57:17 +08:00
from collections import OrderedDict
import string
import random
from typing import List
2023-01-15 23:50:56 +08:00
current_task = None
2023-12-15 16:57:17 +08:00
pending_tasks = OrderedDict ( )
2023-01-15 23:50:56 +08:00
finished_tasks = [ ]
2023-04-30 03:16:54 +08:00
recorded_results = [ ]
recorded_results_limit = 2
2023-01-15 23:50:56 +08:00
2023-04-30 03:15:20 +08:00
def start_task ( id_task ) :
2023-01-15 23:50:56 +08:00
global current_task
current_task = id_task
2023-04-30 03:15:20 +08:00
pending_tasks . pop ( id_task , None )
2023-01-15 23:50:56 +08:00
2023-04-30 03:15:20 +08:00
def finish_task ( id_task ) :
2023-01-15 23:50:56 +08:00
global current_task
if current_task == id_task :
current_task = None
2023-04-30 03:15:20 +08:00
finished_tasks . append ( id_task )
2023-01-15 23:50:56 +08:00
if len ( finished_tasks ) > 16 :
finished_tasks . pop ( 0 )
2023-12-15 16:57:17 +08:00
def create_task_id ( task_type ) :
N = 7
res = ' ' . join ( random . choices ( string . ascii_uppercase +
string . digits , k = N ) )
return f " task( { task_type } - { res } ) "
2023-01-15 23:50:56 +08:00
2023-04-30 03:16:54 +08:00
def record_results ( id_task , res ) :
recorded_results . append ( ( id_task , res ) )
if len ( recorded_results ) > recorded_results_limit :
recorded_results . pop ( 0 )
2023-04-30 03:15:20 +08:00
def add_task_to_queue ( id_job ) :
pending_tasks [ id_job ] = time . time ( )
2023-02-06 03:55:31 +08:00
2023-12-15 16:57:17 +08:00
class PendingTasksResponse ( BaseModel ) :
size : int = Field ( title = " Pending task size " )
tasks : List [ str ] = Field ( title = " Pending task ids " )
2023-02-03 03:13:03 +08:00
2023-01-15 23:50:56 +08:00
class ProgressRequest ( BaseModel ) :
id_task : str = Field ( default = None , title = " Task ID " , description = " id of the task to get progress for " )
id_live_preview : int = Field ( default = - 1 , title = " Live preview image ID " , description = " id of last received last preview image " )
2023-08-20 18:00:59 +08:00
live_preview : bool = Field ( default = True , title = " Include live preview " , description = " boolean flag indicating whether to include the live preview image " )
2023-01-15 23:50:56 +08:00
class ProgressResponse ( BaseModel ) :
active : bool = Field ( title = " Whether the task is being worked on right now " )
queued : bool = Field ( title = " Whether the task is in queue " )
completed : bool = Field ( title = " Whether the task has already finished " )
progress : float = Field ( default = None , title = " Progress " , description = " The progress with a range of 0 to 1 " )
eta : float = Field ( default = None , title = " ETA in secs " )
live_preview : str = Field ( default = None , title = " Live preview image " , description = " Current live preview; a data: uri " )
id_live_preview : int = Field ( default = None , title = " Live preview image ID " , description = " Send this together with next request to prevent receiving same image " )
textinfo : str = Field ( default = None , title = " Info text " , description = " Info text used by WebUI. " )
def setup_progress_api ( app ) :
2023-12-16 16:04:59 +08:00
app . add_api_route ( " /internal/pending-tasks " , get_pending_tasks , methods = [ " GET " ] )
2023-01-15 23:50:56 +08:00
return app . add_api_route ( " /internal/progress " , progressapi , methods = [ " POST " ] , response_model = ProgressResponse )
2023-12-16 16:04:59 +08:00
2023-12-15 16:57:17 +08:00
def get_pending_tasks ( ) :
2023-12-15 17:48:20 +08:00
pending_tasks_ids = list ( pending_tasks )
2023-12-15 16:57:17 +08:00
pending_len = len ( pending_tasks_ids )
return PendingTasksResponse ( size = pending_len , tasks = pending_tasks_ids )
2023-01-15 23:50:56 +08:00
def progressapi ( req : ProgressRequest ) :
active = req . id_task == current_task
queued = req . id_task in pending_tasks
completed = req . id_task in finished_tasks
if not active :
2023-08-21 17:48:56 +08:00
textinfo = " Waiting... "
if queued :
sorted_queued = sorted ( pending_tasks . keys ( ) , key = lambda x : pending_tasks [ x ] )
queue_index = sorted_queued . index ( req . id_task )
textinfo = " In queue: {} / {} " . format ( queue_index + 1 , len ( sorted_queued ) )
return ProgressResponse ( active = active , queued = queued , completed = completed , id_live_preview = - 1 , textinfo = textinfo )
2023-01-15 23:50:56 +08:00
progress = 0
2023-01-19 13:41:37 +08:00
job_count , job_no = shared . state . job_count , shared . state . job_no
sampling_steps , sampling_step = shared . state . sampling_steps , shared . state . sampling_step
if job_count > 0 :
progress + = job_no / job_count
2023-01-19 14:25:37 +08:00
if sampling_steps > 0 and job_count > 0 :
2023-01-19 13:41:37 +08:00
progress + = 1 / job_count * sampling_step / sampling_steps
2023-01-15 23:50:56 +08:00
progress = min ( progress , 1 )
elapsed_since_start = time . time ( ) - shared . state . time_start
predicted_duration = elapsed_since_start / progress if progress > 0 else None
eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
2023-08-22 15:41:10 +08:00
live_preview = None
2023-01-15 23:50:56 +08:00
id_live_preview = req . id_live_preview
2023-08-22 15:41:10 +08:00
if opts . live_previews_enable and req . live_preview :
shared . state . set_current_image ( )
if shared . state . id_live_preview != req . id_live_preview :
image = shared . state . current_image
if image is not None :
buffered = io . BytesIO ( )
if opts . live_previews_image_format == " png " :
# using optimize for large images takes an enormous amount of time
if max ( * image . size ) < = 256 :
save_kwargs = { " optimize " : True }
else :
save_kwargs = { " optimize " : False , " compress_level " : 1 }
2023-05-18 01:22:38 +08:00
else :
2023-08-22 15:41:10 +08:00
save_kwargs = { }
image . save ( buffered , format = opts . live_previews_image_format , * * save_kwargs )
base64_image = base64 . b64encode ( buffered . getvalue ( ) ) . decode ( ' ascii ' )
live_preview = f " data:image/ { opts . live_previews_image_format } ;base64, { base64_image } "
id_live_preview = shared . state . id_live_preview
2023-01-15 23:50:56 +08:00
2023-04-30 03:15:20 +08:00
return ProgressResponse ( active = active , queued = queued , completed = completed , progress = progress , eta = eta , live_preview = live_preview , id_live_preview = id_live_preview , textinfo = shared . state . textinfo )
2023-04-30 03:16:54 +08:00
def restore_progress ( id_task ) :
while id_task == current_task or id_task in pending_tasks :
time . sleep ( 0.1 )
res = next ( iter ( [ x [ 1 ] for x in recorded_results if id_task == x [ 0 ] ] ) , None )
if res is not None :
return res
return gr . update ( ) , gr . update ( ) , gr . update ( ) , f " Couldn ' t restore progress for { id_task } : results either have been discarded or never were obtained "