Makes some fixes related to TempFiles (#3523)

* temp file fixes

* changes

* fixing tests

* formatting

* fix

* fix chatbot processing

* tests

* get tests to pass

* fix code demo

* changelog

* fix multimodal
This commit is contained in:
Abubakar Abid 2023-03-21 09:37:24 -07:00 committed by GitHub
parent b9f0822510
commit 88afd684be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 314 additions and 278 deletions

View File

@ -8,7 +8,7 @@ No changes to highlight.
## Bug Fixes:
No changes to highlight.
- Fixes `Chatbot` and `Image` components so that files passed during processing are added to a directory where they can be served from, by [@abidlabs](https://github.com/abidlabs) in [PR 3523](https://github.com/gradio-app/gradio/pull/3523)
## Documentation Changes:

View File

@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: code"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/code/file.css"]}, {"cell_type": "code", "execution_count": null, "id": 44380577570523278879349135829904343037, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "from time import sleep\n", "\n", "\n", "css_file = os.path.join(os.path.abspath(''), \"file.css\")\n", "\n", "\n", "def set_lang(language):\n", " print(language)\n", " return gr.Code.update(language=language)\n", "\n", "\n", "def set_lang_from_path():\n", " sleep(1)\n", " return gr.Code.update(css_file, language=\"css\")\n", "\n", "\n", "def code(language, code):\n", " return gr.Code.update(code, language=language)\n", "\n", "\n", "io = gr.Interface(lambda x: x, \"code\", \"code\")\n", "\n", "with gr.Blocks() as demo:\n", " lang = gr.Dropdown(value=\"python\", choices=gr.Code.languages)\n", " with gr.Row():\n", " code_in = gr.Code(language=\"python\", label=\"Input\")\n", " code_out = gr.Code(label=\"Ouput\")\n", " btn = gr.Button(\"Run\")\n", " btn_two = gr.Button(\"Load File\")\n", "\n", " lang.change(set_lang, inputs=lang, outputs=code_in)\n", " btn.click(code, inputs=[lang, code_in], outputs=code_out)\n", " btn_two.click(set_lang_from_path, inputs=None, outputs=code_out)\n", " io.render()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: code"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/code/file.css"]}, {"cell_type": "code", "execution_count": null, "id": 44380577570523278879349135829904343037, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "from time import sleep\n", "\n", "\n", "css_file = os.path.join(os.path.abspath(''), \"file.css\")\n", "\n", "\n", "def set_lang(language):\n", " print(language)\n", " return gr.Code.update(language=language)\n", "\n", "\n", "def set_lang_from_path():\n", " sleep(1)\n", " return gr.Code.update((css_file, ), language=\"css\")\n", "\n", "\n", "def code(language, code):\n", " return gr.Code.update(code, language=language)\n", "\n", "\n", "io = gr.Interface(lambda x: x, \"code\", \"code\")\n", "\n", "with gr.Blocks() as demo:\n", " lang = gr.Dropdown(value=\"python\", choices=gr.Code.languages)\n", " with gr.Row():\n", " code_in = gr.Code(language=\"python\", label=\"Input\")\n", " code_out = gr.Code(label=\"Ouput\")\n", " btn = gr.Button(\"Run\")\n", " btn_two = gr.Button(\"Load File\")\n", "\n", " lang.change(set_lang, inputs=lang, outputs=code_in)\n", " btn.click(code, inputs=[lang, code_in], outputs=code_out)\n", " btn_two.click(set_lang_from_path, inputs=None, outputs=code_out)\n", " io.render()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -13,7 +13,7 @@ def set_lang(language):
def set_lang_from_path():
sleep(1)
return gr.Code.update(css_file, language="css")
return gr.Code.update((css_file, ), language="css")
def code(language, code):

View File

@ -89,7 +89,7 @@ class Block:
Context.block.add(self)
if Context.root_block is not None:
Context.root_block.blocks[self._id] = self
if isinstance(self, components.TempFileManager):
if isinstance(self, components.IOComponent):
Context.root_block.temp_file_sets.append(self.temp_files)
return self

View File

@ -4,26 +4,33 @@ each component. These demos are located in the `demo` directory."""
from __future__ import annotations
import hashlib
import inspect
import json
import math
import operator
import random
import secrets
import shutil
import tempfile
import urllib.request
import uuid
import warnings
from copy import deepcopy
from enum import Enum
from pathlib import Path
from types import ModuleType
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Set, Tuple, Type
import aiofiles
import altair as alt
import matplotlib.figure
import numpy as np
import pandas as pd
import PIL
import PIL.ImageOps
import requests
from fastapi import UploadFile
from ffmpy import FFmpeg
from pandas.api.types import is_numeric_dtype
from PIL import Image as _Image # using _ to minimize namespace pollution
@ -49,7 +56,6 @@ from gradio.events import (
)
from gradio.interpretation import NeighborInterpretable, TokenInterpretable
from gradio.layouts import Column, Form, Row
from gradio.processing_utils import TempFileManager
from gradio.serializing import (
FileSerializable,
ImgSerializable,
@ -177,6 +183,9 @@ class IOComponent(Component, Serializable):
every: float | None = None,
**kwargs,
):
self.temp_files: Set[str] = set()
self.DEFAULT_TEMP_DIR = tempfile.gettempdir()
Component.__init__(
self, elem_id=elem_id, elem_classes=elem_classes, visible=visible, **kwargs
)
@ -198,6 +207,120 @@ class IOComponent(Component, Serializable):
if callable(load_fn):
self.attach_load_event(load_fn, every)
def hash_file(self, file_path: str, chunk_num_blocks: int = 128) -> str:
sha1 = hashlib.sha1()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(chunk_num_blocks * sha1.block_size), b""):
sha1.update(chunk)
return sha1.hexdigest()
def hash_url(self, url: str, chunk_num_blocks: int = 128) -> str:
sha1 = hashlib.sha1()
remote = urllib.request.urlopen(url)
max_file_size = 100 * 1024 * 1024 # 100MB
total_read = 0
while True:
data = remote.read(chunk_num_blocks * sha1.block_size)
total_read += chunk_num_blocks * sha1.block_size
if not data or total_read > max_file_size:
break
sha1.update(data)
return sha1.hexdigest()
def hash_base64(self, base64_encoding: str, chunk_num_blocks: int = 128) -> str:
sha1 = hashlib.sha1()
for i in range(0, len(base64_encoding), chunk_num_blocks * sha1.block_size):
data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size]
sha1.update(data.encode("utf-8"))
return sha1.hexdigest()
def make_temp_copy_if_needed(self, file_path: str) -> str:
"""Returns a temporary file path for a copy of the given file path if it does
not already exist. Otherwise returns the path to the existing temp file."""
temp_dir = self.hash_file(file_path)
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)
f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
f.name = utils.strip_invalid_filename_characters(Path(file_path).name)
full_temp_file_path = str(utils.abspath(temp_dir / f.name))
if not Path(full_temp_file_path).exists():
shutil.copy2(file_path, full_temp_file_path)
self.temp_files.add(full_temp_file_path)
return full_temp_file_path
async def save_uploaded_file(self, file: UploadFile, upload_dir: str) -> str:
temp_dir = secrets.token_hex(
20
) # Since the full file is being uploaded anyways, there is no benefit to hashing the file.
temp_dir = Path(upload_dir) / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)
output_file_obj = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
if file.filename:
file_name = Path(file.filename).name
output_file_obj.name = utils.strip_invalid_filename_characters(file_name)
full_temp_file_path = str(utils.abspath(temp_dir / output_file_obj.name))
async with aiofiles.open(full_temp_file_path, "wb") as output_file:
while True:
content = await file.read(100 * 1024 * 1024)
if not content:
break
await output_file.write(content)
return full_temp_file_path
def download_temp_copy_if_needed(self, url: str) -> str:
"""Downloads a file and makes a temporary file path for a copy if does not already
exist. Otherwise returns the path to the existing temp file."""
temp_dir = self.hash_url(url)
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)
f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
f.name = utils.strip_invalid_filename_characters(Path(url).name)
full_temp_file_path = str(utils.abspath(temp_dir / f.name))
if not Path(full_temp_file_path).exists():
with requests.get(url, stream=True) as r:
with open(full_temp_file_path, "wb") as f:
shutil.copyfileobj(r.raw, f)
self.temp_files.add(full_temp_file_path)
return full_temp_file_path
def base64_to_temp_file_if_needed(
self, base64_encoding: str, file_name: str | None = None
) -> str:
"""Converts a base64 encoding to a file and returns the path to the file if
the file doesn't already exist. Otherwise returns the path to the existing file."""
temp_dir = self.hash_base64(base64_encoding)
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)
guess_extension = processing_utils.get_extension(base64_encoding)
if file_name:
file_name = utils.strip_invalid_filename_characters(file_name)
elif guess_extension:
file_name = "file." + guess_extension
else:
file_name = "file"
f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
f.name = file_name
full_temp_file_path = str(utils.abspath(temp_dir / f.name))
if not Path(full_temp_file_path).exists():
data, _ = processing_utils.decode_base64_to_binary(base64_encoding)
with open(full_temp_file_path, "wb") as fb:
fb.write(data)
self.temp_files.add(full_temp_file_path)
return full_temp_file_path
def get_config(self):
config = {
"label": self.label,
@ -1565,7 +1688,7 @@ class Image(
suffix=("." + fmt.lower() if fmt is not None else ".png"),
)
im.save(file_obj.name)
return file_obj.name
return self.make_temp_copy_if_needed(file_obj.name)
else:
raise ValueError(
"Unknown type: "
@ -1797,7 +1920,6 @@ class Video(
Uploadable,
IOComponent,
FileSerializable,
TempFileManager,
):
"""
Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output).
@ -1854,7 +1976,6 @@ class Video(
self.include_audio = (
include_audio if include_audio is not None else source == "upload"
)
TempFileManager.__init__(self)
IOComponent.__init__(
self,
label=label,
@ -2016,7 +2137,6 @@ class Audio(
Uploadable,
IOComponent,
FileSerializable,
TempFileManager,
TokenInterpretable,
):
"""
@ -2076,7 +2196,6 @@ class Audio(
raise ValueError(
"Audio streaming only available if source is 'microphone'."
)
TempFileManager.__init__(self)
IOComponent.__init__(
self,
label=label,
@ -2326,7 +2445,6 @@ class File(
Uploadable,
IOComponent,
FileSerializable,
TempFileManager,
):
"""
Creates a file component that allows uploading generic file (when used as an input) and or displaying generic files (output).
@ -2397,7 +2515,6 @@ class File(
Uses event data gradio.SelectData to carry `value` referring to name of selected file, and `index` to refer to index.
See EventData documentation on how to use this event data.
"""
TempFileManager.__init__(self)
IOComponent.__init__(
self,
label=label,
@ -3100,9 +3217,7 @@ class Button(Clickable, IOComponent, SimpleSerializable):
@document("style")
class UploadButton(
Clickable, Uploadable, IOComponent, FileSerializable, TempFileManager
):
class UploadButton(Clickable, Uploadable, IOComponent, FileSerializable):
"""
Used to create an upload button, when cicked allows a user to upload files that satisfy the specified file type or generic files (if file_type not set).
Preprocessing: passes the uploaded file as a {file-object} or {List[file-object]} depending on `file_count` (or a {bytes}/{List{bytes}} depending on `type`)
@ -3147,7 +3262,6 @@ class UploadButton(
)
self.file_types = file_types
self.label = label
TempFileManager.__init__(self)
IOComponent.__init__(
self,
label=label,
@ -3855,7 +3969,7 @@ class HTML(Changeable, IOComponent, SimpleSerializable):
@document("style")
class Gallery(IOComponent, TempFileManager, FileSerializable, Selectable):
class Gallery(IOComponent, FileSerializable, Selectable):
"""
Used to display a list of images as a gallery that can be scrolled through.
Preprocessing: this component does *not* accept input.
@ -3892,7 +4006,6 @@ class Gallery(IOComponent, TempFileManager, FileSerializable, Selectable):
Uses event data gradio.SelectData to carry `value` referring to caption of selected image, and `index` to refer to index.
See EventData documentation on how to use this event data.
"""
TempFileManager.__init__(self)
IOComponent.__init__(
self,
label=label,
@ -4055,14 +4168,16 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
"""
Displays a chatbot output showing both user submitted messages and responses. Supports a subset of Markdown including bold, italics, code, and images.
Preprocessing: this component does *not* accept input.
Postprocessing: expects function to return a {List[Tuple[str | None | Tuple, str | None | Tuple]]}, a list of tuples with user message and response messages. Messages should be strings, tuples, or Nones. If the message is a string, it can include Markdown. If it is a tuple, it should consist of (string filepath to image/video/audio, [optional string alt text]). Messages that are `None` are not displayed.
Postprocessing: expects function to return a {List[List[str | None | Tuple]]}, a list of lists. The inner list should have 2 elements: the user message and the response message. Messages should be strings, tuples, or Nones. If the message is a string, it can include Markdown. If it is a tuple, it should consist of (string filepath to image/video/audio, [optional string alt text]). Messages that are `None` are not displayed.
Demos: chatbot_simple, chatbot_multimodal
"""
def __init__(
self,
value: List[Tuple[str | None, str | None]] | Callable | None = None,
value: List[List[str | Tuple[str] | Tuple[str, str] | None]]
| Callable
| None = None,
color_map: Dict[str, str] | None = None, # Parameter moved to Chatbot.style()
*,
label: str | None = None,
@ -4116,7 +4231,9 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
@staticmethod
def update(
value: Any | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE,
value: List[List[str | Tuple[str] | Tuple[str, str] | None]]
| Literal[_Keywords.NO_VALUE]
| None = _Keywords.NO_VALUE,
label: str | None = None,
show_label: bool | None = None,
visible: bool | None = None,
@ -4130,24 +4247,57 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
}
return updated_config
def _process_chat_messages(
self, chat_message: str | Tuple | List | Dict | None
def _preprocess_chat_messages(
self, chat_message: str | Dict | None
) -> str | Tuple[str] | Tuple[str, str] | None:
if chat_message is None:
return None
elif isinstance(chat_message, dict):
if chat_message["alt_text"] is not None:
return (chat_message["name"], chat_message["alt_text"])
else:
return (chat_message["name"],)
else: # string
return chat_message
def preprocess(
self,
y: List[List[str | Dict | None] | Tuple[str | Dict | None, str | Dict | None]],
) -> List[List[str | Tuple[str] | Tuple[str, str] | None]]:
if y is None:
return y
processed_messages = []
for message_pair in y:
assert isinstance(
message_pair, (tuple, list)
), f"Expected a list of lists or list of tuples. Received: {message_pair}"
assert (
len(message_pair) == 2
), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
processed_messages.append(
[
self._preprocess_chat_messages(message_pair[0]),
self._preprocess_chat_messages(message_pair[1]),
]
)
return processed_messages
def _postprocess_chat_messages(
self, chat_message: str | Tuple | List | None
) -> str | Dict | None:
if chat_message is None:
return None
elif isinstance(chat_message, (tuple, list)):
mime_type = processing_utils.get_mimetype(chat_message[0])
filepath = chat_message[0]
mime_type = processing_utils.get_mimetype(filepath)
filepath = self.make_temp_copy_if_needed(filepath)
return {
"name": chat_message[0],
"name": filepath,
"mime_type": mime_type,
"alt_text": chat_message[1] if len(chat_message) > 1 else None,
"data": None, # These last two fields are filled in by the frontend
"is_file": True,
}
elif isinstance(
chat_message, dict
): # This happens for previously processed messages
return chat_message
elif isinstance(chat_message, str):
return self.md.renderInline(chat_message)
else:
@ -4155,15 +4305,13 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
def postprocess(
self,
y: List[
Tuple[str | Tuple | List | Dict | None, str | Tuple | List | Dict | None]
],
) -> List[Tuple[str | Dict | None, str | Dict | None]]:
y: List[List[str | Tuple[str] | Tuple[str, str] | None] | Tuple],
) -> List[List[str | Dict | None]]:
"""
Parameters:
y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format. It can also be a tuple whose first element is a string filepath or URL to an image/video/audio, and second (optional) element is the alt text, in which case the media file is displayed. It can also be None, in which case that message is not displayed.
y: List of lists representing the message and response pairs. Each message and response should be a string, which may be in Markdown format. It can also be a tuple whose first element is a string filepath or URL to an image/video/audio, and second (optional) element is the alt text, in which case the media file is displayed. It can also be None, in which case that message is not displayed.
Returns:
List of tuples representing the message and response. Each message and response will be a string of HTML, or a dictionary with media information.
List of lists representing the message and response. Each message and response will be a string of HTML, or a dictionary with media information. Or None if the message is not to be displayed.
"""
if y is None:
return []
@ -4176,10 +4324,10 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
len(message_pair) == 2
), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
processed_messages.append(
(
self._process_chat_messages(message_pair[0]),
self._process_chat_messages(message_pair[1]),
)
[
self._postprocess_chat_messages(message_pair[0]),
self._postprocess_chat_messages(message_pair[1]),
]
)
return processed_messages
@ -4200,9 +4348,7 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
@document("style")
class Model3D(
Changeable, Editable, Clearable, IOComponent, FileSerializable, TempFileManager
):
class Model3D(Changeable, Editable, Clearable, IOComponent, FileSerializable):
"""
Component allows users to upload or view 3D Model files (.obj, .glb, or .gltf).
Preprocessing: This component passes the uploaded file as a {str} filepath.
@ -4237,7 +4383,6 @@ class Model3D(
elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.
"""
self.clear_color = clear_color or [0, 0, 0, 0]
TempFileManager.__init__(self)
IOComponent.__init__(
self,
label=label,
@ -5549,7 +5694,7 @@ class Code(Changeable, IOComponent, SimpleSerializable):
def __init__(
self,
value: str | None = None,
value: str | Tuple[str] | None = None,
language: str | None = None,
*,
label: str | None = None,
@ -5604,7 +5749,10 @@ class Code(Changeable, IOComponent, SimpleSerializable):
@staticmethod
def update(
value: str | None | Literal[_Keywords.NO_VALUE] = _Keywords.NO_VALUE,
value: str
| Tuple[str]
| None
| Literal[_Keywords.NO_VALUE] = _Keywords.NO_VALUE,
label: str | None = None,
show_label: bool | None = None,
visible: bool | None = None,

View File

@ -1,24 +1,19 @@
from __future__ import annotations
import base64
import hashlib
import json
import mimetypes
import os
import secrets
import shutil
import subprocess
import tempfile
import urllib.request
import warnings
from io import BytesIO
from pathlib import Path
from typing import Dict, Set, Tuple
from typing import Dict, Tuple
import aiofiles
import numpy as np
import requests
from fastapi import UploadFile
from ffmpy import FFmpeg, FFprobe, FFRuntimeError
from PIL import Image, ImageOps, PngImagePlugin
@ -318,132 +313,6 @@ def file_to_json(file_path: str | Path) -> Dict:
return json.load(f)
class TempFileManager:
"""
A class that should be inherited by any Component that needs to manage temporary files.
It should be instantiated in the __init__ method of the component.
"""
def __init__(self) -> None:
# Set stores all the temporary files created by this component.
self.temp_files: Set[str] = set()
self.DEFAULT_TEMP_DIR = tempfile.gettempdir()
def hash_file(self, file_path: str, chunk_num_blocks: int = 128) -> str:
sha1 = hashlib.sha1()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(chunk_num_blocks * sha1.block_size), b""):
sha1.update(chunk)
return sha1.hexdigest()
def hash_url(self, url: str, chunk_num_blocks: int = 128) -> str:
sha1 = hashlib.sha1()
remote = urllib.request.urlopen(url)
max_file_size = 100 * 1024 * 1024 # 100MB
total_read = 0
while True:
data = remote.read(chunk_num_blocks * sha1.block_size)
total_read += chunk_num_blocks * sha1.block_size
if not data or total_read > max_file_size:
break
sha1.update(data)
return sha1.hexdigest()
def hash_base64(self, base64_encoding: str, chunk_num_blocks: int = 128) -> str:
sha1 = hashlib.sha1()
for i in range(0, len(base64_encoding), chunk_num_blocks * sha1.block_size):
data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size]
sha1.update(data.encode("utf-8"))
return sha1.hexdigest()
def make_temp_copy_if_needed(self, file_path: str) -> str:
"""Returns a temporary file path for a copy of the given file path if it does
not already exist. Otherwise returns the path to the existing temp file."""
temp_dir = self.hash_file(file_path)
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)
f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
f.name = utils.strip_invalid_filename_characters(Path(file_path).name)
full_temp_file_path = str(utils.abspath(temp_dir / f.name))
if not Path(full_temp_file_path).exists():
shutil.copy2(file_path, full_temp_file_path)
self.temp_files.add(full_temp_file_path)
return full_temp_file_path
async def save_uploaded_file(self, file: UploadFile, upload_dir: str) -> str:
temp_dir = secrets.token_hex(
20
) # Since the full file is being uploaded anyways, there is no benefit to hashing the file.
temp_dir = Path(upload_dir) / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)
output_file_obj = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
if file.filename:
file_name = Path(file.filename).name
output_file_obj.name = utils.strip_invalid_filename_characters(file_name)
full_temp_file_path = str(utils.abspath(temp_dir / output_file_obj.name))
async with aiofiles.open(full_temp_file_path, "wb") as output_file:
while True:
content = await file.read(100 * 1024 * 1024)
if not content:
break
await output_file.write(content)
return full_temp_file_path
def download_temp_copy_if_needed(self, url: str) -> str:
"""Downloads a file and makes a temporary file path for a copy if does not already
exist. Otherwise returns the path to the existing temp file."""
temp_dir = self.hash_url(url)
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)
f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
f.name = utils.strip_invalid_filename_characters(Path(url).name)
full_temp_file_path = str(utils.abspath(temp_dir / f.name))
if not Path(full_temp_file_path).exists():
with requests.get(url, stream=True) as r:
with open(full_temp_file_path, "wb") as f:
shutil.copyfileobj(r.raw, f)
self.temp_files.add(full_temp_file_path)
return full_temp_file_path
def base64_to_temp_file_if_needed(
self, base64_encoding: str, file_name: str | None = None
) -> str:
"""Converts a base64 encoding to a file and returns the path to the file if
the file doesn't already exist. Otherwise returns the path to the existing file."""
temp_dir = self.hash_base64(base64_encoding)
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)
guess_extension = get_extension(base64_encoding)
if file_name:
file_name = utils.strip_invalid_filename_characters(file_name)
elif guess_extension:
file_name = "file." + guess_extension
else:
file_name = "file"
f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
f.name = file_name
full_temp_file_path = str(utils.abspath(temp_dir / f.name))
if not Path(full_temp_file_path).exists():
data, _ = decode_base64_to_binary(base64_encoding)
with open(full_temp_file_path, "wb") as fb:
fb.write(data)
self.temp_files.add(full_temp_file_path)
return full_temp_file_path
def download_tmp_copy_of_file(
url_path: str, access_token: str | None = None, dir: str | None = None
) -> tempfile._TemporaryFileWrapper:

View File

@ -46,7 +46,6 @@ from gradio.data_classes import PredictBody, ResetBody
from gradio.documentation import document, set_documentation_group
from gradio.exceptions import Error
from gradio.helpers import EventData
from gradio.processing_utils import TempFileManager
from gradio.queueing import Estimation, Event
from gradio.utils import cancel_tasks, run_coro_in_background, set_task_name
@ -548,7 +547,7 @@ class App(FastAPI):
files: List[UploadFile] = File(...),
):
output_files = []
file_manager = TempFileManager()
file_manager = gradio.File()
for input_file in files:
output_files.append(
await file_manager.save_uploaded_file(

View File

@ -623,6 +623,9 @@ class TestImage:
assert image_input.preprocess(img).size == (30, 10)
assert image_input.postprocess("test/test_files/bus.png") == img
assert image_input.serialize("test/test_files/bus.png") == img
image_input = gr.Image(type="filepath")
image_temp_filepath = image_input.preprocess(img)
assert image_temp_filepath in image_input.temp_files
image_input = gr.Image(
source="upload", tool="editor", type="pil", label="Upload Your Image"
@ -1692,49 +1695,66 @@ class TestChatbot:
Postprocess, get_config
"""
chatbot = gr.Chatbot()
assert chatbot.postprocess([("You are **cool**", "so are *you*")]) == [
("You are <strong>cool</strong>", "so are <em>you</em>")
assert chatbot.postprocess([["You are **cool**", "so are *you*"]]) == [
["You are <strong>cool</strong>", "so are <em>you</em>"]
]
multimodal_msg = [
(("driving.mp4",), "cool video"),
(("cantina.wav",), "cool audio"),
(("lion.jpg", "A lion"), "cool pic"),
[("test/test_files/video_sample.mp4",), "cool video"],
[("test/test_files/audio_sample.wav",), "cool audio"],
[("test/test_files/bus.png", "A bus"), "cool pic"],
]
processed_multimodal_msg = [
(
[
{
"name": "driving.mp4",
"name": "video_sample.mp4",
"mime_type": "video/mp4",
"alt_text": None,
"data": None,
"is_file": True,
},
"cool video",
),
(
],
[
{
"name": "cantina.wav",
"name": "audio_sample.wav",
"mime_type": "audio/wav",
"alt_text": None,
"data": None,
"is_file": True,
},
"cool audio",
),
(
],
[
{
"name": "lion.jpg",
"mime_type": "image/jpeg",
"alt_text": "A lion",
"name": "bus.png",
"mime_type": "image/png",
"alt_text": "A bus",
"data": None,
"is_file": True,
},
"cool pic",
),
],
]
postprocessed_multimodal_msg = chatbot.postprocess(multimodal_msg)
postprocessed_multimodal_msg_base_names = []
for x, y in postprocessed_multimodal_msg:
if isinstance(x, dict):
x["name"] = os.path.basename(x["name"])
postprocessed_multimodal_msg_base_names.append([x, y])
assert postprocessed_multimodal_msg_base_names == processed_multimodal_msg
preprocessed_multimodal_msg = chatbot.preprocess(processed_multimodal_msg)
multimodal_msg_base_names = []
for x, y in multimodal_msg:
if isinstance(x, tuple):
if len(x) > 1:
new_x = (os.path.basename(x[0]), x[1])
else:
new_x = (os.path.basename(x[0]),)
multimodal_msg_base_names.append([new_x, y])
assert multimodal_msg_base_names == preprocessed_multimodal_msg
assert chatbot.postprocess(multimodal_msg) == processed_multimodal_msg
assert chatbot.get_config() == {
"value": [],
"label": None,
@ -2606,3 +2626,84 @@ class TestCode:
"interactive": None,
"root_url": None,
}
class TestTempFileManagement:
def test_hash_file(self):
temp_file_manager = gr.File()
h1 = temp_file_manager.hash_file("gradio/test_data/cheetah1.jpg")
h2 = temp_file_manager.hash_file("gradio/test_data/cheetah1-copy.jpg")
h3 = temp_file_manager.hash_file("gradio/test_data/cheetah2.jpg")
assert h1 == h2
assert h1 != h3
@patch("shutil.copy2")
def test_make_temp_copy_if_needed(self, mock_copy):
temp_file_manager = gr.File()
f = temp_file_manager.make_temp_copy_if_needed("gradio/test_data/cheetah1.jpg")
try: # Delete if already exists from before this test
os.remove(f)
except OSError:
pass
f = temp_file_manager.make_temp_copy_if_needed("gradio/test_data/cheetah1.jpg")
assert mock_copy.called
assert len(temp_file_manager.temp_files) == 1
assert Path(f).name == "cheetah1.jpg"
f = temp_file_manager.make_temp_copy_if_needed("gradio/test_data/cheetah1.jpg")
assert len(temp_file_manager.temp_files) == 1
f = temp_file_manager.make_temp_copy_if_needed(
"gradio/test_data/cheetah1-copy.jpg"
)
assert len(temp_file_manager.temp_files) == 2
assert Path(f).name == "cheetah1-copy.jpg"
def test_base64_to_temp_file_if_needed(self):
temp_file_manager = gr.File()
base64_file_1 = media_data.BASE64_IMAGE
base64_file_2 = media_data.BASE64_AUDIO["data"]
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
try: # Delete if already exists from before this test
os.remove(f)
except OSError:
pass
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
assert len(temp_file_manager.temp_files) == 1
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
assert len(temp_file_manager.temp_files) == 1
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_2)
assert len(temp_file_manager.temp_files) == 2
for file in temp_file_manager.temp_files:
os.remove(file)
@pytest.mark.flaky
@patch("shutil.copyfileobj")
def test_download_temp_copy_if_needed(self, mock_copy):
temp_file_manager = gr.File()
url1 = "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/test_image.png"
url2 = "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/cheetah1.jpg"
f = temp_file_manager.download_temp_copy_if_needed(url1)
try: # Delete if already exists from before this test
os.remove(f)
except OSError:
pass
f = temp_file_manager.download_temp_copy_if_needed(url1)
assert mock_copy.called
assert len(temp_file_manager.temp_files) == 1
f = temp_file_manager.download_temp_copy_if_needed(url1)
assert len(temp_file_manager.temp_files) == 1
f = temp_file_manager.download_temp_copy_if_needed(url2)
assert len(temp_file_manager.temp_files) == 2

View File

@ -134,87 +134,6 @@ class TestAudioPreprocessing:
assert audio_.dtype == "int16"
class TestTempFileManager:
def test_hash_file(self):
temp_file_manager = processing_utils.TempFileManager()
h1 = temp_file_manager.hash_file("gradio/test_data/cheetah1.jpg")
h2 = temp_file_manager.hash_file("gradio/test_data/cheetah1-copy.jpg")
h3 = temp_file_manager.hash_file("gradio/test_data/cheetah2.jpg")
assert h1 == h2
assert h1 != h3
@patch("shutil.copy2")
def test_make_temp_copy_if_needed(self, mock_copy):
temp_file_manager = processing_utils.TempFileManager()
f = temp_file_manager.make_temp_copy_if_needed("gradio/test_data/cheetah1.jpg")
try: # Delete if already exists from before this test
os.remove(f)
except OSError:
pass
f = temp_file_manager.make_temp_copy_if_needed("gradio/test_data/cheetah1.jpg")
assert mock_copy.called
assert len(temp_file_manager.temp_files) == 1
assert Path(f).name == "cheetah1.jpg"
f = temp_file_manager.make_temp_copy_if_needed("gradio/test_data/cheetah1.jpg")
assert len(temp_file_manager.temp_files) == 1
f = temp_file_manager.make_temp_copy_if_needed(
"gradio/test_data/cheetah1-copy.jpg"
)
assert len(temp_file_manager.temp_files) == 2
assert Path(f).name == "cheetah1-copy.jpg"
def test_base64_to_temp_file_if_needed(self):
temp_file_manager = processing_utils.TempFileManager()
base64_file_1 = media_data.BASE64_IMAGE
base64_file_2 = media_data.BASE64_AUDIO["data"]
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
try: # Delete if already exists from before this test
os.remove(f)
except OSError:
pass
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
assert len(temp_file_manager.temp_files) == 1
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1)
assert len(temp_file_manager.temp_files) == 1
f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_2)
assert len(temp_file_manager.temp_files) == 2
for file in temp_file_manager.temp_files:
os.remove(file)
@pytest.mark.flaky
@patch("shutil.copyfileobj")
def test_download_temp_copy_if_needed(self, mock_copy):
temp_file_manager = processing_utils.TempFileManager()
url1 = "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/test_image.png"
url2 = "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/cheetah1.jpg"
f = temp_file_manager.download_temp_copy_if_needed(url1)
try: # Delete if already exists from before this test
os.remove(f)
except OSError:
pass
f = temp_file_manager.download_temp_copy_if_needed(url1)
assert mock_copy.called
assert len(temp_file_manager.temp_files) == 1
f = temp_file_manager.download_temp_copy_if_needed(url1)
assert len(temp_file_manager.temp_files) == 1
f = temp_file_manager.download_temp_copy_if_needed(url2)
assert len(temp_file_manager.temp_files) == 2
class TestOutputPreprocessing:
def test_decode_base64_to_binary(self):
binary = processing_utils.decode_base64_to_binary(