2
0
mirror of https://github.com/gradio-app/gradio.git synced 2025-03-19 12:00:39 +08:00

async-function-support ()

* async-function-support

- add async function support to Blocks

* async-function-support

- resolve conflicts

* async-function-support

- add error to Interface for async functions

* async-function-support

- add test

* async-function-support

- add test packages

* async-function-support

- add test packages
This commit is contained in:
Ömer Faruk Özdemir 2022-05-11 23:10:50 +03:00 committed by GitHub
parent 5fc00b4567
commit a88c017f87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 279 additions and 203 deletions

@ -1,6 +1,7 @@
from __future__ import annotations
import getpass
import inspect
import os
import sys
import time
@ -8,6 +9,8 @@ import warnings
import webbrowser
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from fastapi.concurrency import run_in_threadpool
from gradio import encryptor, networking, queueing, strings, utils
from gradio.context import Context
from gradio.deprecation import check_deprecated_parameters
@ -291,7 +294,7 @@ class Blocks(BlockContext):
if Context.block is not None:
Context.block.children.extend(self.children)
def process_api(
async def process_api(
self,
data: PredictBody,
username: str = None,
@ -321,7 +324,11 @@ class Blocks(BlockContext):
else:
processed_input = raw_input
start = time.time()
predictions = block_fn.fn(*processed_input)
if inspect.iscoroutinefunction(block_fn.fn):
predictions = await block_fn.fn(*processed_input)
else:
predictions = await run_in_threadpool(block_fn.fn, *processed_input)
duration = time.time() - start
block_fn.total_runtime += duration
block_fn.total_runs += 1

@ -165,6 +165,10 @@ class Interface(Blocks):
analytics_enabled=analytics_enabled, mode="interface", **kwargs
)
if inspect.iscoroutinefunction(fn):
raise NotImplementedError(
"Async functions are not currently supported within interfaces. Please use Blocks API."
)
self.interface_type = self.InterfaceTypes.STANDARD
if (inputs is None or inputs == []) and (outputs is None or outputs == []):
raise ValueError("Must provide at least one of `inputs` or `outputs`")

@ -16,13 +16,12 @@ import requests
import uvicorn
from gradio import queueing
from gradio.routes import create_app
from gradio.routes import App
from gradio.tunneling import create_tunnel
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
from gradio.blocks import Blocks
# By default, the local server will try to open on localhost, port 7860.
# If that is not available, then it will try 7861, 7862, ... 7959.
INITIAL_PORT_VALUE = int(os.getenv("GRADIO_SERVER_PORT", "7860"))
@ -139,8 +138,7 @@ def start_server(
else:
path_to_local_server = "http://{}:{}/".format(url_host_name, port)
app = create_app()
app = configure_app(app, blocks)
app = App.create_app(blocks)
if app.blocks.enable_queue:
if blocks.auth is not None or app.blocks.encrypt:

@ -10,22 +10,22 @@ import secrets
import traceback
import urllib
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, List, Optional, Type
import orjson
import pkg_resources
import uvicorn
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
from fastapi.security import OAuth2PasswordRequestForm
from fastapi.templating import Jinja2Templates
from jinja2.exceptions import TemplateNotFound
from pydantic import BaseModel
from starlette.concurrency import run_in_threadpool
from starlette.responses import RedirectResponse
from gradio import encryptor, queueing, utils
import gradio
from gradio import encryptor, queueing
STATIC_TEMPLATE_LIB = pkg_resources.resource_filename("gradio", "templates/")
STATIC_PATH_LIB = pkg_resources.resource_filename("gradio", "templates/frontend/static")
@ -78,199 +78,229 @@ class PredictBody(BaseModel):
###########
def create_app() -> FastAPI:
app = FastAPI(default_response_class=ORJSONResponse)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
app.state_holder = {}
class App(FastAPI):
"""
FastAPI App Wrapper
"""
@app.get("/user")
@app.get("/user/")
def get_current_user(request: Request) -> Optional[str]:
token = request.cookies.get("access-token")
return app.tokens.get(token)
def __init__(self, **kwargs):
self.tokens = None
self.auth = None
self.blocks: Optional[gradio.Blocks] = None
super().__init__(**kwargs)
@app.get("/login_check")
@app.get("/login_check/")
def login_check(user: str = Depends(get_current_user)):
if app.auth is None or not (user is None):
return
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated"
)
@app.get("/token")
@app.get("/token/")
def get_token(request: Request) -> Optional[str]:
token = request.cookies.get("access-token")
return {"token": token, "user": app.tokens.get(token)}
@app.post("/login")
@app.post("/login/")
def login(form_data: OAuth2PasswordRequestForm = Depends()):
username, password = form_data.username, form_data.password
if (
not callable(app.auth)
and username in app.auth
and app.auth[username] == password
) or (callable(app.auth) and app.auth.__call__(username, password)):
token = secrets.token_urlsafe(16)
app.tokens[token] = username
response = RedirectResponse(url="/", status_code=status.HTTP_302_FOUND)
response.set_cookie(key="access-token", value=token, httponly=True)
return response
else:
raise HTTPException(status_code=400, detail="Incorrect credentials.")
###############
# Main Routes
###############
@app.head("/", response_class=HTMLResponse)
@app.get("/", response_class=HTMLResponse)
def main(request: Request, user: str = Depends(get_current_user)):
if app.auth is None or not (user is None):
config = app.blocks.config
else:
config = {
"auth_required": True,
"auth_message": app.blocks.auth_message,
}
try:
return templates.TemplateResponse(
"frontend/index.html", {"request": request, "config": config}
)
except TemplateNotFound:
raise ValueError(
"Did you install Gradio from source files? You need to build "
"the frontend by running /scripts/build_frontend.sh"
)
@app.get("/config/", dependencies=[Depends(login_check)])
@app.get("/config", dependencies=[Depends(login_check)])
def get_config():
return app.blocks.config
@app.get("/static/{path:path}")
def static_resource(path: str):
if app.blocks.share:
return RedirectResponse(GRADIO_STATIC_ROOT + path)
else:
static_file = safe_join(STATIC_PATH_LIB, path)
if static_file is not None:
return FileResponse(static_file)
raise HTTPException(status_code=404, detail="Static file not found")
@app.get("/assets/{path:path}")
def build_resource(path: str):
if app.blocks.share:
return RedirectResponse(GRADIO_BUILD_ROOT + path)
else:
build_file = safe_join(BUILD_PATH_LIB, path)
if build_file is not None:
return FileResponse(build_file)
raise HTTPException(status_code=404, detail="Build file not found")
@app.get("/favicon.ico")
async def favicon():
if app.blocks.favicon_path is None:
return static_resource("img/logo.svg")
else:
return FileResponse(app.blocks.favicon_path)
@app.get("/file/{path:path}", dependencies=[Depends(login_check)])
def file(path):
if (
app.blocks.encrypt
and isinstance(app.blocks.examples, str)
and path.startswith(app.blocks.examples)
):
with open(safe_join(app.cwd, path), "rb") as encrypted_file:
encrypted_data = encrypted_file.read()
file_data = encryptor.decrypt(app.blocks.encryption_key, encrypted_data)
return FileResponse(
io.BytesIO(file_data), attachment_filename=os.path.basename(path)
)
else:
if Path(app.cwd).resolve() in Path(path).resolve().parents:
return FileResponse(Path(path).resolve())
@app.get("/api", response_class=HTMLResponse) # Needed for Spaces
@app.get("/api/", response_class=HTMLResponse)
def api_docs(request: Request):
inputs = [type(inp) for inp in app.blocks.input_components]
outputs = [type(out) for out in app.blocks.output_components]
input_types_doc, input_types = get_types(inputs, "input")
output_types_doc, output_types = get_types(outputs, "output")
input_names = [inp.get_block_name() for inp in app.blocks.input_components]
output_names = [out.get_block_name() for out in app.blocks.output_components]
if app.blocks.examples is not None:
sample_inputs = app.blocks.examples[0]
else:
sample_inputs = [
inp.generate_sample() for inp in app.blocks.input_components
]
docs = {
"inputs": input_names,
"outputs": output_names,
"len_inputs": len(inputs),
"len_outputs": len(outputs),
"inputs_lower": [name.lower() for name in input_names],
"outputs_lower": [name.lower() for name in output_names],
"input_types": input_types,
"output_types": output_types,
"input_types_doc": input_types_doc,
"output_types_doc": output_types_doc,
"sample_inputs": sample_inputs,
"auth": app.blocks.auth,
"local_login_url": urllib.parse.urljoin(app.blocks.local_url, "login"),
"local_api_url": urllib.parse.urljoin(app.blocks.local_url, "api/predict"),
}
return templates.TemplateResponse("api_docs.html", {"request": request, **docs})
@app.post("/api/predict/", dependencies=[Depends(login_check)])
async def predict(body: PredictBody, username: str = Depends(get_current_user)):
if hasattr(body, "session_hash"):
if body.session_hash not in app.state_holder:
app.state_holder[body.session_hash] = {
_id: getattr(block, "value", None)
for _id, block in app.blocks.blocks.items()
if getattr(block, "stateful", False)
}
session_state = app.state_holder[body.session_hash]
else:
session_state = {}
try:
output = await run_in_threadpool(
app.blocks.process_api,
body,
username,
session_state,
)
except BaseException as error:
if app.blocks.show_error:
traceback.print_exc()
return JSONResponse(content={"error": str(error)}, status_code=500)
def configure_app(self, blocks: gradio.Blocks) -> None:
auth = blocks.auth
if auth is not None:
if not callable(auth):
self.auth = {account[0]: account[1] for account in auth}
else:
raise error
return output
self.auth = auth
else:
self.auth = None
self.blocks = blocks
self.cwd = os.getcwd()
self.favicon_path = blocks.favicon_path
self.tokens = {}
@app.post("/api/queue/push/", dependencies=[Depends(login_check)])
async def queue_push(body: QueuePushBody):
job_hash, queue_position = queueing.push(body)
return {"hash": job_hash, "queue_position": queue_position}
@staticmethod
def create_app(blocks: gradio.Blocks) -> FastAPI:
app = App(default_response_class=ORJSONResponse)
app.configure_app(blocks)
@app.post("/api/queue/status/", dependencies=[Depends(login_check)])
async def queue_status(body: QueueStatusBody):
status, data = queueing.get_status(body.hash)
return {"status": status, "data": data}
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
app.state_holder = {}
return app
@app.get("/user")
@app.get("/user/")
def get_current_user(request: Request) -> Optional[str]:
token = request.cookies.get("access-token")
return app.tokens.get(token)
@app.get("/login_check")
@app.get("/login_check/")
def login_check(user: str = Depends(get_current_user)):
if app.auth is None or not (user is None):
return
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated"
)
@app.get("/token")
@app.get("/token/")
def get_token(request: Request) -> dict:
token = request.cookies.get("access-token")
return {"token": token, "user": app.tokens.get(token)}
@app.post("/login")
@app.post("/login/")
def login(form_data: OAuth2PasswordRequestForm = Depends()):
username, password = form_data.username, form_data.password
if (
not callable(app.auth)
and username in app.auth
and app.auth[username] == password
) or (callable(app.auth) and app.auth.__call__(username, password)):
token = secrets.token_urlsafe(16)
app.tokens[token] = username
response = RedirectResponse(url="/", status_code=status.HTTP_302_FOUND)
response.set_cookie(key="access-token", value=token, httponly=True)
return response
else:
raise HTTPException(status_code=400, detail="Incorrect credentials.")
###############
# Main Routes
###############
@app.head("/", response_class=HTMLResponse)
@app.get("/", response_class=HTMLResponse)
def main(request: Request, user: str = Depends(get_current_user)):
if app.auth is None or not (user is None):
config = app.blocks.config
else:
config = {
"auth_required": True,
"auth_message": app.blocks.auth_message,
}
try:
return templates.TemplateResponse(
"frontend/index.html", {"request": request, "config": config}
)
except TemplateNotFound:
raise ValueError(
"Did you install Gradio from source files? You need to build "
"the frontend by running /scripts/build_frontend.sh"
)
@app.get("/config/", dependencies=[Depends(login_check)])
@app.get("/config", dependencies=[Depends(login_check)])
def get_config():
return app.blocks.config
@app.get("/static/{path:path}")
def static_resource(path: str):
if app.blocks.share:
return RedirectResponse(GRADIO_STATIC_ROOT + path)
else:
static_file = safe_join(STATIC_PATH_LIB, path)
if static_file is not None:
return FileResponse(static_file)
raise HTTPException(status_code=404, detail="Static file not found")
@app.get("/assets/{path:path}")
def build_resource(path: str):
if app.blocks.share:
return RedirectResponse(GRADIO_BUILD_ROOT + path)
else:
build_file = safe_join(BUILD_PATH_LIB, path)
if build_file is not None:
return FileResponse(build_file)
raise HTTPException(status_code=404, detail="Build file not found")
@app.get("/favicon.ico")
async def favicon():
if app.blocks.favicon_path is None:
return static_resource("img/logo.svg")
else:
return FileResponse(app.blocks.favicon_path)
@app.get("/file/{path:path}", dependencies=[Depends(login_check)])
def file(path):
if (
app.blocks.encrypt
and isinstance(app.blocks.examples, str)
and path.startswith(app.blocks.examples)
):
with open(safe_join(app.cwd, path), "rb") as encrypted_file:
encrypted_data = encrypted_file.read()
file_data = encryptor.decrypt(app.blocks.encryption_key, encrypted_data)
return FileResponse(
io.BytesIO(file_data), attachment_filename=os.path.basename(path)
)
else:
if Path(app.cwd).resolve() in Path(path).resolve().parents:
return FileResponse(Path(path).resolve())
@app.get("/api", response_class=HTMLResponse) # Needed for Spaces
@app.get("/api/", response_class=HTMLResponse)
def api_docs(request: Request):
inputs = [type(inp) for inp in app.blocks.input_components]
outputs = [type(out) for out in app.blocks.output_components]
input_types_doc, input_types = get_types(inputs, "input")
output_types_doc, output_types = get_types(outputs, "output")
input_names = [inp.get_block_name() for inp in app.blocks.input_components]
output_names = [
out.get_block_name() for out in app.blocks.output_components
]
if app.blocks.examples is not None:
sample_inputs = app.blocks.examples[0]
else:
sample_inputs = [
inp.generate_sample() for inp in app.blocks.input_components
]
docs = {
"inputs": input_names,
"outputs": output_names,
"len_inputs": len(inputs),
"len_outputs": len(outputs),
"inputs_lower": [name.lower() for name in input_names],
"outputs_lower": [name.lower() for name in output_names],
"input_types": input_types,
"output_types": output_types,
"input_types_doc": input_types_doc,
"output_types_doc": output_types_doc,
"sample_inputs": sample_inputs,
"auth": app.blocks.auth,
"local_login_url": urllib.parse.urljoin(app.blocks.local_url, "login"),
"local_api_url": urllib.parse.urljoin(
app.blocks.local_url, "api/predict"
),
}
return templates.TemplateResponse(
"api_docs.html", {"request": request, **docs}
)
@app.post("/api/predict/", dependencies=[Depends(login_check)])
async def predict(body: PredictBody, username: str = Depends(get_current_user)):
if hasattr(body, "session_hash"):
if body.session_hash not in app.state_holder:
app.state_holder[body.session_hash] = {
_id: getattr(block, "value", None)
for _id, block in app.blocks.blocks.items()
if getattr(block, "stateful", False)
}
session_state = app.state_holder[body.session_hash]
else:
session_state = {}
try:
output = await app.blocks.process_api(body, username, session_state)
except BaseException as error:
if app.blocks.show_error:
traceback.print_exc()
return JSONResponse(content={"error": str(error)}, status_code=500)
else:
raise error
return output
@app.post("/api/queue/push/", dependencies=[Depends(login_check)])
async def queue_push(body: QueuePushBody):
job_hash, queue_position = queueing.push(body)
return {"hash": job_hash, "queue_position": queue_position}
@app.post("/api/queue/status/", dependencies=[Depends(login_check)])
async def queue_status(body: QueueStatusBody):
status, data = queueing.get_status(body.hash)
return {"status": status, "data": data}
return app
########

