Merge pull request #472 from gradio-app/concurrency

Fixed concurrency issue
This commit is contained in:
Abubakar Abid 2022-01-20 17:14:54 -06:00 committed by GitHub
commit 8efa2d8059
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 9 deletions

View File

@ -2,6 +2,7 @@
from __future__ import annotations
from fastapi import FastAPI, Request, Depends, HTTPException, status
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, HTMLResponse, FileResponse
from fastapi.security import OAuth2PasswordRequestForm
@ -37,6 +38,7 @@ app.add_middleware(
allow_headers=["*"],
)
templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
@ -178,23 +180,28 @@ async def predict(
if body.get("example_id") != None:
example_id = body["example_id"]
if app.interface.cache_examples:
prediction = load_from_cache(app.interface, example_id)
prediction = await run_in_threadpool(
load_from_cache, app.interface, example_id)
durations = None
else:
prediction, durations = process_example(app.interface, example_id)
prediction, durations = await run_in_threadpool(
process_example, app.interface, example_id)
else:
raw_input = body["data"]
if app.interface.show_error:
try:
prediction, durations = app.interface.process(raw_input)
prediction, durations = await run_in_threadpool(
app.interface.process, raw_input)
except BaseException as error:
traceback.print_exc()
return JSONResponse(content={"error": str(error)},
status_code=500)
else:
prediction, durations = app.interface.process(raw_input)
prediction, durations = await run_in_threadpool(
app.interface.process, raw_input)
if app.interface.allow_flagging == "auto":
flag_index = app.interface.flagging_callback.flag(
flag_index = await run_in_threadpool(
app.interface.flagging_callback.flag,
app.interface, raw_input, prediction,
flag_option="" if app.interface.flagging_options else None,
username=username)
@ -216,7 +223,8 @@ async def flag(
await utils.log_feature_analytics(app.interface.ip_address, 'flag')
body = await request.json()
data = body['data']
app.interface.flagging_callback.flag(
await run_in_threadpool(
app.interface.flagging_callback.flag,
app.interface, data['input_data'], data['output_data'],
flag_option=data.get("flag_option"), flag_index=data.get("flag_index"),
username=username)
@ -229,8 +237,8 @@ async def interpret(request: Request):
await utils.log_feature_analytics(app.interface.ip_address, 'interpret')
body = await request.json()
raw_input = body["data"]
interpretation_scores, alternative_outputs = app.interface.interpret(
raw_input)
interpretation_scores, alternative_outputs = await run_in_threadpool(
app.interface.interpret, raw_input)
return {
"interpretation_scores": interpretation_scores,
"alternative_outputs": alternative_outputs

View File

@ -147,7 +147,7 @@ def start_server(
app.queue_thread.start()
if interface.save_to is not None: # Used for selenium tests
interface.save_to["port"] = port
config = uvicorn.Config(app=app, port=port, host=server_name,
log_level="warning")
server = Server(config=config)