Start queue when gradio is a sub application (#2319)

* First stab at it

* Use util methos

* lint

* Test

* Fix formatting

* Try out setting predict endpoint from websocket request

* lint

* Fix bug

* Address comments - remove server and port

* Skip in 3.7

* Fix documentation

* Add default 🤦

* docs tweak

* Add back imports that were deleted by wrong linter version

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Freddy Boulton 2022-09-23 16:01:44 -04:00 committed by GitHub
parent 3a4a82634e
commit 9f7dd05b72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 144 additions and 18 deletions

View File

@ -5,13 +5,15 @@ CUSTOM_PATH = "/gradio"
app = FastAPI()
@app.get("/")
def read_main():
return {"message": "This is your main app"}
io = gr.Interface(lambda x: "Hello, " + x + "!", "textbox", "textbox")
gradio_app = gr.routes.App.create_app(io)
app = gr.mount_gradio_app(app, io, path=CUSTOM_PATH)
app.mount(CUSTOM_PATH, gradio_app)
# Run this from the terminal as you would normally start a FastAPI app: `uvicorn run:app` and navigate to http://localhost:8000/gradio in your browser.
# Run this from the terminal as you would normally start a FastAPI app: `uvicorn run:app`
# and navigate to http://localhost:8000/gradio in your browser.

View File

@ -57,6 +57,7 @@ from gradio.interface import Interface, TabbedInterface, close_all
from gradio.ipython_ext import load_ipython_extension
from gradio.layouts import Accordion, Box, Column, Group, Row, Tab, TabItem, Tabs
from gradio.mix import Parallel, Series
from gradio.routes import mount_gradio_app
from gradio.templates import (
Files,
Highlight,

View File

@ -912,6 +912,7 @@ class Blocks(BlockContext):
update_intervals=status_update_rate if status_update_rate != "auto" else 1,
max_size=max_size,
)
self.config = self.get_config_file()
return self
def launch(
@ -1257,3 +1258,9 @@ class Blocks(BlockContext):
no_target=True,
queue=False,
)
def startup_events(self):
"""Events that should be run when the app containing this block starts up."""
if self.enable_queue:
utils.run_coro_in_background(self._queue.start)
utils.run_coro_in_background(self.create_limiter)

View File

@ -14,10 +14,11 @@ from collections import defaultdict
from copy import deepcopy
from pathlib import Path
from typing import Any, List, Optional, Type
from urllib.parse import urlparse
import fastapi
import orjson
import pkg_resources
import requests
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
@ -30,6 +31,7 @@ from starlette.websockets import WebSocket, WebSocketState
import gradio
from gradio import encryptor, utils
from gradio.documentation import document, set_documentation_group
from gradio.exceptions import Error
from gradio.queue import Estimation, Event
@ -308,6 +310,12 @@ class App(FastAPI):
@app.websocket("/queue/join")
async def join_queue(websocket: WebSocket):
if app.blocks._queue.server_path is None:
print(f"WS: {str(websocket.url)}")
app_url = get_server_url_from_ws_url(str(websocket.url))
print(f"Server URL: {app_url}")
app.blocks._queue.set_url(app_url)
await websocket.accept()
event = Event(websocket)
rank = app.blocks._queue.push(event)
@ -335,12 +343,7 @@ class App(FastAPI):
dependencies=[Depends(login_check)],
)
async def startup_events():
from gradio.utils import run_coro_in_background
if app.blocks.enable_queue:
gradio.utils.run_coro_in_background(app.blocks._queue.start)
gradio.utils.run_coro_in_background(app.blocks.create_limiter)
app.blocks.startup_events()
return True
return app
@ -382,3 +385,52 @@ def get_types(cls_set: List[Type]):
types.append(line.split("value (")[1].split(")")[0])
docset.append(doc_lines[1].split(":")[-1])
return docset, types
def get_server_url_from_ws_url(ws_url: str):
ws_url = urlparse(ws_url)
scheme = "http" if ws_url.scheme == "ws" else "https"
port = f":{ws_url.port}" if ws_url.port else ""
return f"{scheme}://{ws_url.hostname}{port}{ws_url.path.replace('queue/join', '')}"
set_documentation_group("routes")
@document()
def mount_gradio_app(
app: fastapi.FastAPI,
blocks: gradio.Blocks,
path: str,
gradio_api_url: Optional[str] = None,
) -> fastapi.FastAPI:
"""Mount a gradio.Blocks to an existing FastAPI application.
Parameters:
app: The parent FastAPI application.
blocks: The blocks object we want to mount to the parent app.
path: The path at which the gradio application will be mounted.
gradio_api_url: The full url at which the gradio app will run. This is only needed if deploying to Huggingface spaces of if the websocket endpoints of your deployed app are on a different network location than the gradio app. If deploying to spaces, set gradio_api_url to 'http://localhost:7860/'
Example:
from fastapi import FastAPI
import gradio as gr
app = FastAPI()
@app.get("/")
def read_main():
return {"message": "This is your main app"}
io = gr.Interface(lambda x: "Hello, " + x + "!", "textbox", "textbox")
app = gr.mount_gradio_app(app, io, path="/gradio")
# Then run `uvicorn run:app` from the terminal and navigate to http://localhost:8000/gradio.
"""
gradio_app = App.create_app(blocks)
@app.on_event("startup")
async def start_queue():
if gradio_app.blocks.enable_queue:
if gradio_api_url:
gradio_app.blocks._queue.set_url(gradio_api_url)
gradio_app.blocks.startup_events()
app.mount(path, gradio_app)
return app

View File