@ -1,4 +1,5 @@
# Don't forget to run bash scripts/create_test_requirements.sh from unix or wsl when you update this file.
asyncio
IPython
comet_ml
coverage
@ -12,6 +13,7 @@ pytest
wandb
huggingface_hub
pytest-cov
pytest-asyncio
black
isort
flake8

@ -14,6 +14,10 @@ asttokens==2.0.5
# via stack-data
astunparse==1.6.3
# via tensorflow
asyncio==3.4.3
# via -r requirements.in
atomicwrites==1.4.0
# via pytest
attrs==21.4.0
# via
# jsonschema
@ -48,6 +52,12 @@ cloudpickle==2.0.0
# via
# mlflow
# shap
colorama==0.4.4
# via
# click
# ipython
# pytest
# tqdm
comet-ml==3.25.0
# via -r requirements.in
configobj==5.0.6
@ -110,8 +120,6 @@ grpcio==1.43.0
# via
# tensorboard
# tensorflow
gunicorn==20.1.0
# via mlflow
h5py==3.6.0
# via tensorflow
huggingface-hub==0.4.0
@ -216,8 +224,6 @@ pathspec==0.9.0
# via black
pathtools==0.1.2
# via wandb
pexpect==4.8.0
# via ipython
pickleshare==0.7.5
# via ipython
pillow==9.0.1
@ -244,8 +250,6 @@ protobuf==3.19.4
# wandb
psutil==5.9.0
# via wandb
ptyprocess==0.7.0
# via pexpect
pure-eval==0.2.2
# via stack-data
py==1.11.0
@ -273,7 +277,10 @@ pyrsistent==0.18.1
pytest==7.0.0
# via
# -r requirements.in
# pytest-asyncio
# pytest-cov
pytest-asyncio==0.18.3
# via -r requirements.in
pytest-cov==3.0.0
# via -r requirements.in
python-dateutil==2.8.2
@ -424,6 +431,8 @@ urllib3[secure]==1.26.8
# requests
# selenium
# sentry-sdk
waitress==2.1.1
# via mlflow
wandb==0.12.10
# via -r requirements.in
wcwidth==0.2.5

@ -1,9 +1,16 @@
import asyncio
import random
import time
import unittest
import pytest
import gradio as gr
from gradio.routes import PredictBody
from gradio.test_data.blocks_configs import XRAY_CONFIG
pytest_plugins = ("pytest_asyncio",)
class TestBlocks(unittest.TestCase):
def test_xray(self):
@ -53,6 +60,25 @@ class TestBlocks(unittest.TestCase):
demo.load(fake_func, [], [textbox])
self.assertEqual(XRAY_CONFIG, demo.get_config_file())
@pytest.mark.asyncio
async def test_async_function(self):
async def wait():
await asyncio.sleep(0.01)
return True
with gr.Blocks() as demo:
text = gr.components.Textbox()
button = gr.components.Button()
button.click(wait, [text], [text])
body = PredictBody(data=1, fn_index=0)
start = time.time()
result = await demo.process_api(body)
end = time.time()
difference = end - start
assert difference >= 0.01
assert result
if __name__ == "__main__":
unittest.main()