mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-27 02:30:17 +08:00
Merge pull request #472 from gradio-app/concurrency
Fixed concurrency issue
This commit is contained in:
commit
8efa2d8059
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from fastapi import FastAPI, Request, Depends, HTTPException, status
|
from fastapi import FastAPI, Request, Depends, HTTPException, status
|
||||||
|
from fastapi.concurrency import run_in_threadpool
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse, HTMLResponse, FileResponse
|
from fastapi.responses import JSONResponse, HTMLResponse, FileResponse
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
@ -37,6 +38,7 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
|
templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
|
||||||
|
|
||||||
|
|
||||||
@ -178,23 +180,28 @@ async def predict(
|
|||||||
if body.get("example_id") != None:
|
if body.get("example_id") != None:
|
||||||
example_id = body["example_id"]
|
example_id = body["example_id"]
|
||||||
if app.interface.cache_examples:
|
if app.interface.cache_examples:
|
||||||
prediction = load_from_cache(app.interface, example_id)
|
prediction = await run_in_threadpool(
|
||||||
|
load_from_cache, app.interface, example_id)
|
||||||
durations = None
|
durations = None
|
||||||
else:
|
else:
|
||||||
prediction, durations = process_example(app.interface, example_id)
|
prediction, durations = await run_in_threadpool(
|
||||||
|
process_example, app.interface, example_id)
|
||||||
else:
|
else:
|
||||||
raw_input = body["data"]
|
raw_input = body["data"]
|
||||||
if app.interface.show_error:
|
if app.interface.show_error:
|
||||||
try:
|
try:
|
||||||
prediction, durations = app.interface.process(raw_input)
|
prediction, durations = await run_in_threadpool(
|
||||||
|
app.interface.process, raw_input)
|
||||||
except BaseException as error:
|
except BaseException as error:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return JSONResponse(content={"error": str(error)},
|
return JSONResponse(content={"error": str(error)},
|
||||||
status_code=500)
|
status_code=500)
|
||||||
else:
|
else:
|
||||||
prediction, durations = app.interface.process(raw_input)
|
prediction, durations = await run_in_threadpool(
|
||||||
|
app.interface.process, raw_input)
|
||||||
if app.interface.allow_flagging == "auto":
|
if app.interface.allow_flagging == "auto":
|
||||||
flag_index = app.interface.flagging_callback.flag(
|
flag_index = await run_in_threadpool(
|
||||||
|
app.interface.flagging_callback.flag,
|
||||||
app.interface, raw_input, prediction,
|
app.interface, raw_input, prediction,
|
||||||
flag_option="" if app.interface.flagging_options else None,
|
flag_option="" if app.interface.flagging_options else None,
|
||||||
username=username)
|
username=username)
|
||||||
@ -216,7 +223,8 @@ async def flag(
|
|||||||
await utils.log_feature_analytics(app.interface.ip_address, 'flag')
|
await utils.log_feature_analytics(app.interface.ip_address, 'flag')
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
data = body['data']
|
data = body['data']
|
||||||
app.interface.flagging_callback.flag(
|
await run_in_threadpool(
|
||||||
|
app.interface.flagging_callback.flag,
|
||||||
app.interface, data['input_data'], data['output_data'],
|
app.interface, data['input_data'], data['output_data'],
|
||||||
flag_option=data.get("flag_option"), flag_index=data.get("flag_index"),
|
flag_option=data.get("flag_option"), flag_index=data.get("flag_index"),
|
||||||
username=username)
|
username=username)
|
||||||
@ -229,8 +237,8 @@ async def interpret(request: Request):
|
|||||||
await utils.log_feature_analytics(app.interface.ip_address, 'interpret')
|
await utils.log_feature_analytics(app.interface.ip_address, 'interpret')
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
raw_input = body["data"]
|
raw_input = body["data"]
|
||||||
interpretation_scores, alternative_outputs = app.interface.interpret(
|
interpretation_scores, alternative_outputs = await run_in_threadpool(
|
||||||
raw_input)
|
app.interface.interpret, raw_input)
|
||||||
return {
|
return {
|
||||||
"interpretation_scores": interpretation_scores,
|
"interpretation_scores": interpretation_scores,
|
||||||
"alternative_outputs": alternative_outputs
|
"alternative_outputs": alternative_outputs
|
||||||
|
@ -147,7 +147,7 @@ def start_server(
|
|||||||
app.queue_thread.start()
|
app.queue_thread.start()
|
||||||
if interface.save_to is not None: # Used for selenium tests
|
if interface.save_to is not None: # Used for selenium tests
|
||||||
interface.save_to["port"] = port
|
interface.save_to["port"] = port
|
||||||
|
|
||||||
config = uvicorn.Config(app=app, port=port, host=server_name,
|
config = uvicorn.Config(app=app, port=port, host=server_name,
|
||||||
log_level="warning")
|
log_level="warning")
|
||||||
server = Server(config=config)
|
server = Server(config=config)
|
||||||
|
Loading…
Reference in New Issue
Block a user