mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-21 02:19:59 +08:00
Merge pull request #472 from gradio-app/concurrency
Fixed concurrency issue
This commit is contained in:
commit
8efa2d8059
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
from fastapi import FastAPI, Request, Depends, HTTPException, status
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, HTMLResponse, FileResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
@ -37,6 +38,7 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
|
||||
|
||||
|
||||
@ -178,23 +180,28 @@ async def predict(
|
||||
if body.get("example_id") != None:
|
||||
example_id = body["example_id"]
|
||||
if app.interface.cache_examples:
|
||||
prediction = load_from_cache(app.interface, example_id)
|
||||
prediction = await run_in_threadpool(
|
||||
load_from_cache, app.interface, example_id)
|
||||
durations = None
|
||||
else:
|
||||
prediction, durations = process_example(app.interface, example_id)
|
||||
prediction, durations = await run_in_threadpool(
|
||||
process_example, app.interface, example_id)
|
||||
else:
|
||||
raw_input = body["data"]
|
||||
if app.interface.show_error:
|
||||
try:
|
||||
prediction, durations = app.interface.process(raw_input)
|
||||
prediction, durations = await run_in_threadpool(
|
||||
app.interface.process, raw_input)
|
||||
except BaseException as error:
|
||||
traceback.print_exc()
|
||||
return JSONResponse(content={"error": str(error)},
|
||||
status_code=500)
|
||||
else:
|
||||
prediction, durations = app.interface.process(raw_input)
|
||||
prediction, durations = await run_in_threadpool(
|
||||
app.interface.process, raw_input)
|
||||
if app.interface.allow_flagging == "auto":
|
||||
flag_index = app.interface.flagging_callback.flag(
|
||||
flag_index = await run_in_threadpool(
|
||||
app.interface.flagging_callback.flag,
|
||||
app.interface, raw_input, prediction,
|
||||
flag_option="" if app.interface.flagging_options else None,
|
||||
username=username)
|
||||
@ -216,7 +223,8 @@ async def flag(
|
||||
await utils.log_feature_analytics(app.interface.ip_address, 'flag')
|
||||
body = await request.json()
|
||||
data = body['data']
|
||||
app.interface.flagging_callback.flag(
|
||||
await run_in_threadpool(
|
||||
app.interface.flagging_callback.flag,
|
||||
app.interface, data['input_data'], data['output_data'],
|
||||
flag_option=data.get("flag_option"), flag_index=data.get("flag_index"),
|
||||
username=username)
|
||||
@ -229,8 +237,8 @@ async def interpret(request: Request):
|
||||
await utils.log_feature_analytics(app.interface.ip_address, 'interpret')
|
||||
body = await request.json()
|
||||
raw_input = body["data"]
|
||||
interpretation_scores, alternative_outputs = app.interface.interpret(
|
||||
raw_input)
|
||||
interpretation_scores, alternative_outputs = await run_in_threadpool(
|
||||
app.interface.interpret, raw_input)
|
||||
return {
|
||||
"interpretation_scores": interpretation_scores,
|
||||
"alternative_outputs": alternative_outputs
|
||||
|
@ -147,7 +147,7 @@ def start_server(
|
||||
app.queue_thread.start()
|
||||
if interface.save_to is not None: # Used for selenium tests
|
||||
interface.save_to["port"] = port
|
||||
|
||||
|
||||
config = uvicorn.Config(app=app, port=port, host=server_name,
|
||||
log_level="warning")
|
||||
server = Server(config=config)
|
||||
|
Loading…
Reference in New Issue
Block a user