From 88afd684bec2c40bcdb94a1cbe151ffbea1e69d1 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Tue, 21 Mar 2023 09:37:24 -0700 Subject: [PATCH] Makes some fixes related to TempFiles (#3523) * temp file fixes * changes * fixing tests * formatting * fix * fix chatbot processing * tests * get tests to pass * fix code demo * changelog * fix multimodal --- CHANGELOG.md | 2 +- demo/code/run.ipynb | 2 +- demo/code/run.py | 2 +- gradio/blocks.py | 2 +- gradio/components.py | 232 ++++++++++++++++++++++++++++------ gradio/processing_utils.py | 133 +------------------ gradio/routes.py | 3 +- test/test_components.py | 135 +++++++++++++++++--- test/test_processing_utils.py | 81 ------------ 9 files changed, 314 insertions(+), 278 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 761a876c1f..726438337e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ No changes to highlight. ## Bug Fixes: -No changes to highlight. +- Fixes `Chatbot` and `Image` components so that files passed during processing are added to a directory where they can be served from, by [@abidlabs](https://github.com/abidlabs) in [PR 3523](https://github.com/gradio-app/gradio/pull/3523) ## Documentation Changes: diff --git a/demo/code/run.ipynb b/demo/code/run.ipynb index 352c322437..e5bd0e340e 100644 --- a/demo/code/run.ipynb +++ b/demo/code/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: code"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/code/file.css"]}, {"cell_type": "code", "execution_count": null, "id": 44380577570523278879349135829904343037, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "from time import sleep\n", "\n", "\n", "css_file = os.path.join(os.path.abspath(''), \"file.css\")\n", "\n", "\n", "def set_lang(language):\n", " print(language)\n", " return gr.Code.update(language=language)\n", "\n", "\n", "def set_lang_from_path():\n", " sleep(1)\n", " return gr.Code.update(css_file, language=\"css\")\n", "\n", "\n", "def code(language, code):\n", " return gr.Code.update(code, language=language)\n", "\n", "\n", "io = gr.Interface(lambda x: x, \"code\", \"code\")\n", "\n", "with gr.Blocks() as demo:\n", " lang = gr.Dropdown(value=\"python\", choices=gr.Code.languages)\n", " with gr.Row():\n", " code_in = gr.Code(language=\"python\", label=\"Input\")\n", " code_out = gr.Code(label=\"Ouput\")\n", " btn = gr.Button(\"Run\")\n", " btn_two = gr.Button(\"Load File\")\n", "\n", " lang.change(set_lang, inputs=lang, outputs=code_in)\n", " btn.click(code, inputs=[lang, code_in], outputs=code_out)\n", " btn_two.click(set_lang_from_path, inputs=None, outputs=code_out)\n", " io.render()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: code"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/code/file.css"]}, {"cell_type": "code", "execution_count": null, "id": 44380577570523278879349135829904343037, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "from time import sleep\n", "\n", "\n", "css_file = os.path.join(os.path.abspath(''), \"file.css\")\n", "\n", "\n", "def set_lang(language):\n", " print(language)\n", " return gr.Code.update(language=language)\n", "\n", "\n", "def set_lang_from_path():\n", " sleep(1)\n", " return gr.Code.update((css_file, ), language=\"css\")\n", "\n", "\n", "def code(language, code):\n", " return gr.Code.update(code, language=language)\n", "\n", "\n", "io = gr.Interface(lambda x: x, \"code\", \"code\")\n", "\n", "with gr.Blocks() as demo:\n", " lang = gr.Dropdown(value=\"python\", choices=gr.Code.languages)\n", " with gr.Row():\n", " code_in = gr.Code(language=\"python\", label=\"Input\")\n", " code_out = gr.Code(label=\"Ouput\")\n", " btn = gr.Button(\"Run\")\n", " btn_two = gr.Button(\"Load File\")\n", "\n", " lang.change(set_lang, inputs=lang, outputs=code_in)\n", " btn.click(code, inputs=[lang, code_in], outputs=code_out)\n", " btn_two.click(set_lang_from_path, inputs=None, outputs=code_out)\n", " io.render()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/code/run.py b/demo/code/run.py index 77663918fe..cbbe334204 100644 --- a/demo/code/run.py +++ b/demo/code/run.py @@ -13,7 +13,7 @@ def set_lang(language): def set_lang_from_path(): sleep(1) - return gr.Code.update(css_file, language="css") + return gr.Code.update((css_file, ), language="css") def code(language, code): diff --git a/gradio/blocks.py b/gradio/blocks.py index ad009cf57f..612d3671b7 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -89,7 +89,7 @@ class Block: Context.block.add(self) if Context.root_block is not None: Context.root_block.blocks[self._id] = self - if isinstance(self, components.TempFileManager): + if isinstance(self, components.IOComponent): Context.root_block.temp_file_sets.append(self.temp_files) return self diff --git a/gradio/components.py b/gradio/components.py index c3264c46ca..88d2670465 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -4,26 +4,33 @@ each component. These demos are located in the `demo` directory.""" from __future__ import annotations +import hashlib import inspect import json import math import operator import random +import secrets +import shutil import tempfile +import urllib.request import uuid import warnings from copy import deepcopy from enum import Enum from pathlib import Path from types import ModuleType -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Set, Tuple, Type +import aiofiles import altair as alt import matplotlib.figure import numpy as np import pandas as pd import PIL import PIL.ImageOps +import requests +from fastapi import UploadFile from ffmpy import FFmpeg from pandas.api.types import is_numeric_dtype from PIL import Image as _Image # using _ to minimize namespace pollution @@ -49,7 +56,6 @@ from gradio.events import ( ) from gradio.interpretation import NeighborInterpretable, TokenInterpretable from gradio.layouts import Column, Form, Row -from gradio.processing_utils import TempFileManager from gradio.serializing import ( FileSerializable, ImgSerializable, @@ -177,6 +183,9 @@ class IOComponent(Component, Serializable): every: float | None = None, **kwargs, ): + self.temp_files: Set[str] = set() + self.DEFAULT_TEMP_DIR = tempfile.gettempdir() + Component.__init__( self, elem_id=elem_id, elem_classes=elem_classes, visible=visible, **kwargs ) @@ -198,6 +207,120 @@ class IOComponent(Component, Serializable): if callable(load_fn): self.attach_load_event(load_fn, every) + def hash_file(self, file_path: str, chunk_num_blocks: int = 128) -> str: + sha1 = hashlib.sha1() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(chunk_num_blocks * sha1.block_size), b""): + sha1.update(chunk) + return sha1.hexdigest() + + def hash_url(self, url: str, chunk_num_blocks: int = 128) -> str: + sha1 = hashlib.sha1() + remote = urllib.request.urlopen(url) + max_file_size = 100 * 1024 * 1024 # 100MB + total_read = 0 + while True: + data = remote.read(chunk_num_blocks * sha1.block_size) + total_read += chunk_num_blocks * sha1.block_size + if not data or total_read > max_file_size: + break + sha1.update(data) + return sha1.hexdigest() + + def hash_base64(self, base64_encoding: str, chunk_num_blocks: int = 128) -> str: + sha1 = hashlib.sha1() + for i in range(0, len(base64_encoding), chunk_num_blocks * sha1.block_size): + data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size] + sha1.update(data.encode("utf-8")) + return sha1.hexdigest() + + def make_temp_copy_if_needed(self, file_path: str) -> str: + """Returns a temporary file path for a copy of the given file path if it does + not already exist. Otherwise returns the path to the existing temp file.""" + temp_dir = self.hash_file(file_path) + temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir + temp_dir.mkdir(exist_ok=True, parents=True) + + f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) + f.name = utils.strip_invalid_filename_characters(Path(file_path).name) + full_temp_file_path = str(utils.abspath(temp_dir / f.name)) + + if not Path(full_temp_file_path).exists(): + shutil.copy2(file_path, full_temp_file_path) + + self.temp_files.add(full_temp_file_path) + return full_temp_file_path + + async def save_uploaded_file(self, file: UploadFile, upload_dir: str) -> str: + temp_dir = secrets.token_hex( + 20 + ) # Since the full file is being uploaded anyways, there is no benefit to hashing the file. + temp_dir = Path(upload_dir) / temp_dir + temp_dir.mkdir(exist_ok=True, parents=True) + output_file_obj = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) + + if file.filename: + file_name = Path(file.filename).name + output_file_obj.name = utils.strip_invalid_filename_characters(file_name) + + full_temp_file_path = str(utils.abspath(temp_dir / output_file_obj.name)) + + async with aiofiles.open(full_temp_file_path, "wb") as output_file: + while True: + content = await file.read(100 * 1024 * 1024) + if not content: + break + await output_file.write(content) + + return full_temp_file_path + + def download_temp_copy_if_needed(self, url: str) -> str: + """Downloads a file and makes a temporary file path for a copy if does not already + exist. Otherwise returns the path to the existing temp file.""" + temp_dir = self.hash_url(url) + temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir + temp_dir.mkdir(exist_ok=True, parents=True) + f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) + + f.name = utils.strip_invalid_filename_characters(Path(url).name) + full_temp_file_path = str(utils.abspath(temp_dir / f.name)) + + if not Path(full_temp_file_path).exists(): + with requests.get(url, stream=True) as r: + with open(full_temp_file_path, "wb") as f: + shutil.copyfileobj(r.raw, f) + + self.temp_files.add(full_temp_file_path) + return full_temp_file_path + + def base64_to_temp_file_if_needed( + self, base64_encoding: str, file_name: str | None = None + ) -> str: + """Converts a base64 encoding to a file and returns the path to the file if + the file doesn't already exist. Otherwise returns the path to the existing file.""" + temp_dir = self.hash_base64(base64_encoding) + temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir + temp_dir.mkdir(exist_ok=True, parents=True) + + guess_extension = processing_utils.get_extension(base64_encoding) + if file_name: + file_name = utils.strip_invalid_filename_characters(file_name) + elif guess_extension: + file_name = "file." + guess_extension + else: + file_name = "file" + f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) + f.name = file_name + full_temp_file_path = str(utils.abspath(temp_dir / f.name)) + + if not Path(full_temp_file_path).exists(): + data, _ = processing_utils.decode_base64_to_binary(base64_encoding) + with open(full_temp_file_path, "wb") as fb: + fb.write(data) + + self.temp_files.add(full_temp_file_path) + return full_temp_file_path + def get_config(self): config = { "label": self.label, @@ -1565,7 +1688,7 @@ class Image( suffix=("." + fmt.lower() if fmt is not None else ".png"), ) im.save(file_obj.name) - return file_obj.name + return self.make_temp_copy_if_needed(file_obj.name) else: raise ValueError( "Unknown type: " @@ -1797,7 +1920,6 @@ class Video( Uploadable, IOComponent, FileSerializable, - TempFileManager, ): """ Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output). @@ -1854,7 +1976,6 @@ class Video( self.include_audio = ( include_audio if include_audio is not None else source == "upload" ) - TempFileManager.__init__(self) IOComponent.__init__( self, label=label, @@ -2016,7 +2137,6 @@ class Audio( Uploadable, IOComponent, FileSerializable, - TempFileManager, TokenInterpretable, ): """ @@ -2076,7 +2196,6 @@ class Audio( raise ValueError( "Audio streaming only available if source is 'microphone'." ) - TempFileManager.__init__(self) IOComponent.__init__( self, label=label, @@ -2326,7 +2445,6 @@ class File( Uploadable, IOComponent, FileSerializable, - TempFileManager, ): """ Creates a file component that allows uploading generic file (when used as an input) and or displaying generic files (output). @@ -2397,7 +2515,6 @@ class File( Uses event data gradio.SelectData to carry `value` referring to name of selected file, and `index` to refer to index. See EventData documentation on how to use this event data. """ - TempFileManager.__init__(self) IOComponent.__init__( self, label=label, @@ -3100,9 +3217,7 @@ class Button(Clickable, IOComponent, SimpleSerializable): @document("style") -class UploadButton( - Clickable, Uploadable, IOComponent, FileSerializable, TempFileManager -): +class UploadButton(Clickable, Uploadable, IOComponent, FileSerializable): """ Used to create an upload button, when cicked allows a user to upload files that satisfy the specified file type or generic files (if file_type not set). Preprocessing: passes the uploaded file as a {file-object} or {List[file-object]} depending on `file_count` (or a {bytes}/{List{bytes}} depending on `type`) @@ -3147,7 +3262,6 @@ class UploadButton( ) self.file_types = file_types self.label = label - TempFileManager.__init__(self) IOComponent.__init__( self, label=label, @@ -3855,7 +3969,7 @@ class HTML(Changeable, IOComponent, SimpleSerializable): @document("style") -class Gallery(IOComponent, TempFileManager, FileSerializable, Selectable): +class Gallery(IOComponent, FileSerializable, Selectable): """ Used to display a list of images as a gallery that can be scrolled through. Preprocessing: this component does *not* accept input. @@ -3892,7 +4006,6 @@ class Gallery(IOComponent, TempFileManager, FileSerializable, Selectable): Uses event data gradio.SelectData to carry `value` referring to caption of selected image, and `index` to refer to index. See EventData documentation on how to use this event data. """ - TempFileManager.__init__(self) IOComponent.__init__( self, label=label, @@ -4055,14 +4168,16 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable): """ Displays a chatbot output showing both user submitted messages and responses. Supports a subset of Markdown including bold, italics, code, and images. Preprocessing: this component does *not* accept input. - Postprocessing: expects function to return a {List[Tuple[str | None | Tuple, str | None | Tuple]]}, a list of tuples with user message and response messages. Messages should be strings, tuples, or Nones. If the message is a string, it can include Markdown. If it is a tuple, it should consist of (string filepath to image/video/audio, [optional string alt text]). Messages that are `None` are not displayed. + Postprocessing: expects function to return a {List[List[str | None | Tuple]]}, a list of lists. The inner list should have 2 elements: the user message and the response message. Messages should be strings, tuples, or Nones. If the message is a string, it can include Markdown. If it is a tuple, it should consist of (string filepath to image/video/audio, [optional string alt text]). Messages that are `None` are not displayed. Demos: chatbot_simple, chatbot_multimodal """ def __init__( self, - value: List[Tuple[str | None, str | None]] | Callable | None = None, + value: List[List[str | Tuple[str] | Tuple[str, str] | None]] + | Callable + | None = None, color_map: Dict[str, str] | None = None, # Parameter moved to Chatbot.style() *, label: str | None = None, @@ -4116,7 +4231,9 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable): @staticmethod def update( - value: Any | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE, + value: List[List[str | Tuple[str] | Tuple[str, str] | None]] + | Literal[_Keywords.NO_VALUE] + | None = _Keywords.NO_VALUE, label: str | None = None, show_label: bool | None = None, visible: bool | None = None, @@ -4130,24 +4247,57 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable): } return updated_config - def _process_chat_messages( - self, chat_message: str | Tuple | List | Dict | None + def _preprocess_chat_messages( + self, chat_message: str | Dict | None + ) -> str | Tuple[str] | Tuple[str, str] | None: + if chat_message is None: + return None + elif isinstance(chat_message, dict): + if chat_message["alt_text"] is not None: + return (chat_message["name"], chat_message["alt_text"]) + else: + return (chat_message["name"],) + else: # string + return chat_message + + def preprocess( + self, + y: List[List[str | Dict | None] | Tuple[str | Dict | None, str | Dict | None]], + ) -> List[List[str | Tuple[str] | Tuple[str, str] | None]]: + if y is None: + return y + processed_messages = [] + for message_pair in y: + assert isinstance( + message_pair, (tuple, list) + ), f"Expected a list of lists or list of tuples. Received: {message_pair}" + assert ( + len(message_pair) == 2 + ), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}" + processed_messages.append( + [ + self._preprocess_chat_messages(message_pair[0]), + self._preprocess_chat_messages(message_pair[1]), + ] + ) + return processed_messages + + def _postprocess_chat_messages( + self, chat_message: str | Tuple | List | None ) -> str | Dict | None: if chat_message is None: return None elif isinstance(chat_message, (tuple, list)): - mime_type = processing_utils.get_mimetype(chat_message[0]) + filepath = chat_message[0] + mime_type = processing_utils.get_mimetype(filepath) + filepath = self.make_temp_copy_if_needed(filepath) return { - "name": chat_message[0], + "name": filepath, "mime_type": mime_type, "alt_text": chat_message[1] if len(chat_message) > 1 else None, "data": None, # These last two fields are filled in by the frontend "is_file": True, } - elif isinstance( - chat_message, dict - ): # This happens for previously processed messages - return chat_message elif isinstance(chat_message, str): return self.md.renderInline(chat_message) else: @@ -4155,15 +4305,13 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable): def postprocess( self, - y: List[ - Tuple[str | Tuple | List | Dict | None, str | Tuple | List | Dict | None] - ], - ) -> List[Tuple[str | Dict | None, str | Dict | None]]: + y: List[List[str | Tuple[str] | Tuple[str, str] | None] | Tuple], + ) -> List[List[str | Dict | None]]: """ Parameters: - y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format. It can also be a tuple whose first element is a string filepath or URL to an image/video/audio, and second (optional) element is the alt text, in which case the media file is displayed. It can also be None, in which case that message is not displayed. + y: List of lists representing the message and response pairs. Each message and response should be a string, which may be in Markdown format. It can also be a tuple whose first element is a string filepath or URL to an image/video/audio, and second (optional) element is the alt text, in which case the media file is displayed. It can also be None, in which case that message is not displayed. Returns: - List of tuples representing the message and response. Each message and response will be a string of HTML, or a dictionary with media information. + List of lists representing the message and response. Each message and response will be a string of HTML, or a dictionary with media information. Or None if the message is not to be displayed. """ if y is None: return [] @@ -4176,10 +4324,10 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable): len(message_pair) == 2 ), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}" processed_messages.append( - ( - self._process_chat_messages(message_pair[0]), - self._process_chat_messages(message_pair[1]), - ) + [ + self._postprocess_chat_messages(message_pair[0]), + self._postprocess_chat_messages(message_pair[1]), + ] ) return processed_messages @@ -4200,9 +4348,7 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable): @document("style") -class Model3D( - Changeable, Editable, Clearable, IOComponent, FileSerializable, TempFileManager -): +class Model3D(Changeable, Editable, Clearable, IOComponent, FileSerializable): """ Component allows users to upload or view 3D Model files (.obj, .glb, or .gltf). Preprocessing: This component passes the uploaded file as a {str} filepath. @@ -4237,7 +4383,6 @@ class Model3D( elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles. """ self.clear_color = clear_color or [0, 0, 0, 0] - TempFileManager.__init__(self) IOComponent.__init__( self, label=label, @@ -5549,7 +5694,7 @@ class Code(Changeable, IOComponent, SimpleSerializable): def __init__( self, - value: str | None = None, + value: str | Tuple[str] | None = None, language: str | None = None, *, label: str | None = None, @@ -5604,7 +5749,10 @@ class Code(Changeable, IOComponent, SimpleSerializable): @staticmethod def update( - value: str | None | Literal[_Keywords.NO_VALUE] = _Keywords.NO_VALUE, + value: str + | Tuple[str] + | None + | Literal[_Keywords.NO_VALUE] = _Keywords.NO_VALUE, label: str | None = None, show_label: bool | None = None, visible: bool | None = None, diff --git a/gradio/processing_utils.py b/gradio/processing_utils.py index a7e27a2028..2ca9e04729 100644 --- a/gradio/processing_utils.py +++ b/gradio/processing_utils.py @@ -1,24 +1,19 @@ from __future__ import annotations import base64 -import hashlib import json import mimetypes import os -import secrets import shutil import subprocess import tempfile -import urllib.request import warnings from io import BytesIO from pathlib import Path -from typing import Dict, Set, Tuple +from typing import Dict, Tuple -import aiofiles import numpy as np import requests -from fastapi import UploadFile from ffmpy import FFmpeg, FFprobe, FFRuntimeError from PIL import Image, ImageOps, PngImagePlugin @@ -318,132 +313,6 @@ def file_to_json(file_path: str | Path) -> Dict: return json.load(f) -class TempFileManager: - """ - A class that should be inherited by any Component that needs to manage temporary files. - It should be instantiated in the __init__ method of the component. - """ - - def __init__(self) -> None: - # Set stores all the temporary files created by this component. - self.temp_files: Set[str] = set() - self.DEFAULT_TEMP_DIR = tempfile.gettempdir() - - def hash_file(self, file_path: str, chunk_num_blocks: int = 128) -> str: - sha1 = hashlib.sha1() - with open(file_path, "rb") as f: - for chunk in iter(lambda: f.read(chunk_num_blocks * sha1.block_size), b""): - sha1.update(chunk) - return sha1.hexdigest() - - def hash_url(self, url: str, chunk_num_blocks: int = 128) -> str: - sha1 = hashlib.sha1() - remote = urllib.request.urlopen(url) - max_file_size = 100 * 1024 * 1024 # 100MB - total_read = 0 - while True: - data = remote.read(chunk_num_blocks * sha1.block_size) - total_read += chunk_num_blocks * sha1.block_size - if not data or total_read > max_file_size: - break - sha1.update(data) - return sha1.hexdigest() - - def hash_base64(self, base64_encoding: str, chunk_num_blocks: int = 128) -> str: - sha1 = hashlib.sha1() - for i in range(0, len(base64_encoding), chunk_num_blocks * sha1.block_size): - data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size] - sha1.update(data.encode("utf-8")) - return sha1.hexdigest() - - def make_temp_copy_if_needed(self, file_path: str) -> str: - """Returns a temporary file path for a copy of the given file path if it does - not already exist. Otherwise returns the path to the existing temp file.""" - temp_dir = self.hash_file(file_path) - temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir - temp_dir.mkdir(exist_ok=True, parents=True) - - f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) - f.name = utils.strip_invalid_filename_characters(Path(file_path).name) - full_temp_file_path = str(utils.abspath(temp_dir / f.name)) - - if not Path(full_temp_file_path).exists(): - shutil.copy2(file_path, full_temp_file_path) - - self.temp_files.add(full_temp_file_path) - return full_temp_file_path - - async def save_uploaded_file(self, file: UploadFile, upload_dir: str) -> str: - temp_dir = secrets.token_hex( - 20 - ) # Since the full file is being uploaded anyways, there is no benefit to hashing the file. - temp_dir = Path(upload_dir) / temp_dir - temp_dir.mkdir(exist_ok=True, parents=True) - output_file_obj = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) - - if file.filename: - file_name = Path(file.filename).name - output_file_obj.name = utils.strip_invalid_filename_characters(file_name) - - full_temp_file_path = str(utils.abspath(temp_dir / output_file_obj.name)) - - async with aiofiles.open(full_temp_file_path, "wb") as output_file: - while True: - content = await file.read(100 * 1024 * 1024) - if not content: - break - await output_file.write(content) - - return full_temp_file_path - - def download_temp_copy_if_needed(self, url: str) -> str: - """Downloads a file and makes a temporary file path for a copy if does not already - exist. Otherwise returns the path to the existing temp file.""" - temp_dir = self.hash_url(url) - temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir - temp_dir.mkdir(exist_ok=True, parents=True) - f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) - - f.name = utils.strip_invalid_filename_characters(Path(url).name) - full_temp_file_path = str(utils.abspath(temp_dir / f.name)) - - if not Path(full_temp_file_path).exists(): - with requests.get(url, stream=True) as r: - with open(full_temp_file_path, "wb") as f: - shutil.copyfileobj(r.raw, f) - - self.temp_files.add(full_temp_file_path) - return full_temp_file_path - - def base64_to_temp_file_if_needed( - self, base64_encoding: str, file_name: str | None = None - ) -> str: - """Converts a base64 encoding to a file and returns the path to the file if - the file doesn't already exist. Otherwise returns the path to the existing file.""" - temp_dir = self.hash_base64(base64_encoding) - temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir - temp_dir.mkdir(exist_ok=True, parents=True) - - guess_extension = get_extension(base64_encoding) - if file_name: - file_name = utils.strip_invalid_filename_characters(file_name) - elif guess_extension: - file_name = "file." + guess_extension - else: - file_name = "file" - f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) - f.name = file_name - full_temp_file_path = str(utils.abspath(temp_dir / f.name)) - - if not Path(full_temp_file_path).exists(): - data, _ = decode_base64_to_binary(base64_encoding) - with open(full_temp_file_path, "wb") as fb: - fb.write(data) - - self.temp_files.add(full_temp_file_path) - return full_temp_file_path - - def download_tmp_copy_of_file( url_path: str, access_token: str | None = None, dir: str | None = None ) -> tempfile._TemporaryFileWrapper: diff --git a/gradio/routes.py b/gradio/routes.py index 044b62445e..b27cb8f901 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -46,7 +46,6 @@ from gradio.data_classes import PredictBody, ResetBody from gradio.documentation import document, set_documentation_group from gradio.exceptions import Error from gradio.helpers import EventData -from gradio.processing_utils import TempFileManager from gradio.queueing import Estimation, Event from gradio.utils import cancel_tasks, run_coro_in_background, set_task_name @@ -548,7 +547,7 @@ class App(FastAPI): files: List[UploadFile] = File(...), ): output_files = [] - file_manager = TempFileManager() + file_manager = gradio.File() for input_file in files: output_files.append( await file_manager.save_uploaded_file( diff --git a/test/test_components.py b/test/test_components.py index df31f9e952..95a9a52455 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -623,6 +623,9 @@ class TestImage: assert image_input.preprocess(img).size == (30, 10) assert image_input.postprocess("test/test_files/bus.png") == img assert image_input.serialize("test/test_files/bus.png") == img + image_input = gr.Image(type="filepath") + image_temp_filepath = image_input.preprocess(img) + assert image_temp_filepath in image_input.temp_files image_input = gr.Image( source="upload", tool="editor", type="pil", label="Upload Your Image" @@ -1692,49 +1695,66 @@ class TestChatbot: Postprocess, get_config """ chatbot = gr.Chatbot() - assert chatbot.postprocess([("You are **cool**", "so are *you*")]) == [ - ("You are cool", "so are you") + assert chatbot.postprocess([["You are **cool**", "so are *you*"]]) == [ + ["You are cool", "so are you"] ] multimodal_msg = [ - (("driving.mp4",), "cool video"), - (("cantina.wav",), "cool audio"), - (("lion.jpg", "A lion"), "cool pic"), + [("test/test_files/video_sample.mp4",), "cool video"], + [("test/test_files/audio_sample.wav",), "cool audio"], + [("test/test_files/bus.png", "A bus"), "cool pic"], ] processed_multimodal_msg = [ - ( + [ { - "name": "driving.mp4", + "name": "video_sample.mp4", "mime_type": "video/mp4", "alt_text": None, "data": None, "is_file": True, }, "cool video", - ), - ( + ], + [ { - "name": "cantina.wav", + "name": "audio_sample.wav", "mime_type": "audio/wav", "alt_text": None, "data": None, "is_file": True, }, "cool audio", - ), - ( + ], + [ { - "name": "lion.jpg", - "mime_type": "image/jpeg", - "alt_text": "A lion", + "name": "bus.png", + "mime_type": "image/png", + "alt_text": "A bus", "data": None, "is_file": True, }, "cool pic", - ), + ], ] + postprocessed_multimodal_msg = chatbot.postprocess(multimodal_msg) + postprocessed_multimodal_msg_base_names = [] + for x, y in postprocessed_multimodal_msg: + if isinstance(x, dict): + x["name"] = os.path.basename(x["name"]) + postprocessed_multimodal_msg_base_names.append([x, y]) + assert postprocessed_multimodal_msg_base_names == processed_multimodal_msg + + preprocessed_multimodal_msg = chatbot.preprocess(processed_multimodal_msg) + multimodal_msg_base_names = [] + for x, y in multimodal_msg: + if isinstance(x, tuple): + if len(x) > 1: + new_x = (os.path.basename(x[0]), x[1]) + else: + new_x = (os.path.basename(x[0]),) + multimodal_msg_base_names.append([new_x, y]) + assert multimodal_msg_base_names == preprocessed_multimodal_msg - assert chatbot.postprocess(multimodal_msg) == processed_multimodal_msg assert chatbot.get_config() == { "value": [], "label": None, @@ -2606,3 +2626,84 @@ class TestCode: "interactive": None, "root_url": None, } + + +class TestTempFileManagement: + def test_hash_file(self): + temp_file_manager = gr.File() + h1 = temp_file_manager.hash_file("gradio/test_data/cheetah1.jpg") + h2 = temp_file_manager.hash_file("gradio/test_data/cheetah1-copy.jpg") + h3 = temp_file_manager.hash_file("gradio/test_data/cheetah2.jpg") + assert h1 == h2 + assert h1 != h3 + + @patch("shutil.copy2") + def test_make_temp_copy_if_needed(self, mock_copy): + temp_file_manager = gr.File() + + f = temp_file_manager.make_temp_copy_if_needed("gradio/test_data/cheetah1.jpg") + try: # Delete if already exists from before this test + os.remove(f) + except OSError: + pass + + f = temp_file_manager.make_temp_copy_if_needed("gradio/test_data/cheetah1.jpg") + assert mock_copy.called + assert len(temp_file_manager.temp_files) == 1 + assert Path(f).name == "cheetah1.jpg" + + f = temp_file_manager.make_temp_copy_if_needed("gradio/test_data/cheetah1.jpg") + assert len(temp_file_manager.temp_files) == 1 + + f = temp_file_manager.make_temp_copy_if_needed( + "gradio/test_data/cheetah1-copy.jpg" + ) + assert len(temp_file_manager.temp_files) == 2 + assert Path(f).name == "cheetah1-copy.jpg" + + def test_base64_to_temp_file_if_needed(self): + temp_file_manager = gr.File() + + base64_file_1 = media_data.BASE64_IMAGE + base64_file_2 = media_data.BASE64_AUDIO["data"] + + f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1) + try: # Delete if already exists from before this test + os.remove(f) + except OSError: + pass + + f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1) + assert len(temp_file_manager.temp_files) == 1 + + f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1) + assert len(temp_file_manager.temp_files) == 1 + + f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_2) + assert len(temp_file_manager.temp_files) == 2 + + for file in temp_file_manager.temp_files: + os.remove(file) + + @pytest.mark.flaky + @patch("shutil.copyfileobj") + def test_download_temp_copy_if_needed(self, mock_copy): + temp_file_manager = gr.File() + url1 = "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/test_image.png" + url2 = "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/cheetah1.jpg" + + f = temp_file_manager.download_temp_copy_if_needed(url1) + try: # Delete if already exists from before this test + os.remove(f) + except OSError: + pass + + f = temp_file_manager.download_temp_copy_if_needed(url1) + assert mock_copy.called + assert len(temp_file_manager.temp_files) == 1 + + f = temp_file_manager.download_temp_copy_if_needed(url1) + assert len(temp_file_manager.temp_files) == 1 + + f = temp_file_manager.download_temp_copy_if_needed(url2) + assert len(temp_file_manager.temp_files) == 2 diff --git a/test/test_processing_utils.py b/test/test_processing_utils.py index 190b6b2286..6e1f6a8367 100644 --- a/test/test_processing_utils.py +++ b/test/test_processing_utils.py @@ -134,87 +134,6 @@ class TestAudioPreprocessing: assert audio_.dtype == "int16" -class TestTempFileManager: - def test_hash_file(self): - temp_file_manager = processing_utils.TempFileManager() - h1 = temp_file_manager.hash_file("gradio/test_data/cheetah1.jpg") - h2 = temp_file_manager.hash_file("gradio/test_data/cheetah1-copy.jpg") - h3 = temp_file_manager.hash_file("gradio/test_data/cheetah2.jpg") - assert h1 == h2 - assert h1 != h3 - - @patch("shutil.copy2") - def test_make_temp_copy_if_needed(self, mock_copy): - temp_file_manager = processing_utils.TempFileManager() - - f = temp_file_manager.make_temp_copy_if_needed("gradio/test_data/cheetah1.jpg") - try: # Delete if already exists from before this test - os.remove(f) - except OSError: - pass - - f = temp_file_manager.make_temp_copy_if_needed("gradio/test_data/cheetah1.jpg") - assert mock_copy.called - assert len(temp_file_manager.temp_files) == 1 - assert Path(f).name == "cheetah1.jpg" - - f = temp_file_manager.make_temp_copy_if_needed("gradio/test_data/cheetah1.jpg") - assert len(temp_file_manager.temp_files) == 1 - - f = temp_file_manager.make_temp_copy_if_needed( - "gradio/test_data/cheetah1-copy.jpg" - ) - assert len(temp_file_manager.temp_files) == 2 - assert Path(f).name == "cheetah1-copy.jpg" - - def test_base64_to_temp_file_if_needed(self): - temp_file_manager = processing_utils.TempFileManager() - - base64_file_1 = media_data.BASE64_IMAGE - base64_file_2 = media_data.BASE64_AUDIO["data"] - - f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1) - try: # Delete if already exists from before this test - os.remove(f) - except OSError: - pass - - f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1) - assert len(temp_file_manager.temp_files) == 1 - - f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_1) - assert len(temp_file_manager.temp_files) == 1 - - f = temp_file_manager.base64_to_temp_file_if_needed(base64_file_2) - assert len(temp_file_manager.temp_files) == 2 - - for file in temp_file_manager.temp_files: - os.remove(file) - - @pytest.mark.flaky - @patch("shutil.copyfileobj") - def test_download_temp_copy_if_needed(self, mock_copy): - temp_file_manager = processing_utils.TempFileManager() - url1 = "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/test_image.png" - url2 = "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/cheetah1.jpg" - - f = temp_file_manager.download_temp_copy_if_needed(url1) - try: # Delete if already exists from before this test - os.remove(f) - except OSError: - pass - - f = temp_file_manager.download_temp_copy_if_needed(url1) - assert mock_copy.called - assert len(temp_file_manager.temp_files) == 1 - - f = temp_file_manager.download_temp_copy_if_needed(url1) - assert len(temp_file_manager.temp_files) == 1 - - f = temp_file_manager.download_temp_copy_if_needed(url2) - assert len(temp_file_manager.temp_files) == 2 - - class TestOutputPreprocessing: def test_decode_base64_to_binary(self): binary = processing_utils.decode_base64_to_binary(