diff --git a/gradio/app.py b/gradio/app.py index 0cd1b37d81..40a0a56552 100644 --- a/gradio/app.py +++ b/gradio/app.py @@ -2,6 +2,7 @@ from __future__ import annotations from fastapi import FastAPI, Request, Depends, HTTPException, status +from fastapi.concurrency import run_in_threadpool from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, HTMLResponse, FileResponse from fastapi.security import OAuth2PasswordRequestForm @@ -37,6 +38,7 @@ app.add_middleware( allow_headers=["*"], ) + templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB) @@ -178,23 +180,28 @@ async def predict( if body.get("example_id") != None: example_id = body["example_id"] if app.interface.cache_examples: - prediction = load_from_cache(app.interface, example_id) + prediction = await run_in_threadpool( + load_from_cache, app.interface, example_id) durations = None else: - prediction, durations = process_example(app.interface, example_id) + prediction, durations = await run_in_threadpool( + process_example, app.interface, example_id) else: raw_input = body["data"] if app.interface.show_error: try: - prediction, durations = app.interface.process(raw_input) + prediction, durations = await run_in_threadpool( + app.interface.process, raw_input) except BaseException as error: traceback.print_exc() return JSONResponse(content={"error": str(error)}, status_code=500) else: - prediction, durations = app.interface.process(raw_input) + prediction, durations = await run_in_threadpool( + app.interface.process, raw_input) if app.interface.allow_flagging == "auto": - flag_index = app.interface.flagging_callback.flag( + flag_index = await run_in_threadpool( + app.interface.flagging_callback.flag, app.interface, raw_input, prediction, flag_option="" if app.interface.flagging_options else None, username=username) @@ -216,7 +223,8 @@ async def flag( await utils.log_feature_analytics(app.interface.ip_address, 'flag') body = await request.json() data = body['data'] - app.interface.flagging_callback.flag( + await run_in_threadpool( + app.interface.flagging_callback.flag, app.interface, data['input_data'], data['output_data'], flag_option=data.get("flag_option"), flag_index=data.get("flag_index"), username=username) @@ -229,8 +237,8 @@ async def interpret(request: Request): await utils.log_feature_analytics(app.interface.ip_address, 'interpret') body = await request.json() raw_input = body["data"] - interpretation_scores, alternative_outputs = app.interface.interpret( - raw_input) + interpretation_scores, alternative_outputs = await run_in_threadpool( + app.interface.interpret, raw_input) return { "interpretation_scores": interpretation_scores, "alternative_outputs": alternative_outputs diff --git a/gradio/networking.py b/gradio/networking.py index 932ea7beb2..33b1a512c4 100644 --- a/gradio/networking.py +++ b/gradio/networking.py @@ -147,7 +147,7 @@ def start_server( app.queue_thread.start() if interface.save_to is not None: # Used for selenium tests interface.save_to["port"] = port - + config = uvicorn.Config(app=app, port=port, host=server_name, log_level="warning") server = Server(config=config)