@ -2,7 +2,6 @@
from __future__ import annotations
import ast
import asyncio
import copy
import inspect
@ -13,7 +12,6 @@ import pkgutil
import random
import warnings
from contextlib import contextmanager
from copy import deepcopy
from distutils.version import StrictVersion
from enum import Enum
from numbers import Number
@ -40,7 +38,6 @@ from pydantic import BaseModel, Json, parse_obj_as
import gradio
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from gradio import Blocks, Interface
from gradio.blocks import BlockContext
from gradio.components import Component

View File

@ -108,7 +108,8 @@ demo.launch(auth=same_auth)
## Mounting Within Another FastAPI App
In some cases, you might have an existing FastAPI app, and you'd like to add a path for a Gradio demo. You can do this by easily using the `gradio.routes.App.create_app()` function, which creates a FastAPI app (but does not launch it), and then adding it to your existing FastAPI app with `FastAPI.mount()`.
In some cases, you might have an existing FastAPI app, and you'd like to add a path for a Gradio demo.
You can easily do this with `gradio.mount_gradio_app()`.
Here's a complete example:

View File

@ -357,6 +357,7 @@ class TestCallFunction:
)
demo.queue()
assert demo.config["enable_queue"]
output = await demo.call_function(0, [3])
assert output["prediction"] == 0

View File

@ -1,8 +1,12 @@
"""Contains tests for networking.py and app.py"""
import json
import os
import sys
import unittest
from unittest.mock import patch
import pytest
import websockets
from fastapi import FastAPI
from fastapi.testclient import TestClient
@ -214,5 +218,45 @@ class TestAuthenticatedRoutes(unittest.TestCase):
close_all()
@pytest.mark.asyncio
@pytest.mark.skipif(
sys.version_info < (3, 8),
reason="Mocks don't work with async context managers in 3.7",
)
@patch("gradio.routes.get_server_url_from_ws_url", return_value="foo_url")
async def test_queue_join_routes_sets_url_if_none_set(mock_get_url):
io = Interface(lambda x: x, "text", "text").queue()
app, _, _ = io.launch(prevent_thread_lock=True)
io._queue.server_path = None
async with websockets.connect(
f"{io.local_url.replace('http', 'ws')}queue/join"
) as ws:
completed = False
while not completed:
msg = json.loads(await ws.recv())
if msg["msg"] == "send_data":
await ws.send(json.dumps({"data": ["foo"], "fn_index": 0}))
completed = msg["msg"] == "process_completed"
assert io._queue.server_path == "foo_url"
@pytest.mark.parametrize(
"ws_url,answer",
[
("ws://127.0.0.1:7861/queue/join", "http://127.0.0.1:7861/"),
(
"ws://127.0.0.1:7861/gradio/gradio/gradio/queue/join",
"http://127.0.0.1:7861/gradio/gradio/gradio/",
),
(
"wss://huggingface.co.tech/path/queue/join",
"https://huggingface.co.tech/path/",
),
],
)
def test_get_server_url_from_ws_url(ws_url, answer):
assert routes.get_server_url_from_ws_url(ws_url) == answer
if __name__ == "__main__":
unittest.main()

View File

@ -116,16 +116,18 @@ export const fn =
var ws_protocol = ws_endpoint.startsWith("https") ? "wss:" : "ws:";
if (is_space) {
const SPACE_REGEX = /embed\/(.*)\/\+/g;
var ws_path = Array.from(ws_endpoint.matchAll(SPACE_REGEX))[0][1];
var ws_path = Array.from(
ws_endpoint.matchAll(SPACE_REGEX)
)[0][1].concat("/");
var ws_host = "spaces.huggingface.tech/";
} else {
var ws_path = location.pathname === "/" ? "" : location.pathname;
var ws_path = location.pathname === "/" ? "/" : location.pathname;
var ws_host =
BUILD_MODE === "dev" || location.origin === "http://localhost:3000"
? BACKEND_URL.replace("http://", "").slice(0, -1)
: location.host;
}
const WS_ENDPOINT = `${ws_protocol}//${ws_host}${ws_path}/queue/join`;
const WS_ENDPOINT = `${ws_protocol}//${ws_host}${ws_path}queue/join`;
var websocket = new WebSocket(WS_ENDPOINT);
ws_map.set(fn_index, websocket);

View File

@ -73,6 +73,10 @@
<a class="px-4 block thin-link" href="#{{ component['name'].lower() }}">{{ component['name'] }}</a>
{% endfor %}
<a class="thin-link px-4 block" href="#update">Update</a>
<a class="link px-4 my-2 block" href="#routes">Routes
{% for component in docs["routes"] %}
<a class="px-4 block thin-link" href="#{{ component['name'].lower() }}">{{ component['name'] }}</a>
{% endfor %}
</div>
<div class="flex flex-col">
<p class="bg-gradient-to-r from-orange-100 to-orange-50 border border-orange-200 px-4 py-1 rounded-full text-orange-800 mb-1">
@ -202,6 +206,21 @@
{% endwith %}
</div>
</section>
<section id="routes" class="pt-2 flex flex-col gap-10">
<div>
<h2 id="routes-header"
class="text-4xl font-light mb-2 pt-2 text-orange-500">Routes</h2>
<p class="mt-8 text-lg">
Gradio includes some helper functions for exposing and interacting with the FastAPI app
used to run your demo.
</p>
</div>
{% for component in docs["routes"] %}
{% with obj=component, parent="gradio" %}
{% include "docs/obj_doc_template.html" %}
{% endwith %}
{% endfor %}
</section>
</div>
</main>
<script src="/assets/prism.js"></script>