Temp file fixes (#4256)

* Fix bug

* Linting

* CHANGELOG

* Add tests

* Update test

* Fix remaining components + add tests

* Fix tests

* Fix tests

* Address comments
This commit is contained in:
Freddy Boulton 2023-05-19 17:22:12 -04:00 committed by GitHub
parent 1151c52535
commit 834afdd303
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 299 additions and 83 deletions

View File

@ -4,6 +4,9 @@ No changes to highlight.
## Bug Fixes: ## Bug Fixes:
- Fixed Gallery/AnnotatedImage components not respecting GRADIO_DEFAULT_DIR variable by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4256](https://github.com/gradio-app/gradio/pull/4256)
- Fixed Gallery/AnnotatedImage components resaving identical images by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4256](https://github.com/gradio-app/gradio/pull/4256)
- Fixed Audio/Video/File components creating empty tempfiles on each run by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4256](https://github.com/gradio-app/gradio/pull/4256)
- Fixed the behavior of the `run_on_click` parameter in `gr.Examples` by [@abidlabs](https://github.com/abidlabs) in [PR 4258](https://github.com/gradio-app/gradio/pull/4258). - Fixed the behavior of the `run_on_click` parameter in `gr.Examples` by [@abidlabs](https://github.com/abidlabs) in [PR 4258](https://github.com/gradio-app/gradio/pull/4258).
- Ensure js client respcts the full root when making requests to the server by [@pngwn](https://github.com/pngwn) in [PR 4271](https://github.com/gradio-app/gradio/pull/4271) - Ensure js client respcts the full root when making requests to the server by [@pngwn](https://github.com/pngwn) in [PR 4271](https://github.com/gradio-app/gradio/pull/4271)

View File

@ -785,8 +785,8 @@ class Endpoint:
if t in ["file", "uploadbutton"] if t in ["file", "uploadbutton"]
] ]
uploaded_files = self._upload(files) uploaded_files = self._upload(files)
self._add_uploaded_files_to_data(uploaded_files, list(data)) data = list(data)
self._add_uploaded_files_to_data(uploaded_files, data)
o = tuple([s.serialize(d) for s, d in zip(self.serializers, data)]) o = tuple([s.serialize(d) for s, d in zip(self.serializers, data)])
return o return o

View File

@ -252,13 +252,19 @@ class TestClientPredictions:
with patch.object( with patch.object(
client.endpoints[0], "_upload", wraps=client.endpoints[0]._upload client.endpoints[0], "_upload", wraps=client.endpoints[0]._upload
) as upload: ) as upload:
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: with patch.object(
f.write("Hello from private space!") client.endpoints[0], "serialize", wraps=client.endpoints[0].serialize
) as serialize:
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
f.write("Hello from private space!")
output = client.submit(1, "foo", f.name, api_name="/file_upload").result() output = client.submit(
1, "foo", f.name, api_name="/file_upload"
).result()
with open(output) as f: with open(output) as f:
assert f.read() == "Hello from private space!" assert f.read() == "Hello from private space!"
upload.assert_called_once() upload.assert_called_once()
assert all(f["is_file"] for f in serialize.return_value())
with patch.object( with patch.object(
client.endpoints[1], "_upload", wraps=client.endpoints[0]._upload client.endpoints[1], "_upload", wraps=client.endpoints[0]._upload

View File

@ -20,7 +20,7 @@ from copy import deepcopy
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from types import ModuleType from types import ModuleType
from typing import TYPE_CHECKING, Any, Callable, Dict, cast from typing import TYPE_CHECKING, Any, Callable, Dict
import aiofiles import aiofiles
import altair as alt import altair as alt
@ -217,14 +217,16 @@ class IOComponent(Component):
if callable(load_fn): if callable(load_fn):
self.attach_load_event(load_fn, every) self.attach_load_event(load_fn, every)
def hash_file(self, file_path: str, chunk_num_blocks: int = 128) -> str: @staticmethod
def hash_file(file_path: str, chunk_num_blocks: int = 128) -> str:
sha1 = hashlib.sha1() sha1 = hashlib.sha1()
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(chunk_num_blocks * sha1.block_size), b""): for chunk in iter(lambda: f.read(chunk_num_blocks * sha1.block_size), b""):
sha1.update(chunk) sha1.update(chunk)
return sha1.hexdigest() return sha1.hexdigest()
def hash_url(self, url: str, chunk_num_blocks: int = 128) -> str: @staticmethod
def hash_url(url: str, chunk_num_blocks: int = 128) -> str:
sha1 = hashlib.sha1() sha1 = hashlib.sha1()
remote = urllib.request.urlopen(url) remote = urllib.request.urlopen(url)
max_file_size = 100 * 1024 * 1024 # 100MB max_file_size = 100 * 1024 * 1024 # 100MB
@ -237,7 +239,14 @@ class IOComponent(Component):
sha1.update(data) sha1.update(data)
return sha1.hexdigest() return sha1.hexdigest()
def hash_base64(self, base64_encoding: str, chunk_num_blocks: int = 128) -> str: @staticmethod
def hash_bytes(bytes: bytes):
sha1 = hashlib.sha1()
sha1.update(bytes)
return sha1.hexdigest()
@staticmethod
def hash_base64(base64_encoding: str, chunk_num_blocks: int = 128) -> str:
sha1 = hashlib.sha1() sha1 = hashlib.sha1()
for i in range(0, len(base64_encoding), chunk_num_blocks * sha1.block_size): for i in range(0, len(base64_encoding), chunk_num_blocks * sha1.block_size):
data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size] data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size]
@ -251,9 +260,8 @@ class IOComponent(Component):
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True) temp_dir.mkdir(exist_ok=True, parents=True)
f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir) name = client_utils.strip_invalid_filename_characters(Path(file_path).name)
f.name = client_utils.strip_invalid_filename_characters(Path(file_path).name) full_temp_file_path = str(utils.abspath(temp_dir / name))
full_temp_file_path = str(utils.abspath(temp_dir / f.name))
if not Path(full_temp_file_path).exists(): if not Path(full_temp_file_path).exists():
shutil.copy2(file_path, full_temp_file_path) shutil.copy2(file_path, full_temp_file_path)
@ -267,15 +275,14 @@ class IOComponent(Component):
) # Since the full file is being uploaded anyways, there is no benefit to hashing the file. ) # Since the full file is being uploaded anyways, there is no benefit to hashing the file.
temp_dir = Path(upload_dir) / temp_dir temp_dir = Path(upload_dir) / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True) temp_dir.mkdir(exist_ok=True, parents=True)
output_file_obj = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
if file.filename: if file.filename:
file_name = Path(file.filename).name file_name = Path(file.filename).name
output_file_obj.name = client_utils.strip_invalid_filename_characters( name = client_utils.strip_invalid_filename_characters(file_name)
file_name else:
) name = f"tmp{secrets.token_hex(5)}"
full_temp_file_path = str(utils.abspath(temp_dir / output_file_obj.name)) full_temp_file_path = str(utils.abspath(temp_dir / name))
async with aiofiles.open(full_temp_file_path, "wb") as output_file: async with aiofiles.open(full_temp_file_path, "wb") as output_file:
while True: while True:
@ -292,10 +299,9 @@ class IOComponent(Component):
temp_dir = self.hash_url(url) temp_dir = self.hash_url(url)
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True) temp_dir.mkdir(exist_ok=True, parents=True)
f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
f.name = client_utils.strip_invalid_filename_characters(Path(url).name) name = client_utils.strip_invalid_filename_characters(Path(url).name)
full_temp_file_path = str(utils.abspath(temp_dir / f.name)) full_temp_file_path = str(utils.abspath(temp_dir / name))
if not Path(full_temp_file_path).exists(): if not Path(full_temp_file_path).exists():
with requests.get(url, stream=True) as r, open( with requests.get(url, stream=True) as r, open(
@ -323,8 +329,7 @@ class IOComponent(Component):
file_name = f"file.{guess_extension}" file_name = f"file.{guess_extension}"
else: else:
file_name = "file" file_name = "file"
f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
f.name = file_name # type: ignore
full_temp_file_path = str(utils.abspath(temp_dir / file_name)) # type: ignore full_temp_file_path = str(utils.abspath(temp_dir / file_name)) # type: ignore
if not Path(full_temp_file_path).exists(): if not Path(full_temp_file_path).exists():
@ -335,6 +340,36 @@ class IOComponent(Component):
self.temp_files.add(full_temp_file_path) self.temp_files.add(full_temp_file_path)
return full_temp_file_path return full_temp_file_path
def pil_to_temp_file(self, img: _Image.Image, dir: str, format="png") -> str:
bytes_data = processing_utils.encode_pil_to_bytes(img, format)
temp_dir = Path(dir) / self.hash_bytes(bytes_data)
temp_dir.mkdir(exist_ok=True, parents=True)
filename = str(temp_dir / f"image.{format}")
img.save(filename, pnginfo=processing_utils.get_pil_metadata(img))
return filename
def img_array_to_temp_file(self, arr: np.ndarray, dir: str) -> str:
pil_image = _Image.fromarray(
processing_utils._convert(arr, np.uint8, force_copy=False)
)
return self.pil_to_temp_file(pil_image, dir, format="png")
def audio_to_temp_file(
self, data: np.ndarray, sample_rate: int, dir: str, format: str
):
temp_dir = Path(dir) / self.hash_bytes(data.tobytes())
temp_dir.mkdir(exist_ok=True, parents=True)
filename = str(temp_dir / f"audio.{format}")
processing_utils.audio_to_file(sample_rate, data, filename, format=format)
return filename
def file_bytes_to_file(self, data: bytes, dir: str, file_name: str):
path = Path(dir) / self.hash_bytes(data)
path.mkdir(exist_ok=True, parents=True)
path = path / Path(file_name).name
path.write_bytes(data)
return path
def get_config(self): def get_config(self):
config = { config = {
"label": self.label, "label": self.label,
@ -1758,12 +1793,11 @@ class Image(
elif self.type == "numpy": elif self.type == "numpy":
return np.array(im) return np.array(im)
elif self.type == "filepath": elif self.type == "filepath":
file_obj = tempfile.NamedTemporaryFile( path = self.pil_to_temp_file(
delete=False, im, dir=self.DEFAULT_TEMP_DIR, format=fmt or "png"
suffix=(f".{fmt.lower()}" if fmt is not None else ".png"),
) )
im.save(file_obj.name) self.temp_files.add(path)
return self.make_temp_copy_if_needed(file_obj.name) return path
else: else:
raise ValueError( raise ValueError(
"Unknown type: " "Unknown type: "
@ -2259,8 +2293,7 @@ class Video(
# HTML5 only support vtt format # HTML5 only support vtt format
if Path(subtitle).suffix == ".srt": if Path(subtitle).suffix == ".srt":
temp_file = tempfile.NamedTemporaryFile( temp_file = tempfile.NamedTemporaryFile(
delete=False, delete=False, suffix=".vtt", dir=self.DEFAULT_TEMP_DIR
suffix=".vtt",
) )
srt_to_vtt(subtitle, temp_file.name) srt_to_vtt(subtitle, temp_file.name)
@ -2483,7 +2516,9 @@ class Audio(
# Handle the leave one outs # Handle the leave one outs
leave_one_out_data = np.copy(data) leave_one_out_data = np.copy(data)
leave_one_out_data[start:stop] = 0 leave_one_out_data[start:stop] = 0
file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") file = tempfile.NamedTemporaryFile(
delete=False, suffix=".wav", dir=self.DEFAULT_TEMP_DIR
)
processing_utils.audio_to_file(sample_rate, leave_one_out_data, file.name) processing_utils.audio_to_file(sample_rate, leave_one_out_data, file.name)
out_data = client_utils.encode_file_to_base64(file.name) out_data = client_utils.encode_file_to_base64(file.name)
leave_one_out_sets.append(out_data) leave_one_out_sets.append(out_data)
@ -2494,7 +2529,9 @@ class Audio(
token = np.copy(data) token = np.copy(data)
token[0:start] = 0 token[0:start] = 0
token[stop:] = 0 token[stop:] = 0
file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") file = tempfile.NamedTemporaryFile(
delete=False, suffix=".wav", dir=self.DEFAULT_TEMP_DIR
)
processing_utils.audio_to_file(sample_rate, token, file.name) processing_utils.audio_to_file(sample_rate, token, file.name)
token_data = client_utils.encode_file_to_base64(file.name) token_data = client_utils.encode_file_to_base64(file.name)
file.close() file.close()
@ -2525,7 +2562,7 @@ class Audio(
masked_input = np.copy(zero_input) masked_input = np.copy(zero_input)
for t, b in zip(token_data, binary_mask_vector): for t, b in zip(token_data, binary_mask_vector):
masked_input = masked_input + t * int(b) masked_input = masked_input + t * int(b)
file = tempfile.NamedTemporaryFile(delete=False) file = tempfile.NamedTemporaryFile(delete=False, dir=self.DEFAULT_TEMP_DIR)
processing_utils.audio_to_file(sample_rate, masked_input, file.name) processing_utils.audio_to_file(sample_rate, masked_input, file.name)
masked_data = client_utils.encode_file_to_base64(file.name) masked_data = client_utils.encode_file_to_base64(file.name)
file.close() file.close()
@ -2546,11 +2583,9 @@ class Audio(
return {"name": y, "data": None, "is_file": True} return {"name": y, "data": None, "is_file": True}
if isinstance(y, tuple): if isinstance(y, tuple):
sample_rate, data = y sample_rate, data = y
file = tempfile.NamedTemporaryFile(suffix=f".{self.format}", delete=False) file_path = self.audio_to_temp_file(
processing_utils.audio_to_file( data, sample_rate, dir=self.DEFAULT_TEMP_DIR, format=self.format
sample_rate, data, file.name, format=self.format
) )
file_path = str(utils.abspath(file.name))
self.temp_files.add(file_path) self.temp_files.add(file_path)
else: else:
file_path = self.make_temp_copy_if_needed(y) file_path = self.make_temp_copy_if_needed(y)
@ -2720,14 +2755,21 @@ class File(
) )
if self.type == "file": if self.type == "file":
if is_file: if is_file:
temp_file_path = self.make_temp_copy_if_needed(file_name) path = self.make_temp_copy_if_needed(file_name)
file = tempfile.NamedTemporaryFile(delete=False)
file.name = temp_file_path
file.orig_name = file_name # type: ignore
else: else:
file = client_utils.decode_base64_to_file(data, file_path=file_name) data, _ = client_utils.decode_base64_to_binary(data)
file.orig_name = file_name # type: ignore path = self.file_bytes_to_file(
self.temp_files.add(str(utils.abspath(file.name))) data, dir=self.DEFAULT_TEMP_DIR, file_name=file_name
)
path = str(utils.abspath(path))
self.temp_files.add(path)
# Creation of tempfiles here
file = tempfile.NamedTemporaryFile(
delete=False, dir=self.DEFAULT_TEMP_DIR
)
file.name = path
file.orig_name = file_name # type: ignore
return file return file
elif ( elif (
self.type == "binary" or self.type == "bytes" self.type == "binary" or self.type == "bytes"
@ -2777,13 +2819,14 @@ class File(
for file in y for file in y
] ]
else: else:
return { d = {
"orig_name": Path(y).name, "orig_name": Path(y).name,
"name": self.make_temp_copy_if_needed(y), "name": self.make_temp_copy_if_needed(y),
"size": Path(y).stat().st_size, "size": Path(y).stat().st_size,
"data": None, "data": None,
"is_file": True, "is_file": True,
} }
return d
def style( def style(
self, self,
@ -3472,14 +3515,19 @@ class UploadButton(Clickable, Uploadable, IOComponent, FileSerializable):
) )
if self.type == "file": if self.type == "file":
if is_file: if is_file:
temp_file_path = self.make_temp_copy_if_needed(file_name) path = self.make_temp_copy_if_needed(file_name)
file = tempfile.NamedTemporaryFile(delete=False)
file.name = temp_file_path
file.orig_name = file_name # type: ignore
else: else:
file = client_utils.decode_base64_to_file(data, file_path=file_name) data, _ = client_utils.decode_base64_to_binary(data)
file.orig_name = file_name # type: ignore path = self.file_bytes_to_file(
self.temp_files.add(str(utils.abspath(file.name))) data, dir=self.DEFAULT_TEMP_DIR, file_name=file_name
)
path = str(utils.abspath(path))
self.temp_files.add(path)
file = tempfile.NamedTemporaryFile(
delete=False, dir=self.DEFAULT_TEMP_DIR
)
file.name = path
file.orig_name = file_name # type: ignore
return file return file
elif self.type == "bytes": elif self.type == "bytes":
if is_file: if is_file:
@ -4068,11 +4116,11 @@ class AnnotatedImage(Selectable, IOComponent, JSONSerializable):
base_img_path = base_img base_img_path = base_img
base_img = np.array(_Image.open(base_img)) base_img = np.array(_Image.open(base_img))
elif isinstance(base_img, np.ndarray): elif isinstance(base_img, np.ndarray):
base_file = processing_utils.save_array_to_file(base_img) base_file = self.img_array_to_temp_file(base_img, dir=self.DEFAULT_TEMP_DIR)
base_img_path = str(utils.abspath(base_file.name)) base_img_path = str(utils.abspath(base_file))
elif isinstance(base_img, _Image.Image): elif isinstance(base_img, _Image.Image):
base_file = processing_utils.save_pil_to_file(base_img) base_file = self.pil_to_temp_file(base_img, dir=self.DEFAULT_TEMP_DIR)
base_img_path = str(utils.abspath(base_file.name)) base_img_path = str(utils.abspath(base_file))
base_img = np.array(base_img) base_img = np.array(base_img)
else: else:
raise ValueError( raise ValueError(
@ -4116,8 +4164,10 @@ class AnnotatedImage(Selectable, IOComponent, JSONSerializable):
colored_mask_img = _Image.fromarray((colored_mask).astype(np.uint8)) colored_mask_img = _Image.fromarray((colored_mask).astype(np.uint8))
mask_file = processing_utils.save_pil_to_file(colored_mask_img) mask_file = self.pil_to_temp_file(
mask_file_path = str(utils.abspath(mask_file.name)) colored_mask_img, dir=self.DEFAULT_TEMP_DIR
)
mask_file_path = str(utils.abspath(mask_file))
self.temp_files.add(mask_file_path) self.temp_files.add(mask_file_path)
sections.append( sections.append(
@ -4404,12 +4454,12 @@ class Gallery(IOComponent, GallerySerializable, Selectable):
if isinstance(img, (tuple, list)): if isinstance(img, (tuple, list)):
img, caption = img img, caption = img
if isinstance(img, np.ndarray): if isinstance(img, np.ndarray):
file = processing_utils.save_array_to_file(img) file = self.img_array_to_temp_file(img, dir=self.DEFAULT_TEMP_DIR)
file_path = str(utils.abspath(file.name)) file_path = str(utils.abspath(file))
self.temp_files.add(file_path) self.temp_files.add(file_path)
elif isinstance(img, _Image.Image): elif isinstance(img, _Image.Image):
file = processing_utils.save_pil_to_file(img) file = self.pil_to_temp_file(img, dir=self.DEFAULT_TEMP_DIR)
file_path = str(utils.abspath(file.name)) file_path = str(utils.abspath(file))
self.temp_files.add(file_path) self.temp_files.add(file_path)
elif isinstance(img, str): elif isinstance(img, str):
if utils.validate_url(img): if utils.validate_url(img):

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import base64 import base64
import json import json
import os
import shutil import shutil
import subprocess import subprocess
import tempfile import tempfile
@ -64,13 +65,6 @@ def encode_plot_to_base64(plt):
return "data:image/png;base64," + base64_str return "data:image/png;base64," + base64_str
def save_array_to_file(image_array, dir=None):
pil_image = Image.fromarray(_convert(image_array, np.uint8, force_copy=False))
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
pil_image.save(file_obj)
return file_obj
def get_pil_metadata(pil_image): def get_pil_metadata(pil_image):
# Copy any text-only metadata # Copy any text-only metadata
metadata = PngImagePlugin.PngInfo() metadata = PngImagePlugin.PngInfo()
@ -81,16 +75,14 @@ def get_pil_metadata(pil_image):
return metadata return metadata
def save_pil_to_file(pil_image, dir=None): def encode_pil_to_bytes(pil_image, format="png"):
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) with BytesIO() as output_bytes:
pil_image.save(file_obj, pnginfo=get_pil_metadata(pil_image)) pil_image.save(output_bytes, format, pnginfo=get_pil_metadata(pil_image))
return file_obj return output_bytes.getvalue()
def encode_pil_to_base64(pil_image): def encode_pil_to_base64(pil_image):
with BytesIO() as output_bytes: bytes_data = encode_pil_to_bytes(pil_image)
pil_image.save(output_bytes, "PNG", pnginfo=get_pil_metadata(pil_image))
bytes_data = output_bytes.getvalue()
base64_str = str(base64.b64encode(bytes_data), "utf-8") base64_str = str(base64.b64encode(bytes_data), "utf-8")
return "data:image/png;base64," + base64_str return "data:image/png;base64," + base64_str
@ -519,8 +511,8 @@ def video_is_playable(video_filepath: str) -> bool:
def convert_video_to_playable_mp4(video_path: str) -> str: def convert_video_to_playable_mp4(video_path: str) -> str:
"""Convert the video to mp4. If something goes wrong return the original video.""" """Convert the video to mp4. If something goes wrong return the original video."""
try: try:
output_path = Path(video_path).with_suffix(".mp4")
with tempfile.NamedTemporaryFile(delete=False) as tmp_file: with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
output_path = Path(video_path).with_suffix(".mp4")
shutil.copy2(video_path, tmp_file.name) shutil.copy2(video_path, tmp_file.name)
# ffmpeg will automatically use h264 codec (playable in browser) when converting to mp4 # ffmpeg will automatically use h264 codec (playable in browser) when converting to mp4
ff = FFmpeg( ff = FFmpeg(
@ -532,4 +524,7 @@ def convert_video_to_playable_mp4(video_path: str) -> str:
except FFRuntimeError as e: except FFRuntimeError as e:
print(f"Error converting video to browser-playable format {str(e)}") print(f"Error converting video to browser-playable format {str(e)}")
output_path = video_path output_path = video_path
finally:
# Remove temp file
os.remove(tmp_file.name) # type: ignore
return str(output_path) return str(output_path)

View File

@ -1,7 +1,9 @@
import inspect import inspect
import pathlib import pathlib
from contextlib import contextmanager
import pytest import pytest
from gradio_client import Client
import gradio as gr import gradio as gr
@ -32,3 +34,24 @@ def io_components():
subclasses.append(subclass) subclasses.append(subclass)
return subclasses return subclasses
@pytest.fixture
def connect():
@contextmanager
def _connect(demo: gr.Blocks, serialize=True):
_, local_url, _ = demo.launch(prevent_thread_lock=True)
try:
yield Client(local_url, serialize=serialize)
finally:
# A more verbose version of .close()
# because we should set a timeout
# the tests that call .cancel() can get stuck
# waiting for the thread to join
if demo.enable_queue:
demo._queue.close()
demo.is_running = False
demo.server.should_exit = True
demo.server.thread.join(timeout=1)
return _connect

View File

@ -15,11 +15,14 @@ from functools import partial
from string import capwords from string import capwords
from unittest.mock import patch from unittest.mock import patch
import gradio_client as grc
import numpy as np
import pytest import pytest
import uvicorn import uvicorn
import websockets import websockets
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from gradio_client import media_data from gradio_client import media_data
from PIL import Image
import gradio as gr import gradio as gr
from gradio.events import SelectData from gradio.events import SelectData
@ -463,6 +466,106 @@ class TestBlocksMethods:
demo.close() demo.close()
class TestTempFile:
def test_pil_images_hashed(self, tmp_path, connect, monkeypatch):
images = [
Image.new("RGB", (512, 512), color) for color in ("red", "green", "blue")
]
def create_images(n_images):
return random.sample(images, n_images)
monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path))
demo = gr.Interface(
create_images,
inputs=[gr.Slider(value=3, minimum=1, maximum=3, step=1)],
outputs=[gr.Gallery().style(grid=2, preview=True)],
)
with connect(demo) as client:
_ = client.predict(3)
_ = client.predict(3)
# only three files created
assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 3
def test_no_empty_image_files(self, tmp_path, connect, monkeypatch):
file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
image = str(file_dir / "bus.png")
monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path))
demo = gr.Interface(
lambda x: x,
inputs=gr.Image(type="filepath"),
outputs=gr.Image(),
)
with connect(demo) as client:
_ = client.predict(image)
_ = client.predict(image)
_ = client.predict(image)
# only three files created
assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 1
@pytest.mark.parametrize("component", [gr.UploadButton, gr.File])
def test_file_component_uploads(self, component, tmp_path, connect, monkeypatch):
code_file = str(pathlib.Path(__file__))
monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path))
demo = gr.Interface(lambda x: x.name, component(), gr.File())
with connect(demo) as client:
_ = client.predict(code_file)
_ = client.predict(code_file)
# the upload route does not hash the file so 2 files from there
# We create two tempfiles (empty) because API says we return
# preprocess/postprocess will only create one file since we hash
# so 2 + 2 + 1 = 5
assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 5
@pytest.mark.parametrize("component", [gr.UploadButton, gr.File])
def test_file_component_uploads_no_serialize(
self, component, tmp_path, connect, monkeypatch
):
code_file = str(pathlib.Path(__file__))
monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path))
demo = gr.Interface(lambda x: x.name, component(), gr.File())
with connect(demo, serialize=False) as client:
_ = client.predict(gr.File().serialize(code_file))
_ = client.predict(gr.File().serialize(code_file))
# We skip the upload route in this case
# We create two tempfiles (empty) because API says we return
# preprocess/postprocess will only create one file since we hash
# so 2 + 1 = 3
assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 3
def test_no_empty_video_files(self, tmp_path, monkeypatch, connect):
file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
video = str(file_dir / "video_sample.mp4")
monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path))
demo = gr.Interface(lambda x: x, gr.Video(type="file"), gr.Video())
with connect(demo) as client:
_, url, _ = demo.launch(prevent_thread_lock=True)
client = grc.Client(url)
_ = client.predict(video)
_ = client.predict(video)
# During preprocessing we compute the hash based on base64
# In postprocessing we compute it based on the file
assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 2
def test_no_empty_audio_files(self, tmp_path, monkeypatch, connect):
file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
audio = str(file_dir / "audio_sample.wav")
def reverse_audio(audio):
sr, data = audio
return (sr, np.flipud(data))
monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path))
demo = gr.Interface(fn=reverse_audio, inputs=gr.Audio(), outputs=gr.Audio())
with connect(demo) as client:
_ = client.predict(audio)
_ = client.predict(audio)
# During preprocessing we compute the hash based on base64
# In postprocessing we compute it based on the file
assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 2
class TestComponentsInBlocks: class TestComponentsInBlocks:
def test_slider_random_value_config(self): def test_slider_random_value_config(self):
with gr.Blocks() as demo: with gr.Blocks() as demo:

View File

@ -9,9 +9,9 @@ import ffmpy
import numpy as np import numpy as np
import pytest import pytest
from gradio_client import media_data from gradio_client import media_data
from PIL import Image from PIL import Image, ImageCms
from gradio import processing_utils, utils from gradio import components, processing_utils, utils
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -54,16 +54,49 @@ class TestImagePreprocessing:
output_base64 = processing_utils.encode_pil_to_base64(img) output_base64 = processing_utils.encode_pil_to_base64(img)
assert output_base64 == deepcopy(media_data.ARRAY_TO_BASE64_IMAGE) assert output_base64 == deepcopy(media_data.ARRAY_TO_BASE64_IMAGE)
def test_save_pil_to_file_keeps_pnginfo(self): def test_save_pil_to_file_keeps_pnginfo(self, tmp_path):
input_img = Image.open("gradio/test_data/test_image.png") input_img = Image.open("gradio/test_data/test_image.png")
input_img = input_img.convert("RGB") input_img = input_img.convert("RGB")
input_img.info = {"key1": "value1", "key2": "value2"} input_img.info = {"key1": "value1", "key2": "value2"}
file_obj = processing_utils.save_pil_to_file(input_img) file_obj = components.Image().pil_to_temp_file(input_img, dir=tmp_path)
output_img = Image.open(file_obj) output_img = Image.open(file_obj)
assert output_img.info == input_img.info assert output_img.info == input_img.info
def test_np_pil_encode_to_the_same(self, tmp_path):
arr = np.random.randint(0, 255, size=(100, 100, 3), dtype=np.uint8)
pil = Image.fromarray(arr)
comp = components.Image()
assert comp.pil_to_temp_file(pil, dir=tmp_path) == comp.img_array_to_temp_file(
arr, dir=tmp_path
)
def test_encode_pil_to_temp_file_metadata_color_profile(self, tmp_path):
# Read image
img = Image.open("gradio/test_data/test_image.png")
img_metadata = Image.open("gradio/test_data/test_image.png")
img_metadata.info = {"key1": "value1", "key2": "value2"}
# Creating sRGB profile
profile = ImageCms.createProfile("sRGB")
profile2 = ImageCms.ImageCmsProfile(profile)
img.save(tmp_path / "img_color_profile.png", icc_profile=profile2.tobytes())
img_cp1 = Image.open(str(tmp_path / "img_color_profile.png"))
# Creating XYZ profile
profile = ImageCms.createProfile("XYZ")
profile2 = ImageCms.ImageCmsProfile(profile)
img.save(tmp_path / "img_color_profile_2.png", icc_profile=profile2.tobytes())
img_cp2 = Image.open(str(tmp_path / "img_color_profile_2.png"))
comp = components.Image()
img_path = comp.pil_to_temp_file(img, dir=tmp_path)
img_metadata_path = comp.pil_to_temp_file(img_metadata, dir=tmp_path)
img_cp1_path = comp.pil_to_temp_file(img_cp1, dir=tmp_path)
img_cp2_path = comp.pil_to_temp_file(img_cp2, dir=tmp_path)
assert len({img_path, img_metadata_path, img_cp1_path, img_cp2_path}) == 4
def test_encode_pil_to_base64_keeps_pnginfo(self): def test_encode_pil_to_base64_keeps_pnginfo(self):
input_img = Image.open("gradio/test_data/test_image.png") input_img = Image.open("gradio/test_data/test_image.png")
input_img = input_img.convert("RGB") input_img = input_img.convert("RGB")
@ -205,9 +238,12 @@ class TestVideoProcessing:
shutil.copy( shutil.copy(
str(test_file_dir / "bad_video_sample.mp4"), tmp_not_playable_vid.name str(test_file_dir / "bad_video_sample.mp4"), tmp_not_playable_vid.name
) )
playable_vid = processing_utils.convert_video_to_playable_mp4( with patch("os.remove", wraps=os.remove) as mock_remove:
tmp_not_playable_vid.name playable_vid = processing_utils.convert_video_to_playable_mp4(
) tmp_not_playable_vid.name
)
# check tempfile got deleted
assert not Path(mock_remove.call_args[0][0]).exists()
assert processing_utils.video_is_playable(playable_vid) assert processing_utils.video_is_playable(playable_vid)
@patch("ffmpy.FFmpeg.run", side_effect=raise_ffmpy_runtime_exception) @patch("ffmpy.FFmpeg.run", side_effect=raise_ffmpy_runtime_exception)