mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-19 12:00:39 +08:00
async-function-support (#1190)
* async-function-support - add async function support to Blocks * async-function-support - resolve conflicts * async-function-support - add error to Interface for async functions * async-function-support - add test * async-function-support - add test packages * async-function-support - add test packages
This commit is contained in:
parent
5fc00b4567
commit
a88c017f87
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import getpass
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
@ -8,6 +9,8 @@ import warnings
|
||||
import webbrowser
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
|
||||
from gradio import encryptor, networking, queueing, strings, utils
|
||||
from gradio.context import Context
|
||||
from gradio.deprecation import check_deprecated_parameters
|
||||
@ -291,7 +294,7 @@ class Blocks(BlockContext):
|
||||
if Context.block is not None:
|
||||
Context.block.children.extend(self.children)
|
||||
|
||||
def process_api(
|
||||
async def process_api(
|
||||
self,
|
||||
data: PredictBody,
|
||||
username: str = None,
|
||||
@ -321,7 +324,11 @@ class Blocks(BlockContext):
|
||||
else:
|
||||
processed_input = raw_input
|
||||
start = time.time()
|
||||
predictions = block_fn.fn(*processed_input)
|
||||
|
||||
if inspect.iscoroutinefunction(block_fn.fn):
|
||||
predictions = await block_fn.fn(*processed_input)
|
||||
else:
|
||||
predictions = await run_in_threadpool(block_fn.fn, *processed_input)
|
||||
duration = time.time() - start
|
||||
block_fn.total_runtime += duration
|
||||
block_fn.total_runs += 1
|
||||
|
@ -165,6 +165,10 @@ class Interface(Blocks):
|
||||
analytics_enabled=analytics_enabled, mode="interface", **kwargs
|
||||
)
|
||||
|
||||
if inspect.iscoroutinefunction(fn):
|
||||
raise NotImplementedError(
|
||||
"Async functions are not currently supported within interfaces. Please use Blocks API."
|
||||
)
|
||||
self.interface_type = self.InterfaceTypes.STANDARD
|
||||
if (inputs is None or inputs == []) and (outputs is None or outputs == []):
|
||||
raise ValueError("Must provide at least one of `inputs` or `outputs`")
|
||||
|
@ -16,13 +16,12 @@ import requests
|
||||
import uvicorn
|
||||
|
||||
from gradio import queueing
|
||||
from gradio.routes import create_app
|
||||
from gradio.routes import App
|
||||
from gradio.tunneling import create_tunnel
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
|
||||
from gradio.blocks import Blocks
|
||||
|
||||
|
||||
# By default, the local server will try to open on localhost, port 7860.
|
||||
# If that is not available, then it will try 7861, 7862, ... 7959.
|
||||
INITIAL_PORT_VALUE = int(os.getenv("GRADIO_SERVER_PORT", "7860"))
|
||||
@ -139,8 +138,7 @@ def start_server(
|
||||
else:
|
||||
path_to_local_server = "http://{}:{}/".format(url_host_name, port)
|
||||
|
||||
app = create_app()
|
||||
app = configure_app(app, blocks)
|
||||
app = App.create_app(blocks)
|
||||
|
||||
if app.blocks.enable_queue:
|
||||
if blocks.auth is not None or app.blocks.encrypt:
|
||||
|
412
gradio/routes.py
412
gradio/routes.py
@ -10,22 +10,22 @@ import secrets
|
||||
import traceback
|
||||
import urllib
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import Any, List, Optional, Type
|
||||
|
||||
import orjson
|
||||
import pkg_resources
|
||||
import uvicorn
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from jinja2.exceptions import TemplateNotFound
|
||||
from pydantic import BaseModel
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.responses import RedirectResponse
|
||||
|
||||
from gradio import encryptor, queueing, utils
|
||||
import gradio
|
||||
from gradio import encryptor, queueing
|
||||
|
||||
STATIC_TEMPLATE_LIB = pkg_resources.resource_filename("gradio", "templates/")
|
||||
STATIC_PATH_LIB = pkg_resources.resource_filename("gradio", "templates/frontend/static")
|
||||
@ -78,199 +78,229 @@ class PredictBody(BaseModel):
|
||||
###########
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
app = FastAPI(default_response_class=ORJSONResponse)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
app.state_holder = {}
|
||||
class App(FastAPI):
|
||||
"""
|
||||
FastAPI App Wrapper
|
||||
"""
|
||||
|
||||
@app.get("/user")
|
||||
@app.get("/user/")
|
||||
def get_current_user(request: Request) -> Optional[str]:
|
||||
token = request.cookies.get("access-token")
|
||||
return app.tokens.get(token)
|
||||
def __init__(self, **kwargs):
|
||||
self.tokens = None
|
||||
self.auth = None
|
||||
self.blocks: Optional[gradio.Blocks] = None
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@app.get("/login_check")
|
||||
@app.get("/login_check/")
|
||||
def login_check(user: str = Depends(get_current_user)):
|
||||
if app.auth is None or not (user is None):
|
||||
return
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated"
|
||||
)
|
||||
|
||||
@app.get("/token")
|
||||
@app.get("/token/")
|
||||
def get_token(request: Request) -> Optional[str]:
|
||||
token = request.cookies.get("access-token")
|
||||
return {"token": token, "user": app.tokens.get(token)}
|
||||
|
||||
@app.post("/login")
|
||||
@app.post("/login/")
|
||||
def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
username, password = form_data.username, form_data.password
|
||||
if (
|
||||
not callable(app.auth)
|
||||
and username in app.auth
|
||||
and app.auth[username] == password
|
||||
) or (callable(app.auth) and app.auth.__call__(username, password)):
|
||||
token = secrets.token_urlsafe(16)
|
||||
app.tokens[token] = username
|
||||
response = RedirectResponse(url="/", status_code=status.HTTP_302_FOUND)
|
||||
response.set_cookie(key="access-token", value=token, httponly=True)
|
||||
return response
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Incorrect credentials.")
|
||||
|
||||
###############
|
||||
# Main Routes
|
||||
###############
|
||||
|
||||
@app.head("/", response_class=HTMLResponse)
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
def main(request: Request, user: str = Depends(get_current_user)):
|
||||
if app.auth is None or not (user is None):
|
||||
config = app.blocks.config
|
||||
else:
|
||||
config = {
|
||||
"auth_required": True,
|
||||
"auth_message": app.blocks.auth_message,
|
||||
}
|
||||
|
||||
try:
|
||||
return templates.TemplateResponse(
|
||||
"frontend/index.html", {"request": request, "config": config}
|
||||
)
|
||||
except TemplateNotFound:
|
||||
raise ValueError(
|
||||
"Did you install Gradio from source files? You need to build "
|
||||
"the frontend by running /scripts/build_frontend.sh"
|
||||
)
|
||||
|
||||
@app.get("/config/", dependencies=[Depends(login_check)])
|
||||
@app.get("/config", dependencies=[Depends(login_check)])
|
||||
def get_config():
|
||||
return app.blocks.config
|
||||
|
||||
@app.get("/static/{path:path}")
|
||||
def static_resource(path: str):
|
||||
if app.blocks.share:
|
||||
return RedirectResponse(GRADIO_STATIC_ROOT + path)
|
||||
else:
|
||||
static_file = safe_join(STATIC_PATH_LIB, path)
|
||||
if static_file is not None:
|
||||
return FileResponse(static_file)
|
||||
raise HTTPException(status_code=404, detail="Static file not found")
|
||||
|
||||
@app.get("/assets/{path:path}")
|
||||
def build_resource(path: str):
|
||||
if app.blocks.share:
|
||||
return RedirectResponse(GRADIO_BUILD_ROOT + path)
|
||||
else:
|
||||
build_file = safe_join(BUILD_PATH_LIB, path)
|
||||
if build_file is not None:
|
||||
return FileResponse(build_file)
|
||||
raise HTTPException(status_code=404, detail="Build file not found")
|
||||
|
||||
@app.get("/favicon.ico")
|
||||
async def favicon():
|
||||
if app.blocks.favicon_path is None:
|
||||
return static_resource("img/logo.svg")
|
||||
else:
|
||||
return FileResponse(app.blocks.favicon_path)
|
||||
|
||||
@app.get("/file/{path:path}", dependencies=[Depends(login_check)])
|
||||
def file(path):
|
||||
if (
|
||||
app.blocks.encrypt
|
||||
and isinstance(app.blocks.examples, str)
|
||||
and path.startswith(app.blocks.examples)
|
||||
):
|
||||
with open(safe_join(app.cwd, path), "rb") as encrypted_file:
|
||||
encrypted_data = encrypted_file.read()
|
||||
file_data = encryptor.decrypt(app.blocks.encryption_key, encrypted_data)
|
||||
return FileResponse(
|
||||
io.BytesIO(file_data), attachment_filename=os.path.basename(path)
|
||||
)
|
||||
else:
|
||||
if Path(app.cwd).resolve() in Path(path).resolve().parents:
|
||||
return FileResponse(Path(path).resolve())
|
||||
|
||||
@app.get("/api", response_class=HTMLResponse) # Needed for Spaces
|
||||
@app.get("/api/", response_class=HTMLResponse)
|
||||
def api_docs(request: Request):
|
||||
inputs = [type(inp) for inp in app.blocks.input_components]
|
||||
outputs = [type(out) for out in app.blocks.output_components]
|
||||
input_types_doc, input_types = get_types(inputs, "input")
|
||||
output_types_doc, output_types = get_types(outputs, "output")
|
||||
input_names = [inp.get_block_name() for inp in app.blocks.input_components]
|
||||
output_names = [out.get_block_name() for out in app.blocks.output_components]
|
||||
if app.blocks.examples is not None:
|
||||
sample_inputs = app.blocks.examples[0]
|
||||
else:
|
||||
sample_inputs = [
|
||||
inp.generate_sample() for inp in app.blocks.input_components
|
||||
]
|
||||
docs = {
|
||||
"inputs": input_names,
|
||||
"outputs": output_names,
|
||||
"len_inputs": len(inputs),
|
||||
"len_outputs": len(outputs),
|
||||
"inputs_lower": [name.lower() for name in input_names],
|
||||
"outputs_lower": [name.lower() for name in output_names],
|
||||
"input_types": input_types,
|
||||
"output_types": output_types,
|
||||
"input_types_doc": input_types_doc,
|
||||
"output_types_doc": output_types_doc,
|
||||
"sample_inputs": sample_inputs,
|
||||
"auth": app.blocks.auth,
|
||||
"local_login_url": urllib.parse.urljoin(app.blocks.local_url, "login"),
|
||||
"local_api_url": urllib.parse.urljoin(app.blocks.local_url, "api/predict"),
|
||||
}
|
||||
return templates.TemplateResponse("api_docs.html", {"request": request, **docs})
|
||||
|
||||
@app.post("/api/predict/", dependencies=[Depends(login_check)])
|
||||
async def predict(body: PredictBody, username: str = Depends(get_current_user)):
|
||||
if hasattr(body, "session_hash"):
|
||||
if body.session_hash not in app.state_holder:
|
||||
app.state_holder[body.session_hash] = {
|
||||
_id: getattr(block, "value", None)
|
||||
for _id, block in app.blocks.blocks.items()
|
||||
if getattr(block, "stateful", False)
|
||||
}
|
||||
session_state = app.state_holder[body.session_hash]
|
||||
else:
|
||||
session_state = {}
|
||||
try:
|
||||
output = await run_in_threadpool(
|
||||
app.blocks.process_api,
|
||||
body,
|
||||
username,
|
||||
session_state,
|
||||
)
|
||||
except BaseException as error:
|
||||
if app.blocks.show_error:
|
||||
traceback.print_exc()
|
||||
return JSONResponse(content={"error": str(error)}, status_code=500)
|
||||
def configure_app(self, blocks: gradio.Blocks) -> None:
|
||||
auth = blocks.auth
|
||||
if auth is not None:
|
||||
if not callable(auth):
|
||||
self.auth = {account[0]: account[1] for account in auth}
|
||||
else:
|
||||
raise error
|
||||
return output
|
||||
self.auth = auth
|
||||
else:
|
||||
self.auth = None
|
||||
self.blocks = blocks
|
||||
self.cwd = os.getcwd()
|
||||
self.favicon_path = blocks.favicon_path
|
||||
self.tokens = {}
|
||||
|
||||
@app.post("/api/queue/push/", dependencies=[Depends(login_check)])
|
||||
async def queue_push(body: QueuePushBody):
|
||||
job_hash, queue_position = queueing.push(body)
|
||||
return {"hash": job_hash, "queue_position": queue_position}
|
||||
@staticmethod
|
||||
def create_app(blocks: gradio.Blocks) -> FastAPI:
|
||||
app = App(default_response_class=ORJSONResponse)
|
||||
app.configure_app(blocks)
|
||||
|
||||
@app.post("/api/queue/status/", dependencies=[Depends(login_check)])
|
||||
async def queue_status(body: QueueStatusBody):
|
||||
status, data = queueing.get_status(body.hash)
|
||||
return {"status": status, "data": data}
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
app.state_holder = {}
|
||||
|
||||
return app
|
||||
@app.get("/user")
|
||||
@app.get("/user/")
|
||||
def get_current_user(request: Request) -> Optional[str]:
|
||||
token = request.cookies.get("access-token")
|
||||
return app.tokens.get(token)
|
||||
|
||||
@app.get("/login_check")
|
||||
@app.get("/login_check/")
|
||||
def login_check(user: str = Depends(get_current_user)):
|
||||
if app.auth is None or not (user is None):
|
||||
return
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated"
|
||||
)
|
||||
|
||||
@app.get("/token")
|
||||
@app.get("/token/")
|
||||
def get_token(request: Request) -> dict:
|
||||
token = request.cookies.get("access-token")
|
||||
return {"token": token, "user": app.tokens.get(token)}
|
||||
|
||||
@app.post("/login")
|
||||
@app.post("/login/")
|
||||
def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
username, password = form_data.username, form_data.password
|
||||
if (
|
||||
not callable(app.auth)
|
||||
and username in app.auth
|
||||
and app.auth[username] == password
|
||||
) or (callable(app.auth) and app.auth.__call__(username, password)):
|
||||
token = secrets.token_urlsafe(16)
|
||||
app.tokens[token] = username
|
||||
response = RedirectResponse(url="/", status_code=status.HTTP_302_FOUND)
|
||||
response.set_cookie(key="access-token", value=token, httponly=True)
|
||||
return response
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Incorrect credentials.")
|
||||
|
||||
###############
|
||||
# Main Routes
|
||||
###############
|
||||
|
||||
@app.head("/", response_class=HTMLResponse)
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
def main(request: Request, user: str = Depends(get_current_user)):
|
||||
if app.auth is None or not (user is None):
|
||||
config = app.blocks.config
|
||||
else:
|
||||
config = {
|
||||
"auth_required": True,
|
||||
"auth_message": app.blocks.auth_message,
|
||||
}
|
||||
|
||||
try:
|
||||
return templates.TemplateResponse(
|
||||
"frontend/index.html", {"request": request, "config": config}
|
||||
)
|
||||
except TemplateNotFound:
|
||||
raise ValueError(
|
||||
"Did you install Gradio from source files? You need to build "
|
||||
"the frontend by running /scripts/build_frontend.sh"
|
||||
)
|
||||
|
||||
@app.get("/config/", dependencies=[Depends(login_check)])
|
||||
@app.get("/config", dependencies=[Depends(login_check)])
|
||||
def get_config():
|
||||
return app.blocks.config
|
||||
|
||||
@app.get("/static/{path:path}")
|
||||
def static_resource(path: str):
|
||||
if app.blocks.share:
|
||||
return RedirectResponse(GRADIO_STATIC_ROOT + path)
|
||||
else:
|
||||
static_file = safe_join(STATIC_PATH_LIB, path)
|
||||
if static_file is not None:
|
||||
return FileResponse(static_file)
|
||||
raise HTTPException(status_code=404, detail="Static file not found")
|
||||
|
||||
@app.get("/assets/{path:path}")
|
||||
def build_resource(path: str):
|
||||
if app.blocks.share:
|
||||
return RedirectResponse(GRADIO_BUILD_ROOT + path)
|
||||
else:
|
||||
build_file = safe_join(BUILD_PATH_LIB, path)
|
||||
if build_file is not None:
|
||||
return FileResponse(build_file)
|
||||
raise HTTPException(status_code=404, detail="Build file not found")
|
||||
|
||||
@app.get("/favicon.ico")
|
||||
async def favicon():
|
||||
if app.blocks.favicon_path is None:
|
||||
return static_resource("img/logo.svg")
|
||||
else:
|
||||
return FileResponse(app.blocks.favicon_path)
|
||||
|
||||
@app.get("/file/{path:path}", dependencies=[Depends(login_check)])
|
||||
def file(path):
|
||||
if (
|
||||
app.blocks.encrypt
|
||||
and isinstance(app.blocks.examples, str)
|
||||
and path.startswith(app.blocks.examples)
|
||||
):
|
||||
with open(safe_join(app.cwd, path), "rb") as encrypted_file:
|
||||
encrypted_data = encrypted_file.read()
|
||||
file_data = encryptor.decrypt(app.blocks.encryption_key, encrypted_data)
|
||||
return FileResponse(
|
||||
io.BytesIO(file_data), attachment_filename=os.path.basename(path)
|
||||
)
|
||||
else:
|
||||
if Path(app.cwd).resolve() in Path(path).resolve().parents:
|
||||
return FileResponse(Path(path).resolve())
|
||||
|
||||
@app.get("/api", response_class=HTMLResponse) # Needed for Spaces
|
||||
@app.get("/api/", response_class=HTMLResponse)
|
||||
def api_docs(request: Request):
|
||||
inputs = [type(inp) for inp in app.blocks.input_components]
|
||||
outputs = [type(out) for out in app.blocks.output_components]
|
||||
input_types_doc, input_types = get_types(inputs, "input")
|
||||
output_types_doc, output_types = get_types(outputs, "output")
|
||||
input_names = [inp.get_block_name() for inp in app.blocks.input_components]
|
||||
output_names = [
|
||||
out.get_block_name() for out in app.blocks.output_components
|
||||
]
|
||||
if app.blocks.examples is not None:
|
||||
sample_inputs = app.blocks.examples[0]
|
||||
else:
|
||||
sample_inputs = [
|
||||
inp.generate_sample() for inp in app.blocks.input_components
|
||||
]
|
||||
docs = {
|
||||
"inputs": input_names,
|
||||
"outputs": output_names,
|
||||
"len_inputs": len(inputs),
|
||||
"len_outputs": len(outputs),
|
||||
"inputs_lower": [name.lower() for name in input_names],
|
||||
"outputs_lower": [name.lower() for name in output_names],
|
||||
"input_types": input_types,
|
||||
"output_types": output_types,
|
||||
"input_types_doc": input_types_doc,
|
||||
"output_types_doc": output_types_doc,
|
||||
"sample_inputs": sample_inputs,
|
||||
"auth": app.blocks.auth,
|
||||
"local_login_url": urllib.parse.urljoin(app.blocks.local_url, "login"),
|
||||
"local_api_url": urllib.parse.urljoin(
|
||||
app.blocks.local_url, "api/predict"
|
||||
),
|
||||
}
|
||||
return templates.TemplateResponse(
|
||||
"api_docs.html", {"request": request, **docs}
|
||||
)
|
||||
|
||||
@app.post("/api/predict/", dependencies=[Depends(login_check)])
|
||||
async def predict(body: PredictBody, username: str = Depends(get_current_user)):
|
||||
if hasattr(body, "session_hash"):
|
||||
if body.session_hash not in app.state_holder:
|
||||
app.state_holder[body.session_hash] = {
|
||||
_id: getattr(block, "value", None)
|
||||
for _id, block in app.blocks.blocks.items()
|
||||
if getattr(block, "stateful", False)
|
||||
}
|
||||
session_state = app.state_holder[body.session_hash]
|
||||
else:
|
||||
session_state = {}
|
||||
try:
|
||||
output = await app.blocks.process_api(body, username, session_state)
|
||||
except BaseException as error:
|
||||
if app.blocks.show_error:
|
||||
traceback.print_exc()
|
||||
return JSONResponse(content={"error": str(error)}, status_code=500)
|
||||
else:
|
||||
raise error
|
||||
return output
|
||||
|
||||
@app.post("/api/queue/push/", dependencies=[Depends(login_check)])
|
||||
async def queue_push(body: QueuePushBody):
|
||||
job_hash, queue_position = queueing.push(body)
|
||||
return {"hash": job_hash, "queue_position": queue_position}
|
||||
|
||||
@app.post("/api/queue/status/", dependencies=[Depends(login_check)])
|
||||
async def queue_status(body: QueueStatusBody):
|
||||
status, data = queueing.get_status(body.hash)
|
||||
return {"status": status, "data": data}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
########
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Don't forget to run bash scripts/create_test_requirements.sh from unix or wsl when you update this file.
|
||||
asyncio
|
||||
IPython
|
||||
comet_ml
|
||||
coverage
|
||||
@ -12,6 +13,7 @@ pytest
|
||||
wandb
|
||||
huggingface_hub
|
||||
pytest-cov
|
||||
pytest-asyncio
|
||||
black
|
||||
isort
|
||||
flake8
|
||||
|
@ -14,6 +14,10 @@ asttokens==2.0.5
|
||||
# via stack-data
|
||||
astunparse==1.6.3
|
||||
# via tensorflow
|
||||
asyncio==3.4.3
|
||||
# via -r requirements.in
|
||||
atomicwrites==1.4.0
|
||||
# via pytest
|
||||
attrs==21.4.0
|
||||
# via
|
||||
# jsonschema
|
||||
@ -48,6 +52,12 @@ cloudpickle==2.0.0
|
||||
# via
|
||||
# mlflow
|
||||
# shap
|
||||
colorama==0.4.4
|
||||
# via
|
||||
# click
|
||||
# ipython
|
||||
# pytest
|
||||
# tqdm
|
||||
comet-ml==3.25.0
|
||||
# via -r requirements.in
|
||||
configobj==5.0.6
|
||||
@ -110,8 +120,6 @@ grpcio==1.43.0
|
||||
# via
|
||||
# tensorboard
|
||||
# tensorflow
|
||||
gunicorn==20.1.0
|
||||
# via mlflow
|
||||
h5py==3.6.0
|
||||
# via tensorflow
|
||||
huggingface-hub==0.4.0
|
||||
@ -216,8 +224,6 @@ pathspec==0.9.0
|
||||
# via black
|
||||
pathtools==0.1.2
|
||||
# via wandb
|
||||
pexpect==4.8.0
|
||||
# via ipython
|
||||
pickleshare==0.7.5
|
||||
# via ipython
|
||||
pillow==9.0.1
|
||||
@ -244,8 +250,6 @@ protobuf==3.19.4
|
||||
# wandb
|
||||
psutil==5.9.0
|
||||
# via wandb
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pure-eval==0.2.2
|
||||
# via stack-data
|
||||
py==1.11.0
|
||||
@ -273,7 +277,10 @@ pyrsistent==0.18.1
|
||||
pytest==7.0.0
|
||||
# via
|
||||
# -r requirements.in
|
||||
# pytest-asyncio
|
||||
# pytest-cov
|
||||
pytest-asyncio==0.18.3
|
||||
# via -r requirements.in
|
||||
pytest-cov==3.0.0
|
||||
# via -r requirements.in
|
||||
python-dateutil==2.8.2
|
||||
@ -424,6 +431,8 @@ urllib3[secure]==1.26.8
|
||||
# requests
|
||||
# selenium
|
||||
# sentry-sdk
|
||||
waitress==2.1.1
|
||||
# via mlflow
|
||||
wandb==0.12.10
|
||||
# via -r requirements.in
|
||||
wcwidth==0.2.5
|
||||
|
@ -1,9 +1,16 @@
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
import gradio as gr
|
||||
from gradio.routes import PredictBody
|
||||
from gradio.test_data.blocks_configs import XRAY_CONFIG
|
||||
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
|
||||
|
||||
class TestBlocks(unittest.TestCase):
|
||||
def test_xray(self):
|
||||
@ -53,6 +60,25 @@ class TestBlocks(unittest.TestCase):
|
||||
demo.load(fake_func, [], [textbox])
|
||||
self.assertEqual(XRAY_CONFIG, demo.get_config_file())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_function(self):
|
||||
async def wait():
|
||||
await asyncio.sleep(0.01)
|
||||
return True
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
text = gr.components.Textbox()
|
||||
button = gr.components.Button()
|
||||
button.click(wait, [text], [text])
|
||||
|
||||
body = PredictBody(data=1, fn_index=0)
|
||||
start = time.time()
|
||||
result = await demo.process_api(body)
|
||||
end = time.time()
|
||||
difference = end - start
|
||||
assert difference >= 0.01
|
||||
assert result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user