Fix the root_url logic for streaming files (#7614)

* fix

* add changeset

* ignore

* add changeset

* changes

* changes

* linting

* demo

* changes

* blocks

* add changeset

* diff

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2024-03-05 15:21:45 -08:00 committed by GitHub
parent e340894b1c
commit 355ed666d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 60 additions and 19 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
feat:Fix the `root_url` logic for streaming files

1
.gitignore vendored
View File

@ -43,6 +43,7 @@ demo/all_demos/demos/*
demo/all_demos/requirements.txt
demo/*/config.json
demo/annotatedimage_component/*.png
demo/fake_diffusion_with_gif/*.gif
# Etc
.idea/*

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.7 MiB

View File

@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: fake_diffusion_with_gif"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/fake_diffusion_with_gif/image.gif"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import numpy as np\n", "import time\n", "import os\n", "from PIL import Image\n", "import requests\n", "from io import BytesIO\n", "\n", "\n", "def create_gif(images):\n", " pil_images = []\n", " for image in images:\n", " if isinstance(image, str):\n", " response = requests.get(image)\n", " image = Image.open(BytesIO(response.content))\n", " else:\n", " image = Image.fromarray((image * 255).astype(np.uint8))\n", " pil_images.append(image)\n", " fp_out = os.path.join(os.path.abspath(''), \"image.gif\")\n", " img = pil_images.pop(0)\n", " img.save(fp=fp_out, format='GIF', append_images=pil_images,\n", " save_all=True, duration=400, loop=0)\n", " return fp_out\n", "\n", "\n", "def fake_diffusion(steps):\n", " rng = np.random.default_rng()\n", " images = []\n", " for _ in range(steps):\n", " time.sleep(1)\n", " image = rng.random((600, 600, 3))\n", " images.append(image)\n", " yield image, gr.Image(visible=False)\n", "\n", " time.sleep(1)\n", " image = \"https://gradio-builds.s3.amazonaws.com/diffusion_image/cute_dog.jpg\"\n", " images.append(image)\n", " gif_path = create_gif(images)\n", "\n", " yield image, gr.Image(value=gif_path, visible=True)\n", "\n", "\n", "demo = gr.Interface(fake_diffusion,\n", " inputs=gr.Slider(1, 10, 3, step=1),\n", " outputs=[\"image\", gr.Image(label=\"All Images\", visible=False)])\n", "demo.queue()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: fake_diffusion_with_gif"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import numpy as np\n", "import time\n", "import os\n", "from PIL import Image\n", "import requests\n", "from io import BytesIO\n", "\n", "\n", "def create_gif(images):\n", " pil_images = []\n", " for image in images:\n", " if isinstance(image, str):\n", " response = requests.get(image)\n", " image = Image.open(BytesIO(response.content))\n", " else:\n", " image = Image.fromarray((image * 255).astype(np.uint8))\n", " pil_images.append(image)\n", " fp_out = os.path.join(os.path.abspath(''), \"image.gif\")\n", " img = pil_images.pop(0)\n", " img.save(fp=fp_out, format='GIF', append_images=pil_images,\n", " save_all=True, duration=400, loop=0)\n", " return fp_out\n", "\n", "\n", "def fake_diffusion(steps):\n", " rng = np.random.default_rng()\n", " images = []\n", " for _ in range(steps):\n", " time.sleep(1)\n", " image = rng.random((600, 600, 3))\n", " images.append(image)\n", " yield image, gr.Image(visible=False)\n", "\n", " time.sleep(1)\n", " image = \"https://gradio-builds.s3.amazonaws.com/diffusion_image/cute_dog.jpg\"\n", " images.append(image)\n", " gif_path = create_gif(images)\n", "\n", " yield image, gr.Image(value=gif_path, visible=True)\n", "\n", "\n", "demo = gr.Interface(fake_diffusion,\n", " inputs=gr.Slider(1, 10, 3, step=1),\n", " outputs=[\"image\", gr.Image(label=\"All Images\", visible=False)])\n", "demo.queue()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -1579,6 +1579,7 @@ Received outputs:
in_event_listener: bool = True,
simple_format: bool = False,
explicit_call: bool = False,
root_path: str | None = None,
) -> dict[str, Any]:
"""
Processes API calls from the frontend. First preprocesses the data,
@ -1593,6 +1594,7 @@ Received outputs:
event_data: data associated with the event trigger itself
in_event_listener: whether this API call is being made in response to an event listener
explicit_call: whether this call is being made directly by calling the Blocks function, instead of through an event listener or API route
root_path: if provided, the root path of the server. All file URLs will be prefixed with this path.
Returns: None
"""
block_fn = self.fns[fn_index]
@ -1641,6 +1643,8 @@ Received outputs:
state,
limiter=self.limiter,
)
if root_path is not None:
data = processing_utils.add_root_url(data, root_path, None)
data = list(zip(*data))
is_generating, iterator = None, None
else:
@ -1673,6 +1677,8 @@ Received outputs:
state,
limiter=self.limiter,
)
if root_path is not None:
data = processing_utils.add_root_url(data, root_path, None)
is_generating, iterator = result["is_generating"], result["iterator"]
if is_generating or was_generating:
run = id(old_iterator) if was_generating else id(iterator)

View File

@ -297,7 +297,7 @@ def move_files_to_cache(
return client_utils.traverse(data, _move_to_cache, client_utils.is_file_obj)
def add_root_url(data: dict, root_url: str, previous_root_url: str | None) -> dict:
def add_root_url(data: dict | list, root_url: str, previous_root_url: str | None):
def _add_root_url(file_dict: dict):
if previous_root_url and file_dict["url"].startswith(previous_root_url):
file_dict["url"] = file_dict["url"][len(previous_root_url) :]

View File

@ -493,13 +493,17 @@ class Queue:
username=username,
request=None,
)
assert body.request is not None # noqa: S101
root_path = route_utils.get_root_url(
request=body.request, route_path="/queue/join", root_path=app.root_path
)
try:
output = await route_utils.call_process_api(
app=app,
body=body,
gr_request=gr_request,
fn_index_inferred=fn_index_inferred,
root_path=root_path,
)
except Exception as error:
show_error = app.get_blocks().show_error or isinstance(error, Error)

View File

@ -232,6 +232,7 @@ async def call_process_api(
body: PredictBody,
gr_request: Union[Request, list[Request]],
fn_index_inferred: int,
root_path: str,
):
session_state, iterator = restore_session_state(app=app, body=body)
@ -259,6 +260,7 @@ async def call_process_api(
event_data=event_data,
in_event_listener=True,
simple_format=body.simple_format,
root_path=root_path,
)
iterator = output.pop("iterator", None)
if event_id is not None:

View File

@ -65,7 +65,6 @@ from gradio.data_classes import (
)
from gradio.exceptions import Error
from gradio.oauth import attach_oauth
from gradio.processing_utils import add_root_url
from gradio.route_utils import ( # noqa: F401
CustomCORSMiddleware,
FileUploadProgress,
@ -609,13 +608,16 @@ class App(FastAPI):
username=username,
request=request,
)
root_path = route_utils.get_root_url(
request=request, route_path=f"/api/{api_name}", root_path=app.root_path
)
try:
output = await route_utils.call_process_api(
app=app,
body=body,
gr_request=gr_request,
fn_index_inferred=fn_index_inferred,
root_path=root_path,
)
except BaseException as error:
show_error = app.get_blocks().show_error or isinstance(error, Error)
@ -624,10 +626,6 @@ class App(FastAPI):
content={"error": str(error) if show_error else None},
status_code=500,
)
root_path = route_utils.get_root_url(
request=request, route_path=f"/api/{api_name}", root_path=app.root_path
)
output = add_root_url(output, root_path, None)
return output
@app.post("/call/{api_name}", dependencies=[Depends(login_check)])
@ -674,7 +672,9 @@ class App(FastAPI):
detail="Queue is stopped.",
)
success, event_id = await blocks._queue.push(body, request, username)
success, event_id = await blocks._queue.push(
body=body, request=request, username=username
)
if not success:
status_code = (
status.HTTP_503_SERVICE_UNAVAILABLE
@ -724,9 +724,6 @@ class App(FastAPI):
process_msg: Callable[[EventMessage], str | None],
):
blocks = app.get_blocks()
root_path = route_utils.get_root_url(
request=request, route_path="/queue/data", root_path=app.root_path
)
async def sse_stream(request: fastapi.Request):
try:
@ -769,11 +766,6 @@ class App(FastAPI):
success=False,
)
if message:
if isinstance(
message,
(ProcessGeneratingMessage, ProcessCompletedMessage),
):
add_root_url(message.output, root_path, None)
response = process_msg(message)
if response is not None:
yield response

View File

@ -408,7 +408,7 @@ class TestTempFile:
assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 3
def test_no_empty_image_files(self, gradio_temp_dir, connect):
file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
file_dir = pathlib.Path(__file__).parent / "test_files"
image = str(file_dir / "bus.png")
demo = gr.Interface(
@ -1152,6 +1152,23 @@ class TestUpdate:
}
@pytest.mark.asyncio
async def test_root_path():
image_file = pathlib.Path(__file__).parent / "test_files" / "bus.png"
demo = gr.Interface(lambda x: image_file, "textbox", "image")
result = await demo.process_api(fn_index=0, inputs=[""], request=None, state=None)
result_url = result["data"][0]["url"]
assert result_url.startswith("/file=")
assert result_url.endswith("bus.png")
result = await demo.process_api(
fn_index=0, inputs=[""], request=None, state=None, root_path="abidlabs.hf.space"
)
result_url = result["data"][0]["url"]
assert result_url.startswith("abidlabs.hf.space/file=")
assert result_url.endswith("bus.png")
class TestRender:
def test_duplicate_error(self):
with pytest.raises(DuplicateBlockError):

View File

@ -19,6 +19,7 @@ from gradio.utils import (
check_function_inputs_match,
colab_check,
delete_none,
diff,
download_if_url,
get_continuous_fn,
get_extension_from_file_path_or_url,
@ -439,3 +440,16 @@ def test_is_in_or_equal():
)
def test_get_extension_from_file_path_or_url(path_or_url, extension):
assert get_extension_from_file_path_or_url(path_or_url) == extension
@pytest.mark.parametrize(
"old, new, expected_diff",
[
({"a": 1, "b": 2}, {"a": 1, "b": 2}, []),
({}, {"a": 1, "b": 2}, [("add", ["a"], 1), ("add", ["b"], 2)]),
(["a", "b"], {"a": 1, "b": 2}, [("replace", [], {"a": 1, "b": 2})]),
("abc", "abcdef", [("append", [], "def")]),
],
)
def test_diff(old, new, expected_diff):
assert diff(old, new) == expected_diff