mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-12 10:34:32 +08:00
server closing & threading
This commit is contained in:
parent
cd3f146c29
commit
68384d0661
@ -17,7 +17,7 @@ from typing import List, Optional, Type, TYPE_CHECKING
|
||||
import urllib
|
||||
import uvicorn
|
||||
|
||||
from gradio import utils
|
||||
from gradio import utils, queueing
|
||||
from gradio.process_examples import load_from_cache, process_example
|
||||
|
||||
|
||||
@ -221,31 +221,21 @@ async def interpret(request: Request):
|
||||
}
|
||||
|
||||
|
||||
|
||||
# @app.route("/shutdown", methods=['GET'])
|
||||
# def shutdown():
|
||||
# shutdown_func = request.environ.get('werkzeug.server.shutdown')
|
||||
# if shutdown_func is None:
|
||||
# raise RuntimeError('Not running werkzeug')
|
||||
# shutdown_func()
|
||||
# return "Shutting down..."
|
||||
@app.post("/api/queue/push/", dependencies=[Depends(login_check)])
|
||||
async def queue_push(request: Request):
|
||||
body = await request.json()
|
||||
data = body["data"]
|
||||
action = body["action"]
|
||||
job_hash, queue_position = queueing.push({"data": data}, action)
|
||||
return {"hash": job_hash, "queue_position": queue_position}
|
||||
|
||||
|
||||
# @app.route("/api/queue/push/", methods=["POST"])
|
||||
# #@login_check
|
||||
# def queue_push():
|
||||
# data = request.json["data"]
|
||||
# action = request.json["action"]
|
||||
# job_hash, queue_position = queueing.push({"data": data}, action)
|
||||
# return {"hash": job_hash, "queue_position": queue_position}
|
||||
|
||||
|
||||
# @app.route("/api/queue/status/", methods=["POST"])
|
||||
# #@login_check
|
||||
# def queue_status():
|
||||
# hash = request.json['hash']
|
||||
# status, data = queueing.get_status(hash)
|
||||
# return {"status": status, "data": data}
|
||||
@app.post("/api/queue/status/", dependencies=[Depends(login_check)])
|
||||
async def queue_status(request: Request):
|
||||
body = await request.json()
|
||||
hash = body['hash']
|
||||
status, data = queueing.get_status(hash)
|
||||
return {"status": status, "data": data}
|
||||
|
||||
|
||||
########
|
||||
|
@ -440,19 +440,20 @@ class Interface:
|
||||
) -> List[Any]:
|
||||
return interpretation.run_interpret(self, raw_input)
|
||||
|
||||
def run_until_interrupted(
|
||||
def block_thread(
|
||||
self,
|
||||
thread: threading.Thread,
|
||||
path_to_local_server: str
|
||||
) -> None:
|
||||
"""Block main thread until interrupted by user."""
|
||||
try:
|
||||
while True:
|
||||
time.sleep(0.5)
|
||||
time.sleep(0.1)
|
||||
except (KeyboardInterrupt, OSError):
|
||||
print("Keyboard interruption in main thread... closing server.")
|
||||
thread.keep_running = False
|
||||
self.server.close()
|
||||
# Hit the server one more time to close it
|
||||
networking.url_ok(path_to_local_server)
|
||||
networking.url_ok(self.local_url)
|
||||
if self.enable_queue:
|
||||
queueing.close()
|
||||
|
||||
@ -545,15 +546,14 @@ class Interface:
|
||||
if self.cache_examples:
|
||||
cache_interface_examples(self)
|
||||
|
||||
server_port, path_to_local_server, app, thread, server = networking.start_server(
|
||||
server_port, path_to_local_server, app, server = networking.start_server(
|
||||
self, server_name, server_port, self.auth)
|
||||
|
||||
self.local_url = path_to_local_server
|
||||
self.server_port = server_port
|
||||
self.status = "RUNNING"
|
||||
self.server = server
|
||||
self.server_app = app
|
||||
self.server_thread = thread
|
||||
self.server = server
|
||||
|
||||
utils.launch_counter()
|
||||
|
||||
@ -631,15 +631,14 @@ class Interface:
|
||||
|
||||
utils.show_tip(self)
|
||||
|
||||
# Run server perpetually under certain circumstances
|
||||
# Block main thread if debug==True
|
||||
if debug or int(os.getenv('GRADIO_DEBUG', 0)) == 1:
|
||||
while True:
|
||||
sys.stdout.flush()
|
||||
time.sleep(0.1)
|
||||
self.block_thread()
|
||||
# Block main thread if running in a script to stop script from exiting
|
||||
is_in_interactive_mode = bool(
|
||||
getattr(sys, 'ps1', sys.flags.interactive))
|
||||
if not prevent_thread_lock and not is_in_interactive_mode:
|
||||
self.run_until_interrupted(thread, path_to_local_server)
|
||||
self.block_thread()
|
||||
|
||||
return app, path_to_local_server, share_url
|
||||
|
||||
@ -651,8 +650,7 @@ class Interface:
|
||||
Closes the Interface that was launched and frees the port.
|
||||
"""
|
||||
try:
|
||||
self.server.shutdown()
|
||||
self.server_thread.join()
|
||||
self.server.close()
|
||||
if verbose:
|
||||
print("Closing server running on port: {}".format(
|
||||
self.server_port))
|
||||
|
@ -3,6 +3,7 @@ Defines helper methods useful for setting up ports, launching servers, and
|
||||
creating tunnels.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import contextlib
|
||||
import fastapi
|
||||
import http
|
||||
import json
|
||||
@ -32,6 +33,21 @@ LOCALHOST_NAME = os.getenv('GRADIO_SERVER_NAME', "127.0.0.1")
|
||||
GRADIO_API_SERVER = "https://api.gradio.app/v1/tunnel-request"
|
||||
|
||||
|
||||
class Server(uvicorn.Server):
|
||||
def install_signal_handlers(self):
|
||||
pass
|
||||
|
||||
def run_in_thread(self):
|
||||
self.thread = threading.Thread(target=self.run, daemon=True)
|
||||
self.thread.start()
|
||||
while not self.started:
|
||||
time.sleep(1e-3)
|
||||
|
||||
def close(self):
|
||||
self.should_exit = True
|
||||
self.thread.join()
|
||||
|
||||
|
||||
def get_first_available_port(
|
||||
initial: int,
|
||||
final: int) -> int:
|
||||
@ -118,18 +134,18 @@ def start_server(
|
||||
app.auth = None
|
||||
app.interface = interface
|
||||
app.cwd = os.getcwd()
|
||||
# if app.interface.enable_queue:
|
||||
# if auth is not None or app.interface.encrypt:
|
||||
# raise ValueError("Cannot queue with encryption or authentication enabled.")
|
||||
# queueing.init()
|
||||
# app.queue_thread = threading.Thread(target=queue_thread, args=(path_to_local_server,))
|
||||
# app.queue_thread.start()
|
||||
if app.interface.enable_queue:
|
||||
if auth is not None or app.interface.encrypt:
|
||||
raise ValueError("Cannot queue with encryption or authentication enabled.")
|
||||
queueing.init()
|
||||
app.queue_thread = threading.Thread(target=queue_thread, args=(path_to_local_server,))
|
||||
app.queue_thread.start()
|
||||
app.tokens = {}
|
||||
app_kwargs = {"app": app, "port": port, "host": server_name,
|
||||
"log_level": "warning"}
|
||||
thread = threading.Thread(target=uvicorn.run, kwargs=app_kwargs)
|
||||
thread.start()
|
||||
return port, path_to_local_server, app, thread, None
|
||||
config = uvicorn.Config(app=app, port=port, host=server_name,
|
||||
log_level="warning")
|
||||
server = Server(config=config)
|
||||
server.run_in_thread()
|
||||
return port, path_to_local_server, app, server
|
||||
|
||||
|
||||
def setup_tunnel(local_server_port: int, endpoint: str) -> str:
|
||||
|
@ -4,7 +4,6 @@ This module defines various classes that can serve as the `output` to an interfa
|
||||
automatically added to a registry, which allows them to be easily referenced in other parts of the code.
|
||||
"""
|
||||
|
||||
from posixpath import basename
|
||||
from gradio.component import Component
|
||||
import numpy as np
|
||||
import json
|
||||
@ -18,7 +17,6 @@ import pandas as pd
|
||||
import PIL
|
||||
from types import ModuleType
|
||||
from ffmpy import FFmpeg
|
||||
import requests
|
||||
|
||||
|
||||
class OutputComponent(Component):
|
||||
|
Loading…
Reference in New Issue
Block a user