mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-25 12:10:31 +08:00
Makes some fixes related to TempFiles (#3523)
* temp file fixes * changes * fixing tests * formatting * fix * fix chatbot processing * tests * get tests to pass * fix code demo * changelog * fix multimodal
This commit is contained in:
parent
b9f0822510
commit
88afd684be
@ -8,7 +8,7 @@ No changes to highlight.
|
||||
|
||||
## Bug Fixes:
|
||||
|
||||
No changes to highlight.
|
||||
- Fixes `Chatbot` and `Image` components so that files passed during processing are added to a directory where they can be served from, by [@abidlabs](https://github.com/abidlabs) in [PR 3523](https://github.com/gradio-app/gradio/pull/3523)
|
||||
|
||||
## Documentation Changes:
|
||||
|
||||
|
@ -1 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: code"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/code/file.css"]}, {"cell_type": "code", "execution_count": null, "id": 44380577570523278879349135829904343037, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "from time import sleep\n", "\n", "\n", "css_file = os.path.join(os.path.abspath(''), \"file.css\")\n", "\n", "\n", "def set_lang(language):\n", " print(language)\n", " return gr.Code.update(language=language)\n", "\n", "\n", "def set_lang_from_path():\n", " sleep(1)\n", " return gr.Code.update(css_file, language=\"css\")\n", "\n", "\n", "def code(language, code):\n", " return gr.Code.update(code, language=language)\n", "\n", "\n", "io = gr.Interface(lambda x: x, \"code\", \"code\")\n", "\n", "with gr.Blocks() as demo:\n", " lang = gr.Dropdown(value=\"python\", choices=gr.Code.languages)\n", " with gr.Row():\n", " code_in = gr.Code(language=\"python\", label=\"Input\")\n", " code_out = gr.Code(label=\"Ouput\")\n", " btn = gr.Button(\"Run\")\n", " btn_two = gr.Button(\"Load File\")\n", "\n", " lang.change(set_lang, inputs=lang, outputs=code_in)\n", " btn.click(code, inputs=[lang, code_in], outputs=code_out)\n", " btn_two.click(set_lang_from_path, inputs=None, outputs=code_out)\n", " io.render()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
||||
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: code"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/code/file.css"]}, {"cell_type": "code", "execution_count": null, "id": 44380577570523278879349135829904343037, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "from time import sleep\n", "\n", "\n", "css_file = os.path.join(os.path.abspath(''), \"file.css\")\n", "\n", "\n", "def set_lang(language):\n", " print(language)\n", " return gr.Code.update(language=language)\n", "\n", "\n", "def set_lang_from_path():\n", " sleep(1)\n", " return gr.Code.update((css_file, ), language=\"css\")\n", "\n", "\n", "def code(language, code):\n", " return gr.Code.update(code, language=language)\n", "\n", "\n", "io = gr.Interface(lambda x: x, \"code\", \"code\")\n", "\n", "with gr.Blocks() as demo:\n", " lang = gr.Dropdown(value=\"python\", choices=gr.Code.languages)\n", " with gr.Row():\n", " code_in = gr.Code(language=\"python\", label=\"Input\")\n", " code_out = gr.Code(label=\"Ouput\")\n", " btn = gr.Button(\"Run\")\n", " btn_two = gr.Button(\"Load File\")\n", "\n", " lang.change(set_lang, inputs=lang, outputs=code_in)\n", " btn.click(code, inputs=[lang, code_in], outputs=code_out)\n", " btn_two.click(set_lang_from_path, inputs=None, outputs=code_out)\n", " io.render()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
@ -13,7 +13,7 @@ def set_lang(language):
|
||||
|
||||
def set_lang_from_path():
|
||||
sleep(1)
|
||||
return gr.Code.update(css_file, language="css")
|
||||
return gr.Code.update((css_file, ), language="css")
|
||||
|
||||
|
||||
def code(language, code):
|
||||
|
@ -89,7 +89,7 @@ class Block:
|
||||
Context.block.add(self)
|
||||
if Context.root_block is not None:
|
||||
Context.root_block.blocks[self._id] = self
|
||||
if isinstance(self, components.TempFileManager):
|
||||
if isinstance(self, components.IOComponent):
|
||||
Context.root_block.temp_file_sets.append(self.temp_files)
|
||||
return self
|
||||
|
||||
|
@ -4,26 +4,33 @@ each component. These demos are located in the `demo` directory."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import math
|
||||
import operator
|
||||
import random
|
||||
import secrets
|
||||
import shutil
|
||||
import tempfile
|
||||
import urllib.request
|
||||
import uuid
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Set, Tuple, Type
|
||||
|
||||
import aiofiles
|
||||
import altair as alt
|
||||
import matplotlib.figure
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import PIL
|
||||
import PIL.ImageOps
|
||||
import requests
|
||||
from fastapi import UploadFile
|
||||
from ffmpy import FFmpeg
|
||||
from pandas.api.types import is_numeric_dtype
|
||||
from PIL import Image as _Image # using _ to minimize namespace pollution
|
||||
@ -49,7 +56,6 @@ from gradio.events import (
|
||||
)
|
||||
from gradio.interpretation import NeighborInterpretable, TokenInterpretable
|
||||
from gradio.layouts import Column, Form, Row
|
||||
from gradio.processing_utils import TempFileManager
|
||||
from gradio.serializing import (
|
||||
FileSerializable,
|
||||
ImgSerializable,
|
||||
@ -177,6 +183,9 @@ class IOComponent(Component, Serializable):
|
||||
every: float | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.temp_files: Set[str] = set()
|
||||
self.DEFAULT_TEMP_DIR = tempfile.gettempdir()
|
||||
|
||||
Component.__init__(
|
||||
self, elem_id=elem_id, elem_classes=elem_classes, visible=visible, **kwargs
|
||||
)
|
||||
@ -198,6 +207,120 @@ class IOComponent(Component, Serializable):
|
||||
if callable(load_fn):
|
||||
self.attach_load_event(load_fn, every)
|
||||
|
||||
def hash_file(self, file_path: str, chunk_num_blocks: int = 128) -> str:
|
||||
sha1 = hashlib.sha1()
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(chunk_num_blocks * sha1.block_size), b""):
|
||||
sha1.update(chunk)
|
||||
return sha1.hexdigest()
|
||||
|
||||
def hash_url(self, url: str, chunk_num_blocks: int = 128) -> str:
|
||||
sha1 = hashlib.sha1()
|
||||
remote = urllib.request.urlopen(url)
|
||||
max_file_size = 100 * 1024 * 1024 # 100MB
|
||||
total_read = 0
|
||||
while True:
|
||||
data = remote.read(chunk_num_blocks * sha1.block_size)
|
||||
total_read += chunk_num_blocks * sha1.block_size
|
||||
if not data or total_read > max_file_size:
|
||||
break
|
||||
sha1.update(data)
|
||||
return sha1.hexdigest()
|
||||
|
||||
def hash_base64(self, base64_encoding: str, chunk_num_blocks: int = 128) -> str:
|
||||
sha1 = hashlib.sha1()
|
||||
for i in range(0, len(base64_encoding), chunk_num_blocks * sha1.block_size):
|
||||
data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size]
|
||||
sha1.update(data.encode("utf-8"))
|
||||
return sha1.hexdigest()
|
||||
|
||||
def make_temp_copy_if_needed(self, file_path: str) -> str:
|
||||
"""Returns a temporary file path for a copy of the given file path if it does
|
||||
not already exist. Otherwise returns the path to the existing temp file."""
|
||||
temp_dir = self.hash_file(file_path)
|
||||
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir
|
||||
temp_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
|
||||
f.name = utils.strip_invalid_filename_characters(Path(file_path).name)
|
||||
full_temp_file_path = str(utils.abspath(temp_dir / f.name))
|
||||
|
||||
if not Path(full_temp_file_path).exists():
|
||||
shutil.copy2(file_path, full_temp_file_path)
|
||||
|
||||
self.temp_files.add(full_temp_file_path)
|
||||
return full_temp_file_path
|
||||
|
||||
async def save_uploaded_file(self, file: UploadFile, upload_dir: str) -> str:
|
||||
temp_dir = secrets.token_hex(
|
||||
20
|
||||
) # Since the full file is being uploaded anyways, there is no benefit to hashing the file.
|
||||
temp_dir = Path(upload_dir) / temp_dir
|
||||
temp_dir.mkdir(exist_ok=True, parents=True)
|
||||
output_file_obj = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
|
||||
|
||||
if file.filename:
|
||||
file_name = Path(file.filename).name
|
||||
output_file_obj.name = utils.strip_invalid_filename_characters(file_name)
|
||||
|
||||
full_temp_file_path = str(utils.abspath(temp_dir / output_file_obj.name))
|
||||
|
||||
async with aiofiles.open(full_temp_file_path, "wb") as output_file:
|
||||
while True:
|
||||
content = await file.read(100 * 1024 * 1024)
|
||||
if not content:
|
||||
break
|
||||
await output_file.write(content)
|
||||
|
||||
return full_temp_file_path
|
||||
|
||||
def download_temp_copy_if_needed(self, url: str) -> str:
|
||||
"""Downloads a file and makes a temporary file path for a copy if does not already
|
||||
exist. Otherwise returns the path to the existing temp file."""
|
||||
temp_dir = self.hash_url(url)
|
||||
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir
|
||||
temp_dir.mkdir(exist_ok=True, parents=True)
|
||||
f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
|
||||
|
||||
f.name = utils.strip_invalid_filename_characters(Path(url).name)
|
||||
full_temp_file_path = str(utils.abspath(temp_dir / f.name))
|
||||
|
||||
if not Path(full_temp_file_path).exists():
|
||||
with requests.get(url, stream=True) as r:
|
||||
with open(full_temp_file_path, "wb") as f:
|
||||
shutil.copyfileobj(r.raw, f)
|
||||
|
||||
self.temp_files.add(full_temp_file_path)
|
||||
return full_temp_file_path
|
||||
|
||||
def base64_to_temp_file_if_needed(
|
||||
self, base64_encoding: str, file_name: str | None = None
|
||||
) -> str:
|
||||
"""Converts a base64 encoding to a file and returns the path to the file if
|
||||
the file doesn't already exist. Otherwise returns the path to the existing file."""
|
||||
temp_dir = self.hash_base64(base64_encoding)
|
||||
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir
|
||||
temp_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
guess_extension = processing_utils.get_extension(base64_encoding)
|
||||
if file_name:
|
||||
file_name = utils.strip_invalid_filename_characters(file_name)
|
||||
elif guess_extension:
|
||||
file_name = "file." + guess_extension
|
||||
else:
|
||||
file_name = "file"
|
||||
f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
|
||||
f.name = file_name
|
||||
full_temp_file_path = str(utils.abspath(temp_dir / f.name))
|
||||
|
||||
if not Path(full_temp_file_path).exists():
|
||||
data, _ = processing_utils.decode_base64_to_binary(base64_encoding)
|
||||
with open(full_temp_file_path, "wb") as fb:
|
||||
fb.write(data)
|
||||
|
||||
self.temp_files.add(full_temp_file_path)
|
||||
return full_temp_file_path
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
"label": self.label,
|
||||
@ -1565,7 +1688,7 @@ class Image(
|
||||
suffix=("." + fmt.lower() if fmt is not None else ".png"),
|
||||
)
|
||||
im.save(file_obj.name)
|
||||
return file_obj.name
|
||||
return self.make_temp_copy_if_needed(file_obj.name)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown type: "
|
||||
@ -1797,7 +1920,6 @@ class Video(
|
||||
Uploadable,
|
||||
IOComponent,
|
||||
FileSerializable,
|
||||
TempFileManager,
|
||||
):
|
||||
"""
|
||||
Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output).
|
||||
@ -1854,7 +1976,6 @@ class Video(
|
||||
self.include_audio = (
|
||||
include_audio if include_audio is not None else source == "upload"
|
||||
)
|
||||
TempFileManager.__init__(self)
|
||||
IOComponent.__init__(
|
||||
self,
|
||||
label=label,
|
||||
@ -2016,7 +2137,6 @@ class Audio(
|
||||
Uploadable,
|
||||
IOComponent,
|
||||
FileSerializable,
|
||||
TempFileManager,
|
||||
TokenInterpretable,
|
||||
):
|
||||
"""
|
||||
@ -2076,7 +2196,6 @@ class Audio(
|
||||
raise ValueError(
|
||||
"Audio streaming only available if source is 'microphone'."
|
||||
)
|
||||
TempFileManager.__init__(self)
|
||||
IOComponent.__init__(
|
||||
self,
|
||||
label=label,
|
||||
@ -2326,7 +2445,6 @@ class File(
|
||||
Uploadable,
|
||||
IOComponent,
|
||||
FileSerializable,
|
||||
TempFileManager,
|
||||
):
|
||||
"""
|
||||
Creates a file component that allows uploading generic file (when used as an input) and or displaying generic files (output).
|
||||
@ -2397,7 +2515,6 @@ class File(
|
||||
Uses event data gradio.SelectData to carry `value` referring to name of selected file, and `index` to refer to index.
|
||||
See EventData documentation on how to use this event data.
|
||||
"""
|
||||
TempFileManager.__init__(self)
|
||||
IOComponent.__init__(
|
||||
self,
|
||||
label=label,
|
||||
@ -3100,9 +3217,7 @@ class Button(Clickable, IOComponent, SimpleSerializable):
|
||||
|
||||
|
||||
@document("style")
|
||||
class UploadButton(
|
||||
Clickable, Uploadable, IOComponent, FileSerializable, TempFileManager
|
||||
):
|
||||
class UploadButton(Clickable, Uploadable, IOComponent, FileSerializable):
|
||||
"""
|
||||
Used to create an upload button, when cicked allows a user to upload files that satisfy the specified file type or generic files (if file_type not set).
|
||||
Preprocessing: passes the uploaded file as a {file-object} or {List[file-object]} depending on `file_count` (or a {bytes}/{List{bytes}} depending on `type`)
|
||||
@ -3147,7 +3262,6 @@ class UploadButton(
|
||||
)
|
||||
self.file_types = file_types
|
||||
self.label = label
|
||||
TempFileManager.__init__(self)
|
||||
IOComponent.__init__(
|
||||
self,
|
||||
label=label,
|
||||
@ -3855,7 +3969,7 @@ class HTML(Changeable, IOComponent, SimpleSerializable):
|
||||
|
||||
|
||||
@document("style")
|
||||
class Gallery(IOComponent, TempFileManager, FileSerializable, Selectable):
|
||||
class Gallery(IOComponent, FileSerializable, Selectable):
|
||||
"""
|
||||
Used to display a list of images as a gallery that can be scrolled through.
|
||||
Preprocessing: this component does *not* accept input.
|
||||
@ -3892,7 +4006,6 @@ class Gallery(IOComponent, TempFileManager, FileSerializable, Selectable):
|
||||
Uses event data gradio.SelectData to carry `value` referring to caption of selected image, and `index` to refer to index.
|
||||
See EventData documentation on how to use this event data.
|
||||
"""
|
||||
TempFileManager.__init__(self)
|
||||
IOComponent.__init__(
|
||||
self,
|
||||
label=label,
|
||||
@ -4055,14 +4168,16 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
|
||||
"""
|
||||
Displays a chatbot output showing both user submitted messages and responses. Supports a subset of Markdown including bold, italics, code, and images.
|
||||
Preprocessing: this component does *not* accept input.
|
||||
Postprocessing: expects function to return a {List[Tuple[str | None | Tuple, str | None | Tuple]]}, a list of tuples with user message and response messages. Messages should be strings, tuples, or Nones. If the message is a string, it can include Markdown. If it is a tuple, it should consist of (string filepath to image/video/audio, [optional string alt text]). Messages that are `None` are not displayed.
|
||||
Postprocessing: expects function to return a {List[List[str | None | Tuple]]}, a list of lists. The inner list should have 2 elements: the user message and the response message. Messages should be strings, tuples, or Nones. If the message is a string, it can include Markdown. If it is a tuple, it should consist of (string filepath to image/video/audio, [optional string alt text]). Messages that are `None` are not displayed.
|
||||
|
||||
Demos: chatbot_simple, chatbot_multimodal
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
value: List[Tuple[str | None, str | None]] | Callable | None = None,
|
||||
value: List[List[str | Tuple[str] | Tuple[str, str] | None]]
|
||||
| Callable
|
||||
| None = None,
|
||||
color_map: Dict[str, str] | None = None, # Parameter moved to Chatbot.style()
|
||||
*,
|
||||
label: str | None = None,
|
||||
@ -4116,7 +4231,9 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
|
||||
|
||||
@staticmethod
|
||||
def update(
|
||||
value: Any | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE,
|
||||
value: List[List[str | Tuple[str] | Tuple[str, str] | None]]
|
||||
| Literal[_Keywords.NO_VALUE]
|
||||
| None = _Keywords.NO_VALUE,
|
||||
label: str | None = None,
|
||||
show_label: bool | None = None,
|
||||
visible: bool | None = None,
|
||||
@ -4130,24 +4247,57 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
|
||||
}
|
||||
return updated_config
|
||||
|
||||
def _process_chat_messages(
|
||||
self, chat_message: str | Tuple | List | Dict | None
|
||||
def _preprocess_chat_messages(
|
||||
self, chat_message: str | Dict | None
|
||||
) -> str | Tuple[str] | Tuple[str, str] | None:
|
||||
if chat_message is None:
|
||||
return None
|
||||
elif isinstance(chat_message, dict):
|
||||
if chat_message["alt_text"] is not None:
|
||||
return (chat_message["name"], chat_message["alt_text"])
|
||||
else:
|
||||
return (chat_message["name"],)
|
||||
else: # string
|
||||
return chat_message
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
y: List[List[str | Dict | None] | Tuple[str | Dict | None, str | Dict | None]],
|
||||
) -> List[List[str | Tuple[str] | Tuple[str, str] | None]]:
|
||||
if y is None:
|
||||
return y
|
||||
processed_messages = []
|
||||
for message_pair in y:
|
||||
assert isinstance(
|
||||
message_pair, (tuple, list)
|
||||
), f"Expected a list of lists or list of tuples. Received: {message_pair}"
|
||||
assert (
|
||||
len(message_pair) == 2
|
||||
), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
|
||||
processed_messages.append(
|
||||
[
|
||||
self._preprocess_chat_messages(message_pair[0]),
|
||||
self._preprocess_chat_messages(message_pair[1]),
|
||||
]
|
||||
)
|
||||
return processed_messages
|
||||
|
||||
def _postprocess_chat_messages(
|
||||
self, chat_message: str | Tuple | List | None
|
||||
) -> str | Dict | None:
|
||||
if chat_message is None:
|
||||
return None
|
||||
elif isinstance(chat_message, (tuple, list)):
|
||||
mime_type = processing_utils.get_mimetype(chat_message[0])
|
||||
filepath = chat_message[0]
|
||||
mime_type = processing_utils.get_mimetype(filepath)
|
||||
filepath = self.make_temp_copy_if_needed(filepath)
|
||||
return {
|
||||
"name": chat_message[0],
|
||||
"name": filepath,
|
||||
"mime_type": mime_type,
|
||||
"alt_text": chat_message[1] if len(chat_message) > 1 else None,
|
||||
"data": None, # These last two fields are filled in by the frontend
|
||||
"is_file": True,
|
||||
}
|
||||
elif isinstance(
|
||||
chat_message, dict
|
||||
): # This happens for previously processed messages
|
||||
return chat_message
|
||||
elif isinstance(chat_message, str):
|
||||
return self.md.renderInline(chat_message)
|
||||
else:
|
||||
@ -4155,15 +4305,13 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
y: List[
|
||||
Tuple[str | Tuple | List | Dict | None, str | Tuple | List | Dict | None]
|
||||
],
|
||||
) -> List[Tuple[str | Dict | None, str | Dict | None]]:
|
||||
y: List[List[str | Tuple[str] | Tuple[str, str] | None] | Tuple],
|
||||
) -> List[List[str | Dict | None]]:
|
||||
"""
|
||||
Parameters:
|
||||
y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format. It can also be a tuple whose first element is a string filepath or URL to an image/video/audio, and second (optional) element is the alt text, in which case the media file is displayed. It can also be None, in which case that message is not displayed.
|
||||
y: List of lists representing the message and response pairs. Each message and response should be a string, which may be in Markdown format. It can also be a tuple whose first element is a string filepath or URL to an image/video/audio, and second (optional) element is the alt text, in which case the media file is displayed. It can also be None, in which case that message is not displayed.
|
||||
Returns:
|
||||
List of tuples representing the message and response. Each message and response will be a string of HTML, or a dictionary with media information.
|
||||
List of lists representing the message and response. Each message and response will be a string of HTML, or a dictionary with media information. Or None if the message is not to be displayed.
|
||||
"""
|
||||
if y is None:
|
||||
return []
|
||||
@ -4176,10 +4324,10 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
|
||||
len(message_pair) == 2
|
||||
), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
|
||||
processed_messages.append(
|
||||
(
|
||||
self._process_chat_messages(message_pair[0]),
|
||||
self._process_chat_messages(message_pair[1]),
|
||||
)
|
||||
[
|
||||
self._postprocess_chat_messages(message_pair[0]),
|
||||
self._postprocess_chat_messages(message_pair[1]),
|
||||
]
|
||||
)
|
||||
return processed_messages
|
||||
|
||||
@ -4200,9 +4348,7 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
|
||||
|
||||
|
||||
@document("style")
|
||||
class Model3D(
|
||||
Changeable, Editable, Clearable, IOComponent, FileSerializable, TempFileManager
|
||||
):
|
||||
class Model3D(Changeable, Editable, Clearable, IOComponent, FileSerializable):
|
||||
"""
|
||||
Component allows users to upload or view 3D Model files (.obj, .glb, or .gltf).
|
||||
Preprocessing: This component passes the uploaded file as a {str} filepath.
|
||||
@ -4237,7 +4383,6 @@ class Model3D(
|
||||
elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.
|
||||
"""
|
||||
self.clear_color = clear_color or [0, 0, 0, 0]
|
||||
TempFileManager.__init__(self)
|
||||
IOComponent.__init__(
|
||||
self,
|
||||
label=label,
|
||||
@ -5549,7 +5694,7 @@ class Code(Changeable, IOComponent, SimpleSerializable):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
value: str | None = None,
|
||||
value: str | Tuple[str] | None = None,
|
||||
language: str | None = None,
|
||||
*,
|
||||
label: str | None = None,
|
||||
@ -5604,7 +5749,10 @@ class Code(Changeable, IOComponent, SimpleSerializable):
|
||||
|
||||
@staticmethod
|
||||
def update(
|
||||
value: str | None | Literal[_Keywords.NO_VALUE] = _Keywords.NO_VALUE,
|
||||
value: str
|
||||
| Tuple[str]
|
||||
| None
|
||||
| Literal[_Keywords.NO_VALUE] = _Keywords.NO_VALUE,
|
||||
label: str | None = None,
|
||||
show_label: bool | None = None,
|
||||
visible: bool | None = None,
|
||||
|
@ -1,24 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import secrets
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import urllib.request
|
||||
import warnings
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Dict, Set, Tuple
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import aiofiles
|
||||
import numpy as np
|
||||
import requests
|
||||
from fastapi import UploadFile
|
||||
from ffmpy import FFmpeg, FFprobe, FFRuntimeError
|
||||
from PIL import Image, ImageOps, PngImagePlugin
|
||||
|
||||
@ -318,132 +313,6 @@ def file_to_json(file_path: str | Path) -> Dict:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
class TempFileManager:
|
||||
"""
|
||||
A class that should be inherited by any Component that needs to manage temporary files.
|
||||
It should be instantiated in the __init__ method of the component.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Set stores all the temporary files created by this component.
|
||||
self.temp_files: Set[str] = set()
|
||||
self.DEFAULT_TEMP_DIR = tempfile.gettempdir()
|
||||
|
||||
def hash_file(self, file_path: str, chunk_num_blocks: int = 128) -> str:
|
||||
sha1 = hashlib.sha1()
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(chunk_num_blocks * sha1.block_size), b""):
|
||||
sha1.update(chunk)
|
||||
return sha1.hexdigest()
|
||||
|
||||
def hash_url(self, url: str, chunk_num_blocks: int = 128) -> str:
|
||||
sha1 = hashlib.sha1()
|
||||
remote = urllib.request.urlopen(url)
|
||||
max_file_size = 100 * 1024 * 1024 # 100MB
|
||||
total_read = 0
|
||||
while True:
|
||||
data = remote.read(chunk_num_blocks * sha1.block_size)
|
||||
total_read += chunk_num_blocks * sha1.block_size
|
||||
if not data or total_read > max_file_size:
|
||||
break
|
||||
sha1.update(data)
|
||||
return sha1.hexdigest()
|
||||
|
||||
def hash_base64(self, base64_encoding: str, chunk_num_blocks: int = 128) -> str:
|
||||
sha1 = hashlib.sha1()
|
||||
for i in range(0, len(base64_encoding), chunk_num_blocks * sha1.block_size):
|
||||
data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size]
|
||||
sha1.update(data.encode("utf-8"))
|
||||
return sha1.hexdigest()
|
||||
|
||||
def make_temp_copy_if_needed(self, file_path: str) -> str:
|
||||
"""Returns a temporary file path for a copy of the given file path if it does
|
||||
not already exist. Otherwise returns the path to the existing temp file."""
|
||||
temp_dir = self.hash_file(file_path)
|
||||
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir
|
||||
temp_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
|
||||
f.name = utils.strip_invalid_filename_characters(Path(file_path).name)
|
||||
full_temp_file_path = str(utils.abspath(temp_dir / f.name))
|
||||
|
||||
if not Path(full_temp_file_path).exists():
|
||||
shutil.copy2(file_path, full_temp_file_path)
|
||||
|
||||
self.temp_files.add(full_temp_file_path)
|
||||
return full_temp_file_path
|
||||
|
||||
async def save_uploaded_file(self, file: UploadFile, upload_dir: str) -> str:
|
||||
temp_dir = secrets.token_hex(
|
||||
20
|
||||
) # Since the full file is being uploaded anyways, there is no benefit to hashing the file.
|
||||
temp_dir = Path(upload_dir) / temp_dir
|
||||
temp_dir.mkdir(exist_ok=True, parents=True)
|
||||
output_file_obj = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
|
||||
|
||||
if file.filename:
|
||||
file_name = Path(file.filename).name
|
||||
output_file_obj.name = utils.strip_invalid_filename_characters(file_name)
|
||||
|
||||
full_temp_file_path = str(utils.abspath(temp_dir / output_file_obj.name))
|
||||
|
||||
async with aiofiles.open(full_temp_file_path, "wb") as output_file:
|
||||
while True:
|
||||
content = await file.read(100 * 1024 * 1024)
|
||||
if not content:
|
||||
break
|
||||
await output_file.write(content)
|
||||
|
||||
return full_temp_file_path
|
||||
|
||||
def download_temp_copy_if_needed(self, url: str) -> str:
|
||||
"""Downloads a file and makes a temporary file path for a copy if does not already
|
||||
exist. Otherwise returns the path to the existing temp file."""
|
||||
temp_dir = self.hash_url(url)
|
||||
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir
|
||||
temp_dir.mkdir(exist_ok=True, parents=True)
|
||||
f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
|
||||
|
||||
f.name = utils.strip_invalid_filename_characters(Path(url).name)
|
||||
full_temp_file_path = str(utils.abspath(temp_dir / f.name))
|
||||
|
||||
if not Path(full_temp_file_path).exists():
|
||||
with requests.get(url, stream=True) as r:
|
||||
with open(full_temp_file_path, "wb") as f:
|
||||
shutil.copyfileobj(r.raw, f)
|
||||
|
||||
self.temp_files.add(full_temp_file_path)
|
||||
return full_temp_file_path
|
||||
|
||||
def base64_to_temp_file_if_needed(
|
||||
self, base64_encoding: str, file_name: str | None = None
|
||||
) -> str:
|
||||
"""Converts a base64 encoding to a file and returns the path to the file if
|
||||
the file doesn't already exist. Otherwise returns the path to the existing file."""
|
||||
temp_dir = self.hash_base64(base64_encoding)
|
||||
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir
|
||||
temp_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
guess_extension = get_extension(base64_encoding)
|
||||
if file_name:
|
||||
file_name = utils.strip_invalid_filename_characters(file_name)
|
||||
elif guess_extension:
|
||||
file_name = "file." + guess_extension
|
||||
else:
|
||||
file_name = "file"
|
||||
f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
|
||||
f.name = file_name
|
||||
full_temp_file_path = str(utils.abspath(temp_dir / f.name))
|
||||
|
||||
if not Path(full_temp_file_path).exists():
|
||||
data, _ = decode_base64_to_binary(base64_encoding)
|
||||
with open(full_temp_file_path, "wb") as fb:
|
||||
fb.write(data)
|
||||
|
||||
self.temp_files.add(full_temp_file_path)
|
||||
return full_temp_file_path
|
||||
|
||||
|
||||
def download_tmp_copy_of_file(
|
||||
url_path: str, access_token: str | None = None, dir: str | None = None
|
||||
) -> tempfile._TemporaryFileWrapper:
|
||||
|
@ -46,7 +46,6 @@ from gradio.data_classes import PredictBody, ResetBody
|
||||
from gradio.documentation import document, set_documentation_group
|
||||
from gradio.exceptions import Error
|
||||
from gradio.helpers import EventData
|
||||
from gradio.processing_utils import TempFileManager
|
||||
from gradio.queueing import Estimation, Event
|
||||
from gradio.utils import cancel_tasks, run_coro_in_background, set_task_name
|
||||
|
||||
@ -548,7 +547,7 @@ class App(FastAPI):
|
||||
files: List[UploadFile] = File(...),
|
||||
):
|
||||
output_files = []
|
||||
file_manager = TempFileManager()
|
||||
file_manager = gradio.File()
|
||||
for input_file in files:
|
||||
output_files.append(
|
||||
await file_manager.save_uploaded_file(
|
||||
|
@ -623,6 +623,9 @@ class TestImage:
|
||||
assert image_input.preprocess(img).size == (30, 10)
|
||||
assert image_input.postprocess("test/test_files/bus.png") == img
|
||||
assert image_input.serialize("test/test_files/bus.png") == img
|
||||
image_input = gr.Image(type="filepath")
|
||||
image_temp_filepath = image_input.preprocess(img)
|
||||
assert image_temp_filepath in image_input.temp_files
|
||||
|
||||
image_input = gr.Image(
|
||||
source="upload", tool="editor", type="pil", label="Upload Your Image"
|
||||
@ -1692,49 +1695,66 @@ class TestChatbot:
|
||||
Postprocess, get_config
|
||||
"""
|
||||
chatbot = gr.Chatbot()
|
||||
assert chatbot.postprocess([("You are **cool**", "so are *you*")]) == [
|
||||
("You are <strong>cool</strong>", "so are <em>you</em>")
|
||||
assert chatbot.postprocess([["You are **cool**", "so are *you*"]]) == [
|
||||
["You are <strong>cool</strong>", "so are <em>you</em>"]
|
||||
]
|
||||
|
||||
multimodal_msg = [
|
||||
(("driving.mp4",), "cool video"),
|
||||
(("cantina.wav",), "cool audio"),
|
||||
(("lion.jpg", "A lion"), "cool pic"),
|
||||
[("test/test_files/video_sample.mp4",), "cool video"],
|
||||
[("test/test_files/audio_sample.wav",), "cool audio"],
|
||||
[("test/test_files/bus.png", "A bus"), "cool pic"],
|
||||
]
|
||||
processed_multimodal_msg = [
|
||||
(
|
||||
[
|
||||
{
|
||||
"name": "driving.mp4",
|
||||
"name": "video_sample.mp4",
|
||||
"mime_type": "video/mp4",
|
||||
"alt_text": None,
|
||||
"data": None,
|
||||
"is_file": True,
|
||||
},
|
||||
"cool video",
|
||||
),
|
||||
(
|
||||
],
|
||||
[
|
||||
{
|
||||
"name": "cantina.wav",
|
||||
"name": "audio_sample.wav",
|
||||
"mime_type": "audio/wav",
|
||||
"alt_text": None,
|
||||
"data": None,
|
||||
"is_file": True,
|
||||
},
|
||||
"cool audio",
|
||||
),
|
||||
(
|
||||
],
|
||||
[
|
||||
{
|
||||
"name": "lion.jpg",
|
||||
"mime_type": "image/jpeg",
|
||||
"alt_text": "A lion",
|
||||
"name": "bus.png",
|
||||
"mime_type": "image/png",
|
||||
"alt_text": "A bus",
|
||||
"data": None,
|
||||
"is_file": True,
|
||||
},
|
||||
"cool pic",
|
||||
),
|
||||
],
|
||||
]
|
||||
postprocessed_multimodal_msg = chatbot.postprocess(multimodal_msg)
|
||||
postprocessed_multimodal_msg_base_names = []
|
||||
for x, y in postprocessed_multimodal_msg:
|
||||
if isinstance(x, dict):
|
||||
x["name"] = os.path.basename(x["name"])
|
||||
postprocessed_multimodal_msg_base_names.append([x, y])
|
||||
assert postprocessed_multimodal_msg_base_names == processed_multimodal_msg
|
||||
|
||||
preprocessed_multimodal_msg = chatbot.preprocess(processed_multimodal_msg)
|
||||
multimodal_msg_base_names = []
|
||||
for x, y in multimodal_msg:
|
||||
if isinstance(x, tuple):
|
||||
if len(x) > 1:
|
||||
new_x = (os.path.basename(x[0]), x[1])
|
||||
else:
|
||||
new_x = (os.path.basename(x[0]),)
|
||||
multimodal_msg_base_names.append([new_x, y])
|
||||
assert multimodal_msg_base_names == preprocessed_multimodal_msg
|
||||
|
||||
assert chatbot.postprocess(multimodal_msg) == processed_multimodal_msg
|
||||
assert chatbot.get_config() == {
|
||||
"value": [],
|
||||
"label": None,
|
||||
@ -2606,3 +2626,84 @@ class TestCode:
|
||||
"interactive": None,
|
||||
"root_url": None,
|
||||
}
|
||||
|
||||
|
||||
class TestTempFileManagement:
|
||||
def test_hash_file(self):
|
||||
temp_file_manager = gr.File()
|
||||
h1 = temp_file_manager.hash_file("gradio/test_data/cheetah1.jpg")
|
||||
h2 = temp_file_manager.hash_file("gradio/test_data/cheetah1-copy.jpg")
|
||||
h3 = temp_file_manager.hash_file("gradio/test_data/cheetah2.jpg")
|
||||
assert h1 == h2
|
||||
assert h1 != h3
|
||||
|
||||
@patch("shutil.copy2")
|
||||
def test_make_temp_copy_if_needed(self, mock_copy):
|
||||
temp_file_manager = gr.File()
|
||||
|
||||
f = temp_file_manager.make_temp_copy_if_needed("gradio/test_data/cheetah1.jpg")
|
||||
try: # Delete if already exists from before this test
|
||||
os.remove(f)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
f = temp_file_manager.make_temp_copy_if_needed("gradio/test_data/cheetah1.jpg")
|
||||
assert mock_copy.called
|
||||
assert len(temp_file_manager.temp_files) == 1
|
||||
assert Path(f).name == "cheetah1.jpg"
|
||||
|
||||
f = temp_file_manager.make_temp_copy_if_needed("gradio/test_data/cheetah1.jpg")
|
||||
assert len(temp_file_manager.temp_files) == 1
|
||||
|
||||
f = temp_file_manager.make_temp_copy_if_needed(
|
||||
"gradio/test_data/cheetah1-copy.jpg"
|
||||
)
|
||||
assert len(temp_file_manager.temp_files) == 2
|
||||
assert Path(f).name == "cheetah1-copy.jpg"
|
||||
|
||||
def test_base64_to_temp_file_if_needed(self):
|
||||
temp_file_manager = gr.File()
|
||||
|
||||
base64_file_1 = media_data.BASE64_IMAGE
|
||||
base64_file_2 = media_data.BASE64_AUDIO["data"]
|
||||
|
||||
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
|
||||
try: # Delete if already exists from before this test
|
||||
os.remove(f)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
|
||||
assert len(temp_file_manager.temp_files) == 1
|
||||
|
||||
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
|
||||
assert len(temp_file_manager.temp_files) == 1
|
||||
|
||||
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_2)
|
||||
assert len(temp_file_manager.temp_files) == 2
|
||||
|
||||
for file in temp_file_manager.temp_files:
|
||||
os.remove(file)
|
||||
|
||||
@pytest.mark.flaky
|
||||
@patch("shutil.copyfileobj")
|
||||
def test_download_temp_copy_if_needed(self, mock_copy):
|
||||
temp_file_manager = gr.File()
|
||||
url1 = "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/test_image.png"
|
||||
url2 = "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/cheetah1.jpg"
|
||||
|
||||
f = temp_file_manager.download_temp_copy_if_needed(url1)
|
||||
try: # Delete if already exists from before this test
|
||||
os.remove(f)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
f = temp_file_manager.download_temp_copy_if_needed(url1)
|
||||
assert mock_copy.called
|
||||
assert len(temp_file_manager.temp_files) == 1
|
||||
|
||||
f = temp_file_manager.download_temp_copy_if_needed(url1)
|
||||
assert len(temp_file_manager.temp_files) == 1
|
||||
|
||||
f = temp_file_manager.download_temp_copy_if_needed(url2)
|
||||
assert len(temp_file_manager.temp_files) == 2
|
||||
|
@ -134,87 +134,6 @@ class TestAudioPreprocessing:
|
||||
assert audio_.dtype == "int16"
|
||||
|
||||
|
||||
class TestTempFileManager:
|
||||
def test_hash_file(self):
|
||||
temp_file_manager = processing_utils.TempFileManager()
|
||||
h1 = temp_file_manager.hash_file("gradio/test_data/cheetah1.jpg")
|
||||
h2 = temp_file_manager.hash_file("gradio/test_data/cheetah1-copy.jpg")
|
||||
h3 = temp_file_manager.hash_file("gradio/test_data/cheetah2.jpg")
|
||||
assert h1 == h2
|
||||
assert h1 != h3
|
||||
|
||||
@patch("shutil.copy2")
|
||||
def test_make_temp_copy_if_needed(self, mock_copy):
|
||||
temp_file_manager = processing_utils.TempFileManager()
|
||||
|
||||
f = temp_file_manager.make_temp_copy_if_needed("gradio/test_data/cheetah1.jpg")
|
||||
try: # Delete if already exists from before this test
|
||||
os.remove(f)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
f = temp_file_manager.make_temp_copy_if_needed("gradio/test_data/cheetah1.jpg")
|
||||
assert mock_copy.called
|
||||
assert len(temp_file_manager.temp_files) == 1
|
||||
assert Path(f).name == "cheetah1.jpg"
|
||||
|
||||
f = temp_file_manager.make_temp_copy_if_needed("gradio/test_data/cheetah1.jpg")
|
||||
assert len(temp_file_manager.temp_files) == 1
|
||||
|
||||
f = temp_file_manager.make_temp_copy_if_needed(
|
||||
"gradio/test_data/cheetah1-copy.jpg"
|
||||
)
|
||||
assert len(temp_file_manager.temp_files) == 2
|
||||
assert Path(f).name == "cheetah1-copy.jpg"
|
||||
|
||||
def test_base64_to_temp_file_if_needed(self):
|
||||
temp_file_manager = processing_utils.TempFileManager()
|
||||
|
||||
base64_file_1 = media_data.BASE64_IMAGE
|
||||
base64_file_2 = media_data.BASE64_AUDIO["data"]
|
||||
|
||||
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
|
||||
try: # Delete if already exists from before this test
|
||||
os.remove(f)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
|
||||
assert len(temp_file_manager.temp_files) == 1
|
||||
|
||||
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
|
||||
assert len(temp_file_manager.temp_files) == 1
|
||||
|
||||
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_2)
|
||||
assert len(temp_file_manager.temp_files) == 2
|
||||
|
||||
for file in temp_file_manager.temp_files:
|
||||
os.remove(file)
|
||||
|
||||
@pytest.mark.flaky
|
||||
@patch("shutil.copyfileobj")
|
||||
def test_download_temp_copy_if_needed(self, mock_copy):
|
||||
temp_file_manager = processing_utils.TempFileManager()
|
||||
url1 = "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/test_image.png"
|
||||
url2 = "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/cheetah1.jpg"
|
||||
|
||||
f = temp_file_manager.download_temp_copy_if_needed(url1)
|
||||
try: # Delete if already exists from before this test
|
||||
os.remove(f)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
f = temp_file_manager.download_temp_copy_if_needed(url1)
|
||||
assert mock_copy.called
|
||||
assert len(temp_file_manager.temp_files) == 1
|
||||
|
||||
f = temp_file_manager.download_temp_copy_if_needed(url1)
|
||||
assert len(temp_file_manager.temp_files) == 1
|
||||
|
||||
f = temp_file_manager.download_temp_copy_if_needed(url2)
|
||||
assert len(temp_file_manager.temp_files) == 2
|
||||
|
||||
|
||||
class TestOutputPreprocessing:
|
||||
def test_decode_base64_to_binary(self):
|
||||
binary = processing_utils.decode_base64_to_binary(
|
||||
|
Loading…
x
Reference in New Issue
Block a user