diff --git a/.changeset/fluffy-mice-joke.md b/.changeset/fluffy-mice-joke.md new file mode 100644 index 0000000000..ea27be8a38 --- /dev/null +++ b/.changeset/fluffy-mice-joke.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +feat:Fix the `root_url` logic for streaming files diff --git a/.gitignore b/.gitignore index c79420e85a..51efba9dde 100644 --- a/.gitignore +++ b/.gitignore @@ -43,6 +43,7 @@ demo/all_demos/demos/* demo/all_demos/requirements.txt demo/*/config.json demo/annotatedimage_component/*.png +demo/fake_diffusion_with_gif/*.gif # Etc .idea/* diff --git a/demo/fake_diffusion_with_gif/image.gif b/demo/fake_diffusion_with_gif/image.gif deleted file mode 100644 index 04a1d2f383..0000000000 Binary files a/demo/fake_diffusion_with_gif/image.gif and /dev/null differ diff --git a/demo/fake_diffusion_with_gif/run.ipynb b/demo/fake_diffusion_with_gif/run.ipynb index 13922f848c..0df082ac51 100644 --- a/demo/fake_diffusion_with_gif/run.ipynb +++ b/demo/fake_diffusion_with_gif/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: fake_diffusion_with_gif"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/fake_diffusion_with_gif/image.gif"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import numpy as np\n", "import time\n", "import os\n", "from PIL import Image\n", "import requests\n", "from io import BytesIO\n", "\n", "\n", "def create_gif(images):\n", " pil_images = []\n", " for image in images:\n", " if isinstance(image, str):\n", " response = requests.get(image)\n", " image = Image.open(BytesIO(response.content))\n", " else:\n", " image = Image.fromarray((image * 255).astype(np.uint8))\n", " pil_images.append(image)\n", " fp_out = os.path.join(os.path.abspath(''), \"image.gif\")\n", " img = pil_images.pop(0)\n", " img.save(fp=fp_out, format='GIF', append_images=pil_images,\n", " save_all=True, duration=400, loop=0)\n", " return fp_out\n", "\n", "\n", "def fake_diffusion(steps):\n", " rng = np.random.default_rng()\n", " images = []\n", " for _ in range(steps):\n", " time.sleep(1)\n", " image = rng.random((600, 600, 3))\n", " images.append(image)\n", " yield image, gr.Image(visible=False)\n", "\n", " time.sleep(1)\n", " image = \"https://gradio-builds.s3.amazonaws.com/diffusion_image/cute_dog.jpg\"\n", " images.append(image)\n", " gif_path = create_gif(images)\n", "\n", " yield image, gr.Image(value=gif_path, visible=True)\n", "\n", "\n", "demo = gr.Interface(fake_diffusion,\n", " inputs=gr.Slider(1, 10, 3, step=1),\n", " outputs=[\"image\", gr.Image(label=\"All Images\", visible=False)])\n", "demo.queue()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: fake_diffusion_with_gif"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import numpy as np\n", "import time\n", "import os\n", "from PIL import Image\n", "import requests\n", "from io import BytesIO\n", "\n", "\n", "def create_gif(images):\n", " pil_images = []\n", " for image in images:\n", " if isinstance(image, str):\n", " response = requests.get(image)\n", " image = Image.open(BytesIO(response.content))\n", " else:\n", " image = Image.fromarray((image * 255).astype(np.uint8))\n", " pil_images.append(image)\n", " fp_out = os.path.join(os.path.abspath(''), \"image.gif\")\n", " img = pil_images.pop(0)\n", " img.save(fp=fp_out, format='GIF', append_images=pil_images,\n", " save_all=True, duration=400, loop=0)\n", " return fp_out\n", "\n", "\n", "def fake_diffusion(steps):\n", " rng = np.random.default_rng()\n", " images = []\n", " for _ in range(steps):\n", " time.sleep(1)\n", " image = rng.random((600, 600, 3))\n", " images.append(image)\n", " yield image, gr.Image(visible=False)\n", "\n", " time.sleep(1)\n", " image = \"https://gradio-builds.s3.amazonaws.com/diffusion_image/cute_dog.jpg\"\n", " images.append(image)\n", " gif_path = create_gif(images)\n", "\n", " yield image, gr.Image(value=gif_path, visible=True)\n", "\n", "\n", "demo = gr.Interface(fake_diffusion,\n", " inputs=gr.Slider(1, 10, 3, step=1),\n", " outputs=[\"image\", gr.Image(label=\"All Images\", visible=False)])\n", "demo.queue()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/gradio/blocks.py b/gradio/blocks.py index 017b682505..514494f488 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -1579,6 +1579,7 @@ Received outputs: in_event_listener: bool = True, simple_format: bool = False, explicit_call: bool = False, + root_path: str | None = None, ) -> dict[str, Any]: """ Processes API calls from the frontend. First preprocesses the data, @@ -1593,6 +1594,7 @@ Received outputs: event_data: data associated with the event trigger itself in_event_listener: whether this API call is being made in response to an event listener explicit_call: whether this call is being made directly by calling the Blocks function, instead of through an event listener or API route + root_path: if provided, the root path of the server. All file URLs will be prefixed with this path. Returns: None """ block_fn = self.fns[fn_index] @@ -1641,6 +1643,8 @@ Received outputs: state, limiter=self.limiter, ) + if root_path is not None: + data = processing_utils.add_root_url(data, root_path, None) data = list(zip(*data)) is_generating, iterator = None, None else: @@ -1673,6 +1677,8 @@ Received outputs: state, limiter=self.limiter, ) + if root_path is not None: + data = processing_utils.add_root_url(data, root_path, None) is_generating, iterator = result["is_generating"], result["iterator"] if is_generating or was_generating: run = id(old_iterator) if was_generating else id(iterator) diff --git a/gradio/processing_utils.py b/gradio/processing_utils.py index 9d3b474039..c3286daf6f 100644 --- a/gradio/processing_utils.py +++ b/gradio/processing_utils.py @@ -297,7 +297,7 @@ def move_files_to_cache( return client_utils.traverse(data, _move_to_cache, client_utils.is_file_obj) -def add_root_url(data: dict, root_url: str, previous_root_url: str | None) -> dict: +def add_root_url(data: dict | list, root_url: str, previous_root_url: str | None): def _add_root_url(file_dict: dict): if previous_root_url and file_dict["url"].startswith(previous_root_url): file_dict["url"] = file_dict["url"][len(previous_root_url) :] diff --git a/gradio/queueing.py b/gradio/queueing.py index bb5292737b..bf6c87d6b1 100644 --- a/gradio/queueing.py +++ b/gradio/queueing.py @@ -493,13 +493,17 @@ class Queue: username=username, request=None, ) - + assert body.request is not None # noqa: S101 + root_path = route_utils.get_root_url( + request=body.request, route_path="/queue/join", root_path=app.root_path + ) try: output = await route_utils.call_process_api( app=app, body=body, gr_request=gr_request, fn_index_inferred=fn_index_inferred, + root_path=root_path, ) except Exception as error: show_error = app.get_blocks().show_error or isinstance(error, Error) diff --git a/gradio/route_utils.py b/gradio/route_utils.py index 9460d09d48..7f5b37d053 100644 --- a/gradio/route_utils.py +++ b/gradio/route_utils.py @@ -232,6 +232,7 @@ async def call_process_api( body: PredictBody, gr_request: Union[Request, list[Request]], fn_index_inferred: int, + root_path: str, ): session_state, iterator = restore_session_state(app=app, body=body) @@ -259,6 +260,7 @@ async def call_process_api( event_data=event_data, in_event_listener=True, simple_format=body.simple_format, + root_path=root_path, ) iterator = output.pop("iterator", None) if event_id is not None: diff --git a/gradio/routes.py b/gradio/routes.py index 8e7d0bfae1..c7d4d41921 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -65,7 +65,6 @@ from gradio.data_classes import ( ) from gradio.exceptions import Error from gradio.oauth import attach_oauth -from gradio.processing_utils import add_root_url from gradio.route_utils import ( # noqa: F401 CustomCORSMiddleware, FileUploadProgress, @@ -609,13 +608,16 @@ class App(FastAPI): username=username, request=request, ) - + root_path = route_utils.get_root_url( + request=request, route_path=f"/api/{api_name}", root_path=app.root_path + ) try: output = await route_utils.call_process_api( app=app, body=body, gr_request=gr_request, fn_index_inferred=fn_index_inferred, + root_path=root_path, ) except BaseException as error: show_error = app.get_blocks().show_error or isinstance(error, Error) @@ -624,10 +626,6 @@ class App(FastAPI): content={"error": str(error) if show_error else None}, status_code=500, ) - root_path = route_utils.get_root_url( - request=request, route_path=f"/api/{api_name}", root_path=app.root_path - ) - output = add_root_url(output, root_path, None) return output @app.post("/call/{api_name}", dependencies=[Depends(login_check)]) @@ -674,7 +672,9 @@ class App(FastAPI): detail="Queue is stopped.", ) - success, event_id = await blocks._queue.push(body, request, username) + success, event_id = await blocks._queue.push( + body=body, request=request, username=username + ) if not success: status_code = ( status.HTTP_503_SERVICE_UNAVAILABLE @@ -724,9 +724,6 @@ class App(FastAPI): process_msg: Callable[[EventMessage], str | None], ): blocks = app.get_blocks() - root_path = route_utils.get_root_url( - request=request, route_path="/queue/data", root_path=app.root_path - ) async def sse_stream(request: fastapi.Request): try: @@ -769,11 +766,6 @@ class App(FastAPI): success=False, ) if message: - if isinstance( - message, - (ProcessGeneratingMessage, ProcessCompletedMessage), - ): - add_root_url(message.output, root_path, None) response = process_msg(message) if response is not None: yield response diff --git a/test/test_blocks.py b/test/test_blocks.py index 39c228abf3..e48a035ad0 100644 --- a/test/test_blocks.py +++ b/test/test_blocks.py @@ -408,7 +408,7 @@ class TestTempFile: assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 3 def test_no_empty_image_files(self, gradio_temp_dir, connect): - file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files") + file_dir = pathlib.Path(__file__).parent / "test_files" image = str(file_dir / "bus.png") demo = gr.Interface( @@ -1152,6 +1152,23 @@ class TestUpdate: } +@pytest.mark.asyncio +async def test_root_path(): + image_file = pathlib.Path(__file__).parent / "test_files" / "bus.png" + demo = gr.Interface(lambda x: image_file, "textbox", "image") + result = await demo.process_api(fn_index=0, inputs=[""], request=None, state=None) + result_url = result["data"][0]["url"] + assert result_url.startswith("/file=") + assert result_url.endswith("bus.png") + + result = await demo.process_api( + fn_index=0, inputs=[""], request=None, state=None, root_path="abidlabs.hf.space" + ) + result_url = result["data"][0]["url"] + assert result_url.startswith("abidlabs.hf.space/file=") + assert result_url.endswith("bus.png") + + class TestRender: def test_duplicate_error(self): with pytest.raises(DuplicateBlockError): diff --git a/test/test_utils.py b/test/test_utils.py index 383a576ea6..3720838794 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -19,6 +19,7 @@ from gradio.utils import ( check_function_inputs_match, colab_check, delete_none, + diff, download_if_url, get_continuous_fn, get_extension_from_file_path_or_url, @@ -439,3 +440,16 @@ def test_is_in_or_equal(): ) def test_get_extension_from_file_path_or_url(path_or_url, extension): assert get_extension_from_file_path_or_url(path_or_url) == extension + + +@pytest.mark.parametrize( + "old, new, expected_diff", + [ + ({"a": 1, "b": 2}, {"a": 1, "b": 2}, []), + ({}, {"a": 1, "b": 2}, [("add", ["a"], 1), ("add", ["b"], 2)]), + (["a", "b"], {"a": 1, "b": 2}, [("replace", [], {"a": 1, "b": 2})]), + ("abc", "abcdef", [("append", [], "def")]), + ], +) +def test_diff(old, new, expected_diff): + assert diff(old, new) == expected_diff