Processes avatar_images for gr.Chatbot and icon for gr.Button correctly, so that respective files are moved to cache (#6379)

* format

* add changeset

* add changeset

* add changeset

* whoops fix

* notebook

* tests

* refactor

* refactor

* format

* added test

* fix test

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2023-11-11 23:13:27 -08:00 committed by GitHub
parent d31d8c6ad8
commit de998b2812
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 96 additions and 47 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
fix:Processes `avatar_images` for `gr.Chatbot` and `icon` for `gr.Button` correctly, so that respective files are moved to cache

View File

@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_multimodal"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/chatbot_multimodal/avatar.png"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "import time\n", "\n", "# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.\n", "\n", "\n", "def add_text(history, text):\n", " history = history + [(text, None)]\n", " return history, gr.Textbox(value=\"\", interactive=False)\n", "\n", "\n", "def add_file(history, file):\n", " history = history + [((file.name,), None)]\n", " return history\n", "\n", "\n", "def bot(history):\n", " response = \"**That's cool!**\"\n", " history[-1][1] = \"\"\n", " for character in response:\n", " history[-1][1] += character\n", " time.sleep(0.05)\n", " yield history\n", "\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot(\n", " [],\n", " elem_id=\"chatbot\",\n", " bubble_full_width=False,\n", " avatar_images=(None, (os.path.join(os.path.abspath(''), \"avatar.png\"))),\n", " )\n", "\n", " with gr.Row():\n", " txt = gr.Textbox(\n", " scale=4,\n", " show_label=False,\n", " placeholder=\"Enter text and press enter, or upload an image\",\n", " container=False,\n", " )\n", " btn = gr.UploadButton(\"\ud83d\udcc1\", file_types=[\"image\", \"video\", \"audio\"])\n", "\n", " txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(\n", " bot, chatbot, chatbot, api_name=\"bot_response\"\n", " )\n", " txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)\n", " file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(\n", " bot, chatbot, chatbot\n", " )\n", "\n", "demo.queue()\n", "if __name__ == \"__main__\":\n", " demo.launch(allowed_paths=[\"avatar.png\"])\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_multimodal"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/chatbot_multimodal/avatar.png"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "import time\n", "\n", "# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.\n", "\n", "\n", "def add_text(history, text):\n", " history = history + [(text, None)]\n", " return history, gr.Textbox(value=\"\", interactive=False)\n", "\n", "\n", "def add_file(history, file):\n", " history = history + [((file.name,), None)]\n", " return history\n", "\n", "\n", "def bot(history):\n", " response = \"**That's cool!**\"\n", " history[-1][1] = \"\"\n", " for character in response:\n", " history[-1][1] += character\n", " time.sleep(0.05)\n", " yield history\n", "\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot(\n", " [],\n", " elem_id=\"chatbot\",\n", " bubble_full_width=False,\n", " avatar_images=(None, (os.path.join(os.path.abspath(''), \"avatar.png\"))),\n", " )\n", "\n", " with gr.Row():\n", " txt = gr.Textbox(\n", " scale=4,\n", " show_label=False,\n", " placeholder=\"Enter text and press enter, or upload an image\",\n", " container=False,\n", " )\n", " btn = gr.UploadButton(\"\ud83d\udcc1\", file_types=[\"image\", \"video\", \"audio\"])\n", "\n", " txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(\n", " bot, chatbot, chatbot, api_name=\"bot_response\"\n", " )\n", " txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)\n", " file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(\n", " bot, chatbot, chatbot\n", " )\n", "\n", "demo.queue()\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -51,4 +51,4 @@ with gr.Blocks() as demo:
demo.queue()
if __name__ == "__main__":
demo.launch(allowed_paths=["avatar.png"])
demo.launch()

View File

@ -8,6 +8,7 @@ import os
import random
import secrets
import sys
import tempfile
import threading
import time
import warnings
@ -117,6 +118,13 @@ class Block:
self.is_rendered: bool = False
self._constructor_args: dict
self.state_session_capacity = 10000
self.temp_files: set[str] = set()
self.GRADIO_CACHE = str(
Path(
os.environ.get("GRADIO_TEMP_DIR")
or str(Path(tempfile.gettempdir()) / "gradio")
).resolve()
)
if render:
self.render()
@ -225,6 +233,36 @@ class Block:
kwargs[parameter.name] = props[parameter.name]
return kwargs
def move_resource_to_block_cache(
self, url_or_file_path: str | Path | None
) -> str | None:
"""Moves a file or downloads a file from a url to a block's cache directory, adds
to to the block's temp_files, and returns the path to the file in cache. This
ensures that the file is accessible to the Block and can be served to users.
"""
if url_or_file_path is None:
return None
if isinstance(url_or_file_path, Path):
url_or_file_path = str(url_or_file_path)
if client_utils.is_http_url_like(url_or_file_path):
temp_file_path = processing_utils.save_url_to_cache(
url_or_file_path, cache_dir=self.GRADIO_CACHE
)
self.temp_files.add(temp_file_path)
else:
url_or_file_path = str(utils.abspath(url_or_file_path))
if not utils.is_in_or_equal(url_or_file_path, self.GRADIO_CACHE):
temp_file_path = processing_utils.save_file_to_cache(
url_or_file_path, cache_dir=self.GRADIO_CACHE
)
else:
temp_file_path = url_or_file_path
self.temp_files.add(temp_file_path)
return temp_file_path
class BlockContext(Block):
def __init__(

View File

@ -7,9 +7,7 @@ from __future__ import annotations
import abc
import hashlib
import json
import os
import sys
import tempfile
import warnings
from abc import ABC, abstractmethod
from enum import Enum
@ -166,13 +164,6 @@ class Component(ComponentBase, Block):
self._selectable = False
if not hasattr(self, "data_model"):
self.data_model: type[GradioDataModel] | None = None
self.temp_files: set[str] = set()
self.GRADIO_CACHE = str(
Path(
os.environ.get("GRADIO_TEMP_DIR")
or str(Path(tempfile.gettempdir()) / "gradio")
).resolve()
)
Block.__init__(
self,

View File

@ -68,9 +68,9 @@ class Button(Component):
scale=scale,
min_width=min_width,
)
self.icon = self.move_resource_to_block_cache(icon)
self.variant = variant
self.size = size
self.icon = icon
self.link = link
@property

View File

@ -9,7 +9,7 @@ from typing import Any, Callable, List, Literal, Optional, Tuple, Union
from gradio_client import utils as client_utils
from gradio_client.documentation import document, set_documentation_group
from gradio import utils
from gradio import processing_utils, utils
from gradio.components.base import Component
from gradio.data_classes import FileData, GradioModel, GradioRootModel
from gradio.events import Events
@ -101,7 +101,6 @@ class Chatbot(Component):
if latex_delimiters is None:
latex_delimiters = [{"left": "$$", "right": "$$", "display": True}]
self.latex_delimiters = latex_delimiters
self.avatar_images = avatar_images or (None, None)
self.show_share_button = (
(utils.get_space() is not None)
if show_share_button is None
@ -113,7 +112,6 @@ class Chatbot(Component):
self.bubble_full_width = bubble_full_width
self.line_breaks = line_breaks
self.layout = layout
super().__init__(
label=label,
every=every,
@ -127,6 +125,14 @@ class Chatbot(Component):
render=render,
value=value,
)
self.avatar_images: list[str | None] = [None, None]
if avatar_images is None:
pass
else:
self.avatar_images = [
processing_utils.move_resource_to_block_cache(avatar_images[0], self),
processing_utils.move_resource_to_block_cache(avatar_images[1], self),
]
def _preprocess_chat_messages(
self, chat_message: str | FileMessage | None

View File

@ -11,7 +11,7 @@ from gradio_client.documentation import document, set_documentation_group
from PIL import Image as _Image # using _ to minimize namespace pollution
import gradio.image_utils as image_utils
from gradio import processing_utils, utils
from gradio import utils
from gradio.components.base import Component, StreamingInput
from gradio.data_classes import FileData
from gradio.events import Events
@ -171,7 +171,7 @@ class Image(StreamingInput, Component):
def as_example(self, input_data: str | Path | None) -> str | None:
if input_data is None:
return None
return processing_utils.move_resource_to_block_cache(input_data, self)
return self.move_resource_to_block_cache(input_data)
def example_inputs(self) -> Any:
return "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"

View File

@ -339,4 +339,4 @@ class Video(Component):
def as_example(self, input_data: str | Path | None) -> str | None:
if input_data is None:
return None
return processing_utils.move_resource_to_block_cache(input_data, self)
return self.move_resource_to_block_cache(input_data)

View File

@ -20,7 +20,7 @@ from PIL import Image, ImageOps, PngImagePlugin
from gradio import wasm_utils
from gradio.data_classes import FileData, GradioModel, GradioRootModel
from gradio.utils import abspath, is_in_or_equal
from gradio.utils import abspath
with warnings.catch_warnings():
warnings.simplefilter("ignore") # Ignore pydub warning if ffmpeg is not installed
@ -236,31 +236,13 @@ def save_base64_to_cache(
return full_temp_file_path
def move_resource_to_block_cache(url_or_file_path: str | Path, block: Component) -> str:
"""Moves a file or downloads a file from a url to a block's cache directory, adds
to to the block's temp_files, and returns the path to the file in cache. This
ensures that the file is accessible to the Block and can be served to users.
def move_resource_to_block_cache(
url_or_file_path: str | Path | None, block: Component
) -> str | None:
"""This method has been replaced by Block.move_resource_to_block_cache(), but is
left here for backwards compatibility for any custom components created in Gradio 4.2.0 or earlier.
"""
if isinstance(url_or_file_path, Path):
url_or_file_path = str(url_or_file_path)
if client_utils.is_http_url_like(url_or_file_path):
temp_file_path = save_url_to_cache(
url_or_file_path, cache_dir=block.GRADIO_CACHE
)
block.temp_files.add(temp_file_path)
else:
url_or_file_path = str(abspath(url_or_file_path))
if not is_in_or_equal(url_or_file_path, block.GRADIO_CACHE):
temp_file_path = save_file_to_cache(
url_or_file_path, cache_dir=block.GRADIO_CACHE
)
else:
temp_file_path = url_or_file_path
block.temp_files.add(temp_file_path)
return temp_file_path
return block.move_resource_to_block_cache(url_or_file_path)
def move_files_to_cache(data: Any, block: Component, postprocess: bool = False):
@ -284,6 +266,7 @@ def move_files_to_cache(data: Any, block: Component, postprocess: bool = False):
temp_file_path = payload.url
else:
temp_file_path = move_resource_to_block_cache(payload.path, block)
assert temp_file_path is not None
payload.path = temp_file_path
return payload.model_dump()

View File

@ -6,6 +6,7 @@ black
boto3
coverage
fastapi>=0.101.0
gradio_pdf==0.0.3
httpx
huggingface_hub
pydantic

View File

@ -45,6 +45,8 @@ filelock==3.7.1
# via
# huggingface-hub
# transformers
gradio_pdf==0.0.3
# via -r requirements.in
h11==0.12.0
# via httpcore
httpcore==0.15.0

View File

@ -790,7 +790,7 @@ class TestAudio:
def test_default_value_postprocess(self):
x_wav = deepcopy(media_data.BASE64_AUDIO)
audio = gr.Audio(value=x_wav["path"])
assert processing_utils.is_in_or_equal(audio.value["path"], audio.GRADIO_CACHE)
assert utils.is_in_or_equal(audio.value["path"], audio.GRADIO_CACHE)
def test_in_interface(self):
def reverse_audio(audio):
@ -1874,7 +1874,7 @@ class TestChatbot:
"likeable": False,
"rtl": False,
"show_copy_button": False,
"avatar_images": (None, None),
"avatar_images": [None, None],
"sanitize_html": True,
"render_markdown": True,
"bubble_full_width": True,
@ -1882,6 +1882,12 @@ class TestChatbot:
"layout": None,
}
def test_avatar_images_are_moved_to_cache(self):
chatbot = gr.Chatbot(avatar_images=("test/test_files/bus.png", None))
assert chatbot.avatar_images[0]
assert utils.is_in_or_equal(chatbot.avatar_images[0], chatbot.GRADIO_CACHE)
assert chatbot.avatar_images[1] is None
class TestJSON:
def test_component_functions(self):
@ -2663,7 +2669,7 @@ def test_component_class_ids():
def test_constructor_args():
assert gr.Textbox(max_lines=314).constructor_args == {"max_lines": 314}
assert gr.LoginButton(icon="F00.svg", value="Log in please").constructor_args == {
"icon": "F00.svg",
assert gr.LoginButton(visible=False, value="Log in please").constructor_args == {
"visible": False,
"value": "Log in please",
}

View File

@ -0,0 +1,17 @@
"""
This suite of tests is designed to ensure compatibility between the current version of Gradio
with custom components created using the previous version of Gradio.
"""
from pathlib import Path
from gradio_pdf import PDF
def test_processing_utils_backwards_compatibility():
pdf_component = PDF()
cached_pdf_file = pdf_component.as_example("test/test_files/sample_file.pdf")
assert (
cached_pdf_file
and Path(cached_pdf_file).exists()
and Path(cached_pdf_file).name == "sample_file.pdf"
)