server closing & threading

This commit is contained in:
Abubakar Abid 2022-01-03 17:47:45 -05:00
parent cd3f146c29
commit 68384d0661
4 changed files with 53 additions and 51 deletions

View File

@ -17,7 +17,7 @@ from typing import List, Optional, Type, TYPE_CHECKING
import urllib
import uvicorn
from gradio import utils
from gradio import utils, queueing
from gradio.process_examples import load_from_cache, process_example
@ -221,31 +221,21 @@ async def interpret(request: Request):
}
# @app.route("/shutdown", methods=['GET'])
# def shutdown():
# shutdown_func = request.environ.get('werkzeug.server.shutdown')
# if shutdown_func is None:
# raise RuntimeError('Not running werkzeug')
# shutdown_func()
# return "Shutting down..."
@app.post("/api/queue/push/", dependencies=[Depends(login_check)])
async def queue_push(request: Request):
body = await request.json()
data = body["data"]
action = body["action"]
job_hash, queue_position = queueing.push({"data": data}, action)
return {"hash": job_hash, "queue_position": queue_position}
# @app.route("/api/queue/push/", methods=["POST"])
# #@login_check
# def queue_push():
# data = request.json["data"]
# action = request.json["action"]
# job_hash, queue_position = queueing.push({"data": data}, action)
# return {"hash": job_hash, "queue_position": queue_position}
# @app.route("/api/queue/status/", methods=["POST"])
# #@login_check
# def queue_status():
# hash = request.json['hash']
# status, data = queueing.get_status(hash)
# return {"status": status, "data": data}
@app.post("/api/queue/status/", dependencies=[Depends(login_check)])
async def queue_status(request: Request):
body = await request.json()
hash = body['hash']
status, data = queueing.get_status(hash)
return {"status": status, "data": data}
########

View File

@ -440,19 +440,20 @@ class Interface:
) -> List[Any]:
return interpretation.run_interpret(self, raw_input)
def run_until_interrupted(
def block_thread(
self,
thread: threading.Thread,
path_to_local_server: str
) -> None:
"""Block main thread until interrupted by user."""
try:
while True:
time.sleep(0.5)
time.sleep(0.1)
except (KeyboardInterrupt, OSError):
print("Keyboard interruption in main thread... closing server.")
thread.keep_running = False
self.server.close()
# Hit the server one more time to close it
networking.url_ok(path_to_local_server)
networking.url_ok(self.local_url)
if self.enable_queue:
queueing.close()
@ -545,15 +546,14 @@ class Interface:
if self.cache_examples:
cache_interface_examples(self)
server_port, path_to_local_server, app, thread, server = networking.start_server(
server_port, path_to_local_server, app, server = networking.start_server(
self, server_name, server_port, self.auth)
self.local_url = path_to_local_server
self.server_port = server_port
self.status = "RUNNING"
self.server = server
self.server_app = app
self.server_thread = thread
self.server = server
utils.launch_counter()
@ -631,15 +631,14 @@ class Interface:
utils.show_tip(self)
# Run server perpetually under certain circumstances
# Block main thread if debug==True
if debug or int(os.getenv('GRADIO_DEBUG', 0)) == 1:
while True:
sys.stdout.flush()
time.sleep(0.1)
self.block_thread()
# Block main thread if running in a script to stop script from exiting
is_in_interactive_mode = bool(
getattr(sys, 'ps1', sys.flags.interactive))
if not prevent_thread_lock and not is_in_interactive_mode:
self.run_until_interrupted(thread, path_to_local_server)
self.block_thread()
return app, path_to_local_server, share_url
@ -651,8 +650,7 @@ class Interface:
Closes the Interface that was launched and frees the port.
"""
try:
self.server.shutdown()
self.server_thread.join()
self.server.close()
if verbose:
print("Closing server running on port: {}".format(
self.server_port))

View File

@ -3,6 +3,7 @@ Defines helper methods useful for setting up ports, launching servers, and
creating tunnels.
"""
from __future__ import annotations
import contextlib
import fastapi
import http
import json
@ -32,6 +33,21 @@ LOCALHOST_NAME = os.getenv('GRADIO_SERVER_NAME', "127.0.0.1")
GRADIO_API_SERVER = "https://api.gradio.app/v1/tunnel-request"
class Server(uvicorn.Server):
def install_signal_handlers(self):
pass
def run_in_thread(self):
self.thread = threading.Thread(target=self.run, daemon=True)
self.thread.start()
while not self.started:
time.sleep(1e-3)
def close(self):
self.should_exit = True
self.thread.join()
def get_first_available_port(
initial: int,
final: int) -> int:
@ -118,18 +134,18 @@ def start_server(
app.auth = None
app.interface = interface
app.cwd = os.getcwd()
# if app.interface.enable_queue:
# if auth is not None or app.interface.encrypt:
# raise ValueError("Cannot queue with encryption or authentication enabled.")
# queueing.init()
# app.queue_thread = threading.Thread(target=queue_thread, args=(path_to_local_server,))
# app.queue_thread.start()
if app.interface.enable_queue:
if auth is not None or app.interface.encrypt:
raise ValueError("Cannot queue with encryption or authentication enabled.")
queueing.init()
app.queue_thread = threading.Thread(target=queue_thread, args=(path_to_local_server,))
app.queue_thread.start()
app.tokens = {}
app_kwargs = {"app": app, "port": port, "host": server_name,
"log_level": "warning"}
thread = threading.Thread(target=uvicorn.run, kwargs=app_kwargs)
thread.start()
return port, path_to_local_server, app, thread, None
config = uvicorn.Config(app=app, port=port, host=server_name,
log_level="warning")
server = Server(config=config)
server.run_in_thread()
return port, path_to_local_server, app, server
def setup_tunnel(local_server_port: int, endpoint: str) -> str:

View File

@ -4,7 +4,6 @@ This module defines various classes that can serve as the `output` to an interfa
automatically added to a registry, which allows them to be easily referenced in other parts of the code.
"""
from posixpath import basename
from gradio.component import Component
import numpy as np
import json
@ -18,7 +17,6 @@ import pandas as pd
import PIL
from types import ModuleType
from ffmpy import FFmpeg
import requests
class OutputComponent(Component):