mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-25 12:10:31 +08:00
Processes avatar_images
for gr.Chatbot
and icon
for gr.Button
correctly, so that respective files are moved to cache (#6379)
* format * add changeset * add changeset * add changeset * whoops fix * notebook * tests * refactor * refactor * format * added test * fix test --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
d31d8c6ad8
commit
de998b2812
5
.changeset/spotty-cameras-shop.md
Normal file
5
.changeset/spotty-cameras-shop.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
fix:Processes `avatar_images` for `gr.Chatbot` and `icon` for `gr.Button` correctly, so that respective files are moved to cache
|
@ -1 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_multimodal"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/chatbot_multimodal/avatar.png"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "import time\n", "\n", "# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.\n", "\n", "\n", "def add_text(history, text):\n", " history = history + [(text, None)]\n", " return history, gr.Textbox(value=\"\", interactive=False)\n", "\n", "\n", "def add_file(history, file):\n", " history = history + [((file.name,), None)]\n", " return history\n", "\n", "\n", "def bot(history):\n", " response = \"**That's cool!**\"\n", " history[-1][1] = \"\"\n", " for character in response:\n", " history[-1][1] += character\n", " time.sleep(0.05)\n", " yield history\n", "\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot(\n", " [],\n", " elem_id=\"chatbot\",\n", " bubble_full_width=False,\n", " avatar_images=(None, (os.path.join(os.path.abspath(''), \"avatar.png\"))),\n", " )\n", "\n", " with gr.Row():\n", " txt = gr.Textbox(\n", " scale=4,\n", " show_label=False,\n", " placeholder=\"Enter text and press enter, or upload an image\",\n", " container=False,\n", " )\n", " btn = gr.UploadButton(\"\ud83d\udcc1\", file_types=[\"image\", \"video\", \"audio\"])\n", "\n", " txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(\n", " bot, chatbot, chatbot, api_name=\"bot_response\"\n", " )\n", " txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)\n", " file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(\n", " bot, chatbot, chatbot\n", " )\n", "\n", "demo.queue()\n", "if __name__ == \"__main__\":\n", " demo.launch(allowed_paths=[\"avatar.png\"])\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_multimodal"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/chatbot_multimodal/avatar.png"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "import time\n", "\n", "# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.\n", "\n", "\n", "def add_text(history, text):\n", " history = history + [(text, None)]\n", " return history, gr.Textbox(value=\"\", interactive=False)\n", "\n", "\n", "def add_file(history, file):\n", " history = history + [((file.name,), None)]\n", " return history\n", "\n", "\n", "def bot(history):\n", " response = \"**That's cool!**\"\n", " history[-1][1] = \"\"\n", " for character in response:\n", " history[-1][1] += character\n", " time.sleep(0.05)\n", " yield history\n", "\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot(\n", " [],\n", " elem_id=\"chatbot\",\n", " bubble_full_width=False,\n", " avatar_images=(None, (os.path.join(os.path.abspath(''), \"avatar.png\"))),\n", " )\n", "\n", " with gr.Row():\n", " txt = gr.Textbox(\n", " scale=4,\n", " show_label=False,\n", " placeholder=\"Enter text and press enter, or upload an image\",\n", " container=False,\n", " )\n", " btn = gr.UploadButton(\"\ud83d\udcc1\", file_types=[\"image\", \"video\", \"audio\"])\n", "\n", " txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(\n", " bot, chatbot, chatbot, api_name=\"bot_response\"\n", " )\n", " txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)\n", " file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(\n", " bot, chatbot, chatbot\n", " )\n", "\n", "demo.queue()\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
@ -51,4 +51,4 @@ with gr.Blocks() as demo:
|
||||
|
||||
demo.queue()
|
||||
if __name__ == "__main__":
|
||||
demo.launch(allowed_paths=["avatar.png"])
|
||||
demo.launch()
|
||||
|
@ -8,6 +8,7 @@ import os
|
||||
import random
|
||||
import secrets
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
@ -117,6 +118,13 @@ class Block:
|
||||
self.is_rendered: bool = False
|
||||
self._constructor_args: dict
|
||||
self.state_session_capacity = 10000
|
||||
self.temp_files: set[str] = set()
|
||||
self.GRADIO_CACHE = str(
|
||||
Path(
|
||||
os.environ.get("GRADIO_TEMP_DIR")
|
||||
or str(Path(tempfile.gettempdir()) / "gradio")
|
||||
).resolve()
|
||||
)
|
||||
|
||||
if render:
|
||||
self.render()
|
||||
@ -225,6 +233,36 @@ class Block:
|
||||
kwargs[parameter.name] = props[parameter.name]
|
||||
return kwargs
|
||||
|
||||
def move_resource_to_block_cache(
|
||||
self, url_or_file_path: str | Path | None
|
||||
) -> str | None:
|
||||
"""Moves a file or downloads a file from a url to a block's cache directory, adds
|
||||
to to the block's temp_files, and returns the path to the file in cache. This
|
||||
ensures that the file is accessible to the Block and can be served to users.
|
||||
"""
|
||||
if url_or_file_path is None:
|
||||
return None
|
||||
if isinstance(url_or_file_path, Path):
|
||||
url_or_file_path = str(url_or_file_path)
|
||||
|
||||
if client_utils.is_http_url_like(url_or_file_path):
|
||||
temp_file_path = processing_utils.save_url_to_cache(
|
||||
url_or_file_path, cache_dir=self.GRADIO_CACHE
|
||||
)
|
||||
|
||||
self.temp_files.add(temp_file_path)
|
||||
else:
|
||||
url_or_file_path = str(utils.abspath(url_or_file_path))
|
||||
if not utils.is_in_or_equal(url_or_file_path, self.GRADIO_CACHE):
|
||||
temp_file_path = processing_utils.save_file_to_cache(
|
||||
url_or_file_path, cache_dir=self.GRADIO_CACHE
|
||||
)
|
||||
else:
|
||||
temp_file_path = url_or_file_path
|
||||
self.temp_files.add(temp_file_path)
|
||||
|
||||
return temp_file_path
|
||||
|
||||
|
||||
class BlockContext(Block):
|
||||
def __init__(
|
||||
|
@ -7,9 +7,7 @@ from __future__ import annotations
|
||||
import abc
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
@ -166,13 +164,6 @@ class Component(ComponentBase, Block):
|
||||
self._selectable = False
|
||||
if not hasattr(self, "data_model"):
|
||||
self.data_model: type[GradioDataModel] | None = None
|
||||
self.temp_files: set[str] = set()
|
||||
self.GRADIO_CACHE = str(
|
||||
Path(
|
||||
os.environ.get("GRADIO_TEMP_DIR")
|
||||
or str(Path(tempfile.gettempdir()) / "gradio")
|
||||
).resolve()
|
||||
)
|
||||
|
||||
Block.__init__(
|
||||
self,
|
||||
|
@ -68,9 +68,9 @@ class Button(Component):
|
||||
scale=scale,
|
||||
min_width=min_width,
|
||||
)
|
||||
self.icon = self.move_resource_to_block_cache(icon)
|
||||
self.variant = variant
|
||||
self.size = size
|
||||
self.icon = icon
|
||||
self.link = link
|
||||
|
||||
@property
|
||||
|
@ -9,7 +9,7 @@ from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
||||
from gradio_client import utils as client_utils
|
||||
from gradio_client.documentation import document, set_documentation_group
|
||||
|
||||
from gradio import utils
|
||||
from gradio import processing_utils, utils
|
||||
from gradio.components.base import Component
|
||||
from gradio.data_classes import FileData, GradioModel, GradioRootModel
|
||||
from gradio.events import Events
|
||||
@ -101,7 +101,6 @@ class Chatbot(Component):
|
||||
if latex_delimiters is None:
|
||||
latex_delimiters = [{"left": "$$", "right": "$$", "display": True}]
|
||||
self.latex_delimiters = latex_delimiters
|
||||
self.avatar_images = avatar_images or (None, None)
|
||||
self.show_share_button = (
|
||||
(utils.get_space() is not None)
|
||||
if show_share_button is None
|
||||
@ -113,7 +112,6 @@ class Chatbot(Component):
|
||||
self.bubble_full_width = bubble_full_width
|
||||
self.line_breaks = line_breaks
|
||||
self.layout = layout
|
||||
|
||||
super().__init__(
|
||||
label=label,
|
||||
every=every,
|
||||
@ -127,6 +125,14 @@ class Chatbot(Component):
|
||||
render=render,
|
||||
value=value,
|
||||
)
|
||||
self.avatar_images: list[str | None] = [None, None]
|
||||
if avatar_images is None:
|
||||
pass
|
||||
else:
|
||||
self.avatar_images = [
|
||||
processing_utils.move_resource_to_block_cache(avatar_images[0], self),
|
||||
processing_utils.move_resource_to_block_cache(avatar_images[1], self),
|
||||
]
|
||||
|
||||
def _preprocess_chat_messages(
|
||||
self, chat_message: str | FileMessage | None
|
||||
|
@ -11,7 +11,7 @@ from gradio_client.documentation import document, set_documentation_group
|
||||
from PIL import Image as _Image # using _ to minimize namespace pollution
|
||||
|
||||
import gradio.image_utils as image_utils
|
||||
from gradio import processing_utils, utils
|
||||
from gradio import utils
|
||||
from gradio.components.base import Component, StreamingInput
|
||||
from gradio.data_classes import FileData
|
||||
from gradio.events import Events
|
||||
@ -171,7 +171,7 @@ class Image(StreamingInput, Component):
|
||||
def as_example(self, input_data: str | Path | None) -> str | None:
|
||||
if input_data is None:
|
||||
return None
|
||||
return processing_utils.move_resource_to_block_cache(input_data, self)
|
||||
return self.move_resource_to_block_cache(input_data)
|
||||
|
||||
def example_inputs(self) -> Any:
|
||||
return "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
|
||||
|
@ -339,4 +339,4 @@ class Video(Component):
|
||||
def as_example(self, input_data: str | Path | None) -> str | None:
|
||||
if input_data is None:
|
||||
return None
|
||||
return processing_utils.move_resource_to_block_cache(input_data, self)
|
||||
return self.move_resource_to_block_cache(input_data)
|
||||
|
@ -20,7 +20,7 @@ from PIL import Image, ImageOps, PngImagePlugin
|
||||
|
||||
from gradio import wasm_utils
|
||||
from gradio.data_classes import FileData, GradioModel, GradioRootModel
|
||||
from gradio.utils import abspath, is_in_or_equal
|
||||
from gradio.utils import abspath
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore") # Ignore pydub warning if ffmpeg is not installed
|
||||
@ -236,31 +236,13 @@ def save_base64_to_cache(
|
||||
return full_temp_file_path
|
||||
|
||||
|
||||
def move_resource_to_block_cache(url_or_file_path: str | Path, block: Component) -> str:
|
||||
"""Moves a file or downloads a file from a url to a block's cache directory, adds
|
||||
to to the block's temp_files, and returns the path to the file in cache. This
|
||||
ensures that the file is accessible to the Block and can be served to users.
|
||||
def move_resource_to_block_cache(
|
||||
url_or_file_path: str | Path | None, block: Component
|
||||
) -> str | None:
|
||||
"""This method has been replaced by Block.move_resource_to_block_cache(), but is
|
||||
left here for backwards compatibility for any custom components created in Gradio 4.2.0 or earlier.
|
||||
"""
|
||||
if isinstance(url_or_file_path, Path):
|
||||
url_or_file_path = str(url_or_file_path)
|
||||
|
||||
if client_utils.is_http_url_like(url_or_file_path):
|
||||
temp_file_path = save_url_to_cache(
|
||||
url_or_file_path, cache_dir=block.GRADIO_CACHE
|
||||
)
|
||||
|
||||
block.temp_files.add(temp_file_path)
|
||||
else:
|
||||
url_or_file_path = str(abspath(url_or_file_path))
|
||||
if not is_in_or_equal(url_or_file_path, block.GRADIO_CACHE):
|
||||
temp_file_path = save_file_to_cache(
|
||||
url_or_file_path, cache_dir=block.GRADIO_CACHE
|
||||
)
|
||||
else:
|
||||
temp_file_path = url_or_file_path
|
||||
block.temp_files.add(temp_file_path)
|
||||
|
||||
return temp_file_path
|
||||
return block.move_resource_to_block_cache(url_or_file_path)
|
||||
|
||||
|
||||
def move_files_to_cache(data: Any, block: Component, postprocess: bool = False):
|
||||
@ -284,6 +266,7 @@ def move_files_to_cache(data: Any, block: Component, postprocess: bool = False):
|
||||
temp_file_path = payload.url
|
||||
else:
|
||||
temp_file_path = move_resource_to_block_cache(payload.path, block)
|
||||
assert temp_file_path is not None
|
||||
payload.path = temp_file_path
|
||||
return payload.model_dump()
|
||||
|
||||
|
@ -6,6 +6,7 @@ black
|
||||
boto3
|
||||
coverage
|
||||
fastapi>=0.101.0
|
||||
gradio_pdf==0.0.3
|
||||
httpx
|
||||
huggingface_hub
|
||||
pydantic
|
||||
|
@ -45,6 +45,8 @@ filelock==3.7.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
gradio_pdf==0.0.3
|
||||
# via -r requirements.in
|
||||
h11==0.12.0
|
||||
# via httpcore
|
||||
httpcore==0.15.0
|
||||
|
@ -790,7 +790,7 @@ class TestAudio:
|
||||
def test_default_value_postprocess(self):
|
||||
x_wav = deepcopy(media_data.BASE64_AUDIO)
|
||||
audio = gr.Audio(value=x_wav["path"])
|
||||
assert processing_utils.is_in_or_equal(audio.value["path"], audio.GRADIO_CACHE)
|
||||
assert utils.is_in_or_equal(audio.value["path"], audio.GRADIO_CACHE)
|
||||
|
||||
def test_in_interface(self):
|
||||
def reverse_audio(audio):
|
||||
@ -1874,7 +1874,7 @@ class TestChatbot:
|
||||
"likeable": False,
|
||||
"rtl": False,
|
||||
"show_copy_button": False,
|
||||
"avatar_images": (None, None),
|
||||
"avatar_images": [None, None],
|
||||
"sanitize_html": True,
|
||||
"render_markdown": True,
|
||||
"bubble_full_width": True,
|
||||
@ -1882,6 +1882,12 @@ class TestChatbot:
|
||||
"layout": None,
|
||||
}
|
||||
|
||||
def test_avatar_images_are_moved_to_cache(self):
|
||||
chatbot = gr.Chatbot(avatar_images=("test/test_files/bus.png", None))
|
||||
assert chatbot.avatar_images[0]
|
||||
assert utils.is_in_or_equal(chatbot.avatar_images[0], chatbot.GRADIO_CACHE)
|
||||
assert chatbot.avatar_images[1] is None
|
||||
|
||||
|
||||
class TestJSON:
|
||||
def test_component_functions(self):
|
||||
@ -2663,7 +2669,7 @@ def test_component_class_ids():
|
||||
|
||||
def test_constructor_args():
|
||||
assert gr.Textbox(max_lines=314).constructor_args == {"max_lines": 314}
|
||||
assert gr.LoginButton(icon="F00.svg", value="Log in please").constructor_args == {
|
||||
"icon": "F00.svg",
|
||||
assert gr.LoginButton(visible=False, value="Log in please").constructor_args == {
|
||||
"visible": False,
|
||||
"value": "Log in please",
|
||||
}
|
||||
|
17
test/test_custom_component_compatibility.py
Normal file
17
test/test_custom_component_compatibility.py
Normal file
@ -0,0 +1,17 @@
|
||||
"""
|
||||
This suite of tests is designed to ensure compatibility between the current version of Gradio
|
||||
with custom components created using the previous version of Gradio.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
from gradio_pdf import PDF
|
||||
|
||||
|
||||
def test_processing_utils_backwards_compatibility():
|
||||
pdf_component = PDF()
|
||||
cached_pdf_file = pdf_component.as_example("test/test_files/sample_file.pdf")
|
||||
assert (
|
||||
cached_pdf_file
|
||||
and Path(cached_pdf_file).exists()
|
||||
and Path(cached_pdf_file).name == "sample_file.pdf"
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user