mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-12 12:40:29 +08:00
Start queue when gradio is a sub application (#2319)
* First stab at it
* Use util methos
* lint
* Test
* Fix formatting
* Try out setting predict endpoint from websocket request
* lint
* Fix bug
* Address comments - remove server and port
* Skip in 3.7
* Fix documentation
* Add default 🤦
* docs tweak
* Add back imports that were deleted by wrong linter version
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
3a4a82634e
commit
9f7dd05b72
@ -5,13 +5,15 @@ CUSTOM_PATH = "/gradio"
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.get("/")
|
||||
def read_main():
|
||||
return {"message": "This is your main app"}
|
||||
|
||||
|
||||
io = gr.Interface(lambda x: "Hello, " + x + "!", "textbox", "textbox")
|
||||
gradio_app = gr.routes.App.create_app(io)
|
||||
app = gr.mount_gradio_app(app, io, path=CUSTOM_PATH)
|
||||
|
||||
app.mount(CUSTOM_PATH, gradio_app)
|
||||
|
||||
# Run this from the terminal as you would normally start a FastAPI app: `uvicorn run:app` and navigate to http://localhost:8000/gradio in your browser.
|
||||
# Run this from the terminal as you would normally start a FastAPI app: `uvicorn run:app`
|
||||
# and navigate to http://localhost:8000/gradio in your browser.
|
||||
|
@ -57,6 +57,7 @@ from gradio.interface import Interface, TabbedInterface, close_all
|
||||
from gradio.ipython_ext import load_ipython_extension
|
||||
from gradio.layouts import Accordion, Box, Column, Group, Row, Tab, TabItem, Tabs
|
||||
from gradio.mix import Parallel, Series
|
||||
from gradio.routes import mount_gradio_app
|
||||
from gradio.templates import (
|
||||
Files,
|
||||
Highlight,
|
||||
|
@ -912,6 +912,7 @@ class Blocks(BlockContext):
|
||||
update_intervals=status_update_rate if status_update_rate != "auto" else 1,
|
||||
max_size=max_size,
|
||||
)
|
||||
self.config = self.get_config_file()
|
||||
return self
|
||||
|
||||
def launch(
|
||||
@ -1257,3 +1258,9 @@ class Blocks(BlockContext):
|
||||
no_target=True,
|
||||
queue=False,
|
||||
)
|
||||
|
||||
def startup_events(self):
|
||||
"""Events that should be run when the app containing this block starts up."""
|
||||
if self.enable_queue:
|
||||
utils.run_coro_in_background(self._queue.start)
|
||||
utils.run_coro_in_background(self.create_limiter)
|
||||
|
@ -14,10 +14,11 @@ from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Type
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import fastapi
|
||||
import orjson
|
||||
import pkg_resources
|
||||
import requests
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
|
||||
@ -30,6 +31,7 @@ from starlette.websockets import WebSocket, WebSocketState
|
||||
|
||||
import gradio
|
||||
from gradio import encryptor, utils
|
||||
from gradio.documentation import document, set_documentation_group
|
||||
from gradio.exceptions import Error
|
||||
from gradio.queue import Estimation, Event
|
||||
|
||||
@ -308,6 +310,12 @@ class App(FastAPI):
|
||||
|
||||
@app.websocket("/queue/join")
|
||||
async def join_queue(websocket: WebSocket):
|
||||
if app.blocks._queue.server_path is None:
|
||||
print(f"WS: {str(websocket.url)}")
|
||||
app_url = get_server_url_from_ws_url(str(websocket.url))
|
||||
print(f"Server URL: {app_url}")
|
||||
app.blocks._queue.set_url(app_url)
|
||||
|
||||
await websocket.accept()
|
||||
event = Event(websocket)
|
||||
rank = app.blocks._queue.push(event)
|
||||
@ -335,12 +343,7 @@ class App(FastAPI):
|
||||
dependencies=[Depends(login_check)],
|
||||
)
|
||||
async def startup_events():
|
||||
from gradio.utils import run_coro_in_background
|
||||
|
||||
if app.blocks.enable_queue:
|
||||
gradio.utils.run_coro_in_background(app.blocks._queue.start)
|
||||
gradio.utils.run_coro_in_background(app.blocks.create_limiter)
|
||||
|
||||
app.blocks.startup_events()
|
||||
return True
|
||||
|
||||
return app
|
||||
@ -382,3 +385,52 @@ def get_types(cls_set: List[Type]):
|
||||
types.append(line.split("value (")[1].split(")")[0])
|
||||
docset.append(doc_lines[1].split(":")[-1])
|
||||
return docset, types
|
||||
|
||||
|
||||
def get_server_url_from_ws_url(ws_url: str):
|
||||
ws_url = urlparse(ws_url)
|
||||
scheme = "http" if ws_url.scheme == "ws" else "https"
|
||||
port = f":{ws_url.port}" if ws_url.port else ""
|
||||
return f"{scheme}://{ws_url.hostname}{port}{ws_url.path.replace('queue/join', '')}"
|
||||
|
||||
|
||||
set_documentation_group("routes")
|
||||
|
||||
|
||||
@document()
|
||||
def mount_gradio_app(
|
||||
app: fastapi.FastAPI,
|
||||
blocks: gradio.Blocks,
|
||||
path: str,
|
||||
gradio_api_url: Optional[str] = None,
|
||||
) -> fastapi.FastAPI:
|
||||
"""Mount a gradio.Blocks to an existing FastAPI application.
|
||||
|
||||
Parameters:
|
||||
app: The parent FastAPI application.
|
||||
blocks: The blocks object we want to mount to the parent app.
|
||||
path: The path at which the gradio application will be mounted.
|
||||
gradio_api_url: The full url at which the gradio app will run. This is only needed if deploying to Huggingface spaces of if the websocket endpoints of your deployed app are on a different network location than the gradio app. If deploying to spaces, set gradio_api_url to 'http://localhost:7860/'
|
||||
Example:
|
||||
from fastapi import FastAPI
|
||||
import gradio as gr
|
||||
app = FastAPI()
|
||||
@app.get("/")
|
||||
def read_main():
|
||||
return {"message": "This is your main app"}
|
||||
io = gr.Interface(lambda x: "Hello, " + x + "!", "textbox", "textbox")
|
||||
app = gr.mount_gradio_app(app, io, path="/gradio")
|
||||
# Then run `uvicorn run:app` from the terminal and navigate to http://localhost:8000/gradio.
|
||||
"""
|
||||
|
||||
gradio_app = App.create_app(blocks)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def start_queue():
|
||||
if gradio_app.blocks.enable_queue:
|
||||
if gradio_api_url:
|
||||
gradio_app.blocks._queue.set_url(gradio_api_url)
|
||||
gradio_app.blocks.startup_events()
|
||||
|
||||
app.mount(path, gradio_app)
|
||||
return app
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import copy
|
||||
import inspect
|
||||
@ -13,7 +12,6 @@ import pkgutil
|
||||
import random
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from distutils.version import StrictVersion
|
||||
from enum import Enum
|
||||
from numbers import Number
|
||||
@ -40,7 +38,6 @@ from pydantic import BaseModel, Json, parse_obj_as
|
||||
import gradio
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
from gradio import Blocks, Interface
|
||||
from gradio.blocks import BlockContext
|
||||
from gradio.components import Component
|
||||
|
||||
|
@ -108,7 +108,8 @@ demo.launch(auth=same_auth)
|
||||
|
||||
## Mounting Within Another FastAPI App
|
||||
|
||||
In some cases, you might have an existing FastAPI app, and you'd like to add a path for a Gradio demo. You can do this by easily using the `gradio.routes.App.create_app()` function, which creates a FastAPI app (but does not launch it), and then adding it to your existing FastAPI app with `FastAPI.mount()`.
|
||||
In some cases, you might have an existing FastAPI app, and you'd like to add a path for a Gradio demo.
|
||||
You can easily do this with `gradio.mount_gradio_app()`.
|
||||
|
||||
Here's a complete example:
|
||||
|
@ -357,6 +357,7 @@ class TestCallFunction:
|
||||
)
|
||||
|
||||
demo.queue()
|
||||
assert demo.config["enable_queue"]
|
||||
|
||||
output = await demo.call_function(0, [3])
|
||||
assert output["prediction"] == 0
|
||||
|
@ -1,8 +1,12 @@
|
||||
"""Contains tests for networking.py and app.py"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import websockets
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
@ -214,5 +218,45 @@ class TestAuthenticatedRoutes(unittest.TestCase):
|
||||
close_all()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 8),
|
||||
reason="Mocks don't work with async context managers in 3.7",
|
||||
)
|
||||
@patch("gradio.routes.get_server_url_from_ws_url", return_value="foo_url")
|
||||
async def test_queue_join_routes_sets_url_if_none_set(mock_get_url):
|
||||
io = Interface(lambda x: x, "text", "text").queue()
|
||||
app, _, _ = io.launch(prevent_thread_lock=True)
|
||||
io._queue.server_path = None
|
||||
async with websockets.connect(
|
||||
f"{io.local_url.replace('http', 'ws')}queue/join"
|
||||
) as ws:
|
||||
completed = False
|
||||
while not completed:
|
||||
msg = json.loads(await ws.recv())
|
||||
if msg["msg"] == "send_data":
|
||||
await ws.send(json.dumps({"data": ["foo"], "fn_index": 0}))
|
||||
completed = msg["msg"] == "process_completed"
|
||||
assert io._queue.server_path == "foo_url"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ws_url,answer",
|
||||
[
|
||||
("ws://127.0.0.1:7861/queue/join", "http://127.0.0.1:7861/"),
|
||||
(
|
||||
"ws://127.0.0.1:7861/gradio/gradio/gradio/queue/join",
|
||||
"http://127.0.0.1:7861/gradio/gradio/gradio/",
|
||||
),
|
||||
(
|
||||
"wss://huggingface.co.tech/path/queue/join",
|
||||
"https://huggingface.co.tech/path/",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_server_url_from_ws_url(ws_url, answer):
|
||||
assert routes.get_server_url_from_ws_url(ws_url) == answer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -116,16 +116,18 @@ export const fn =
|
||||
var ws_protocol = ws_endpoint.startsWith("https") ? "wss:" : "ws:";
|
||||
if (is_space) {
|
||||
const SPACE_REGEX = /embed\/(.*)\/\+/g;
|
||||
var ws_path = Array.from(ws_endpoint.matchAll(SPACE_REGEX))[0][1];
|
||||
var ws_path = Array.from(
|
||||
ws_endpoint.matchAll(SPACE_REGEX)
|
||||
)[0][1].concat("/");
|
||||
var ws_host = "spaces.huggingface.tech/";
|
||||
} else {
|
||||
var ws_path = location.pathname === "/" ? "" : location.pathname;
|
||||
var ws_path = location.pathname === "/" ? "/" : location.pathname;
|
||||
var ws_host =
|
||||
BUILD_MODE === "dev" || location.origin === "http://localhost:3000"
|
||||
? BACKEND_URL.replace("http://", "").slice(0, -1)
|
||||
: location.host;
|
||||
}
|
||||
const WS_ENDPOINT = `${ws_protocol}//${ws_host}${ws_path}/queue/join`;
|
||||
const WS_ENDPOINT = `${ws_protocol}//${ws_host}${ws_path}queue/join`;
|
||||
|
||||
var websocket = new WebSocket(WS_ENDPOINT);
|
||||
ws_map.set(fn_index, websocket);
|
||||
|
@ -73,6 +73,10 @@
|
||||
<a class="px-4 block thin-link" href="#{{ component['name'].lower() }}">{{ component['name'] }}</a>
|
||||
{% endfor %}
|
||||
<a class="thin-link px-4 block" href="#update">Update</a>
|
||||
<a class="link px-4 my-2 block" href="#routes">Routes
|
||||
{% for component in docs["routes"] %}
|
||||
<a class="px-4 block thin-link" href="#{{ component['name'].lower() }}">{{ component['name'] }}</a>
|
||||
{% endfor %}
|
||||
</div>
|
||||
<div class="flex flex-col">
|
||||
<p class="bg-gradient-to-r from-orange-100 to-orange-50 border border-orange-200 px-4 py-1 rounded-full text-orange-800 mb-1">
|
||||
@ -202,6 +206,21 @@
|
||||
{% endwith %}
|
||||
</div>
|
||||
</section>
|
||||
<section id="routes" class="pt-2 flex flex-col gap-10">
|
||||
<div>
|
||||
<h2 id="routes-header"
|
||||
class="text-4xl font-light mb-2 pt-2 text-orange-500">Routes</h2>
|
||||
<p class="mt-8 text-lg">
|
||||
Gradio includes some helper functions for exposing and interacting with the FastAPI app
|
||||
used to run your demo.
|
||||
</p>
|
||||
</div>
|
||||
{% for component in docs["routes"] %}
|
||||
{% with obj=component, parent="gradio" %}
|
||||
{% include "docs/obj_doc_template.html" %}
|
||||
{% endwith %}
|
||||
{% endfor %}
|
||||
</section>
|
||||
</div>
|
||||
</main>
|
||||
<script src="/assets/prism.js"></script>
|
||||
|
Loading…
x
Reference in New Issue
Block a user