mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-30 11:00:11 +08:00
Temp file fixes (#4256)
* Fix bug * Linting * CHANGELOG * Add tests * Update test * Fix remaining components + add tests * Fix tests * Fix tests * Address comments
This commit is contained in:
parent
1151c52535
commit
834afdd303
@ -4,6 +4,9 @@ No changes to highlight.
|
||||
|
||||
## Bug Fixes:
|
||||
|
||||
- Fixed Gallery/AnnotatedImage components not respecting GRADIO_DEFAULT_DIR variable by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4256](https://github.com/gradio-app/gradio/pull/4256)
|
||||
- Fixed Gallery/AnnotatedImage components resaving identical images by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4256](https://github.com/gradio-app/gradio/pull/4256)
|
||||
- Fixed Audio/Video/File components creating empty tempfiles on each run by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4256](https://github.com/gradio-app/gradio/pull/4256)
|
||||
- Fixed the behavior of the `run_on_click` parameter in `gr.Examples` by [@abidlabs](https://github.com/abidlabs) in [PR 4258](https://github.com/gradio-app/gradio/pull/4258).
|
||||
- Ensure js client respcts the full root when making requests to the server by [@pngwn](https://github.com/pngwn) in [PR 4271](https://github.com/gradio-app/gradio/pull/4271)
|
||||
|
||||
|
@ -785,8 +785,8 @@ class Endpoint:
|
||||
if t in ["file", "uploadbutton"]
|
||||
]
|
||||
uploaded_files = self._upload(files)
|
||||
self._add_uploaded_files_to_data(uploaded_files, list(data))
|
||||
|
||||
data = list(data)
|
||||
self._add_uploaded_files_to_data(uploaded_files, data)
|
||||
o = tuple([s.serialize(d) for s, d in zip(self.serializers, data)])
|
||||
return o
|
||||
|
||||
|
@ -252,13 +252,19 @@ class TestClientPredictions:
|
||||
with patch.object(
|
||||
client.endpoints[0], "_upload", wraps=client.endpoints[0]._upload
|
||||
) as upload:
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
|
||||
f.write("Hello from private space!")
|
||||
with patch.object(
|
||||
client.endpoints[0], "serialize", wraps=client.endpoints[0].serialize
|
||||
) as serialize:
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
|
||||
f.write("Hello from private space!")
|
||||
|
||||
output = client.submit(1, "foo", f.name, api_name="/file_upload").result()
|
||||
output = client.submit(
|
||||
1, "foo", f.name, api_name="/file_upload"
|
||||
).result()
|
||||
with open(output) as f:
|
||||
assert f.read() == "Hello from private space!"
|
||||
upload.assert_called_once()
|
||||
assert all(f["is_file"] for f in serialize.return_value())
|
||||
|
||||
with patch.object(
|
||||
client.endpoints[1], "_upload", wraps=client.endpoints[0]._upload
|
||||
|
@ -20,7 +20,7 @@ from copy import deepcopy
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, cast
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict
|
||||
|
||||
import aiofiles
|
||||
import altair as alt
|
||||
@ -217,14 +217,16 @@ class IOComponent(Component):
|
||||
if callable(load_fn):
|
||||
self.attach_load_event(load_fn, every)
|
||||
|
||||
def hash_file(self, file_path: str, chunk_num_blocks: int = 128) -> str:
|
||||
@staticmethod
|
||||
def hash_file(file_path: str, chunk_num_blocks: int = 128) -> str:
|
||||
sha1 = hashlib.sha1()
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(chunk_num_blocks * sha1.block_size), b""):
|
||||
sha1.update(chunk)
|
||||
return sha1.hexdigest()
|
||||
|
||||
def hash_url(self, url: str, chunk_num_blocks: int = 128) -> str:
|
||||
@staticmethod
|
||||
def hash_url(url: str, chunk_num_blocks: int = 128) -> str:
|
||||
sha1 = hashlib.sha1()
|
||||
remote = urllib.request.urlopen(url)
|
||||
max_file_size = 100 * 1024 * 1024 # 100MB
|
||||
@ -237,7 +239,14 @@ class IOComponent(Component):
|
||||
sha1.update(data)
|
||||
return sha1.hexdigest()
|
||||
|
||||
def hash_base64(self, base64_encoding: str, chunk_num_blocks: int = 128) -> str:
|
||||
@staticmethod
|
||||
def hash_bytes(bytes: bytes):
|
||||
sha1 = hashlib.sha1()
|
||||
sha1.update(bytes)
|
||||
return sha1.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def hash_base64(base64_encoding: str, chunk_num_blocks: int = 128) -> str:
|
||||
sha1 = hashlib.sha1()
|
||||
for i in range(0, len(base64_encoding), chunk_num_blocks * sha1.block_size):
|
||||
data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size]
|
||||
@ -251,9 +260,8 @@ class IOComponent(Component):
|
||||
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir
|
||||
temp_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
|
||||
f.name = client_utils.strip_invalid_filename_characters(Path(file_path).name)
|
||||
full_temp_file_path = str(utils.abspath(temp_dir / f.name))
|
||||
name = client_utils.strip_invalid_filename_characters(Path(file_path).name)
|
||||
full_temp_file_path = str(utils.abspath(temp_dir / name))
|
||||
|
||||
if not Path(full_temp_file_path).exists():
|
||||
shutil.copy2(file_path, full_temp_file_path)
|
||||
@ -267,15 +275,14 @@ class IOComponent(Component):
|
||||
) # Since the full file is being uploaded anyways, there is no benefit to hashing the file.
|
||||
temp_dir = Path(upload_dir) / temp_dir
|
||||
temp_dir.mkdir(exist_ok=True, parents=True)
|
||||
output_file_obj = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
|
||||
|
||||
if file.filename:
|
||||
file_name = Path(file.filename).name
|
||||
output_file_obj.name = client_utils.strip_invalid_filename_characters(
|
||||
file_name
|
||||
)
|
||||
name = client_utils.strip_invalid_filename_characters(file_name)
|
||||
else:
|
||||
name = f"tmp{secrets.token_hex(5)}"
|
||||
|
||||
full_temp_file_path = str(utils.abspath(temp_dir / output_file_obj.name))
|
||||
full_temp_file_path = str(utils.abspath(temp_dir / name))
|
||||
|
||||
async with aiofiles.open(full_temp_file_path, "wb") as output_file:
|
||||
while True:
|
||||
@ -292,10 +299,9 @@ class IOComponent(Component):
|
||||
temp_dir = self.hash_url(url)
|
||||
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir
|
||||
temp_dir.mkdir(exist_ok=True, parents=True)
|
||||
f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
|
||||
|
||||
f.name = client_utils.strip_invalid_filename_characters(Path(url).name)
|
||||
full_temp_file_path = str(utils.abspath(temp_dir / f.name))
|
||||
name = client_utils.strip_invalid_filename_characters(Path(url).name)
|
||||
full_temp_file_path = str(utils.abspath(temp_dir / name))
|
||||
|
||||
if not Path(full_temp_file_path).exists():
|
||||
with requests.get(url, stream=True) as r, open(
|
||||
@ -323,8 +329,7 @@ class IOComponent(Component):
|
||||
file_name = f"file.{guess_extension}"
|
||||
else:
|
||||
file_name = "file"
|
||||
f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
|
||||
f.name = file_name # type: ignore
|
||||
|
||||
full_temp_file_path = str(utils.abspath(temp_dir / file_name)) # type: ignore
|
||||
|
||||
if not Path(full_temp_file_path).exists():
|
||||
@ -335,6 +340,36 @@ class IOComponent(Component):
|
||||
self.temp_files.add(full_temp_file_path)
|
||||
return full_temp_file_path
|
||||
|
||||
def pil_to_temp_file(self, img: _Image.Image, dir: str, format="png") -> str:
|
||||
bytes_data = processing_utils.encode_pil_to_bytes(img, format)
|
||||
temp_dir = Path(dir) / self.hash_bytes(bytes_data)
|
||||
temp_dir.mkdir(exist_ok=True, parents=True)
|
||||
filename = str(temp_dir / f"image.{format}")
|
||||
img.save(filename, pnginfo=processing_utils.get_pil_metadata(img))
|
||||
return filename
|
||||
|
||||
def img_array_to_temp_file(self, arr: np.ndarray, dir: str) -> str:
|
||||
pil_image = _Image.fromarray(
|
||||
processing_utils._convert(arr, np.uint8, force_copy=False)
|
||||
)
|
||||
return self.pil_to_temp_file(pil_image, dir, format="png")
|
||||
|
||||
def audio_to_temp_file(
|
||||
self, data: np.ndarray, sample_rate: int, dir: str, format: str
|
||||
):
|
||||
temp_dir = Path(dir) / self.hash_bytes(data.tobytes())
|
||||
temp_dir.mkdir(exist_ok=True, parents=True)
|
||||
filename = str(temp_dir / f"audio.{format}")
|
||||
processing_utils.audio_to_file(sample_rate, data, filename, format=format)
|
||||
return filename
|
||||
|
||||
def file_bytes_to_file(self, data: bytes, dir: str, file_name: str):
|
||||
path = Path(dir) / self.hash_bytes(data)
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
path = path / Path(file_name).name
|
||||
path.write_bytes(data)
|
||||
return path
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
"label": self.label,
|
||||
@ -1758,12 +1793,11 @@ class Image(
|
||||
elif self.type == "numpy":
|
||||
return np.array(im)
|
||||
elif self.type == "filepath":
|
||||
file_obj = tempfile.NamedTemporaryFile(
|
||||
delete=False,
|
||||
suffix=(f".{fmt.lower()}" if fmt is not None else ".png"),
|
||||
path = self.pil_to_temp_file(
|
||||
im, dir=self.DEFAULT_TEMP_DIR, format=fmt or "png"
|
||||
)
|
||||
im.save(file_obj.name)
|
||||
return self.make_temp_copy_if_needed(file_obj.name)
|
||||
self.temp_files.add(path)
|
||||
return path
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown type: "
|
||||
@ -2259,8 +2293,7 @@ class Video(
|
||||
# HTML5 only support vtt format
|
||||
if Path(subtitle).suffix == ".srt":
|
||||
temp_file = tempfile.NamedTemporaryFile(
|
||||
delete=False,
|
||||
suffix=".vtt",
|
||||
delete=False, suffix=".vtt", dir=self.DEFAULT_TEMP_DIR
|
||||
)
|
||||
|
||||
srt_to_vtt(subtitle, temp_file.name)
|
||||
@ -2483,7 +2516,9 @@ class Audio(
|
||||
# Handle the leave one outs
|
||||
leave_one_out_data = np.copy(data)
|
||||
leave_one_out_data[start:stop] = 0
|
||||
file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
||||
file = tempfile.NamedTemporaryFile(
|
||||
delete=False, suffix=".wav", dir=self.DEFAULT_TEMP_DIR
|
||||
)
|
||||
processing_utils.audio_to_file(sample_rate, leave_one_out_data, file.name)
|
||||
out_data = client_utils.encode_file_to_base64(file.name)
|
||||
leave_one_out_sets.append(out_data)
|
||||
@ -2494,7 +2529,9 @@ class Audio(
|
||||
token = np.copy(data)
|
||||
token[0:start] = 0
|
||||
token[stop:] = 0
|
||||
file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
||||
file = tempfile.NamedTemporaryFile(
|
||||
delete=False, suffix=".wav", dir=self.DEFAULT_TEMP_DIR
|
||||
)
|
||||
processing_utils.audio_to_file(sample_rate, token, file.name)
|
||||
token_data = client_utils.encode_file_to_base64(file.name)
|
||||
file.close()
|
||||
@ -2525,7 +2562,7 @@ class Audio(
|
||||
masked_input = np.copy(zero_input)
|
||||
for t, b in zip(token_data, binary_mask_vector):
|
||||
masked_input = masked_input + t * int(b)
|
||||
file = tempfile.NamedTemporaryFile(delete=False)
|
||||
file = tempfile.NamedTemporaryFile(delete=False, dir=self.DEFAULT_TEMP_DIR)
|
||||
processing_utils.audio_to_file(sample_rate, masked_input, file.name)
|
||||
masked_data = client_utils.encode_file_to_base64(file.name)
|
||||
file.close()
|
||||
@ -2546,11 +2583,9 @@ class Audio(
|
||||
return {"name": y, "data": None, "is_file": True}
|
||||
if isinstance(y, tuple):
|
||||
sample_rate, data = y
|
||||
file = tempfile.NamedTemporaryFile(suffix=f".{self.format}", delete=False)
|
||||
processing_utils.audio_to_file(
|
||||
sample_rate, data, file.name, format=self.format
|
||||
file_path = self.audio_to_temp_file(
|
||||
data, sample_rate, dir=self.DEFAULT_TEMP_DIR, format=self.format
|
||||
)
|
||||
file_path = str(utils.abspath(file.name))
|
||||
self.temp_files.add(file_path)
|
||||
else:
|
||||
file_path = self.make_temp_copy_if_needed(y)
|
||||
@ -2720,14 +2755,21 @@ class File(
|
||||
)
|
||||
if self.type == "file":
|
||||
if is_file:
|
||||
temp_file_path = self.make_temp_copy_if_needed(file_name)
|
||||
file = tempfile.NamedTemporaryFile(delete=False)
|
||||
file.name = temp_file_path
|
||||
file.orig_name = file_name # type: ignore
|
||||
path = self.make_temp_copy_if_needed(file_name)
|
||||
else:
|
||||
file = client_utils.decode_base64_to_file(data, file_path=file_name)
|
||||
file.orig_name = file_name # type: ignore
|
||||
self.temp_files.add(str(utils.abspath(file.name)))
|
||||
data, _ = client_utils.decode_base64_to_binary(data)
|
||||
path = self.file_bytes_to_file(
|
||||
data, dir=self.DEFAULT_TEMP_DIR, file_name=file_name
|
||||
)
|
||||
path = str(utils.abspath(path))
|
||||
self.temp_files.add(path)
|
||||
|
||||
# Creation of tempfiles here
|
||||
file = tempfile.NamedTemporaryFile(
|
||||
delete=False, dir=self.DEFAULT_TEMP_DIR
|
||||
)
|
||||
file.name = path
|
||||
file.orig_name = file_name # type: ignore
|
||||
return file
|
||||
elif (
|
||||
self.type == "binary" or self.type == "bytes"
|
||||
@ -2777,13 +2819,14 @@ class File(
|
||||
for file in y
|
||||
]
|
||||
else:
|
||||
return {
|
||||
d = {
|
||||
"orig_name": Path(y).name,
|
||||
"name": self.make_temp_copy_if_needed(y),
|
||||
"size": Path(y).stat().st_size,
|
||||
"data": None,
|
||||
"is_file": True,
|
||||
}
|
||||
return d
|
||||
|
||||
def style(
|
||||
self,
|
||||
@ -3472,14 +3515,19 @@ class UploadButton(Clickable, Uploadable, IOComponent, FileSerializable):
|
||||
)
|
||||
if self.type == "file":
|
||||
if is_file:
|
||||
temp_file_path = self.make_temp_copy_if_needed(file_name)
|
||||
file = tempfile.NamedTemporaryFile(delete=False)
|
||||
file.name = temp_file_path
|
||||
file.orig_name = file_name # type: ignore
|
||||
path = self.make_temp_copy_if_needed(file_name)
|
||||
else:
|
||||
file = client_utils.decode_base64_to_file(data, file_path=file_name)
|
||||
file.orig_name = file_name # type: ignore
|
||||
self.temp_files.add(str(utils.abspath(file.name)))
|
||||
data, _ = client_utils.decode_base64_to_binary(data)
|
||||
path = self.file_bytes_to_file(
|
||||
data, dir=self.DEFAULT_TEMP_DIR, file_name=file_name
|
||||
)
|
||||
path = str(utils.abspath(path))
|
||||
self.temp_files.add(path)
|
||||
file = tempfile.NamedTemporaryFile(
|
||||
delete=False, dir=self.DEFAULT_TEMP_DIR
|
||||
)
|
||||
file.name = path
|
||||
file.orig_name = file_name # type: ignore
|
||||
return file
|
||||
elif self.type == "bytes":
|
||||
if is_file:
|
||||
@ -4068,11 +4116,11 @@ class AnnotatedImage(Selectable, IOComponent, JSONSerializable):
|
||||
base_img_path = base_img
|
||||
base_img = np.array(_Image.open(base_img))
|
||||
elif isinstance(base_img, np.ndarray):
|
||||
base_file = processing_utils.save_array_to_file(base_img)
|
||||
base_img_path = str(utils.abspath(base_file.name))
|
||||
base_file = self.img_array_to_temp_file(base_img, dir=self.DEFAULT_TEMP_DIR)
|
||||
base_img_path = str(utils.abspath(base_file))
|
||||
elif isinstance(base_img, _Image.Image):
|
||||
base_file = processing_utils.save_pil_to_file(base_img)
|
||||
base_img_path = str(utils.abspath(base_file.name))
|
||||
base_file = self.pil_to_temp_file(base_img, dir=self.DEFAULT_TEMP_DIR)
|
||||
base_img_path = str(utils.abspath(base_file))
|
||||
base_img = np.array(base_img)
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -4116,8 +4164,10 @@ class AnnotatedImage(Selectable, IOComponent, JSONSerializable):
|
||||
|
||||
colored_mask_img = _Image.fromarray((colored_mask).astype(np.uint8))
|
||||
|
||||
mask_file = processing_utils.save_pil_to_file(colored_mask_img)
|
||||
mask_file_path = str(utils.abspath(mask_file.name))
|
||||
mask_file = self.pil_to_temp_file(
|
||||
colored_mask_img, dir=self.DEFAULT_TEMP_DIR
|
||||
)
|
||||
mask_file_path = str(utils.abspath(mask_file))
|
||||
self.temp_files.add(mask_file_path)
|
||||
|
||||
sections.append(
|
||||
@ -4404,12 +4454,12 @@ class Gallery(IOComponent, GallerySerializable, Selectable):
|
||||
if isinstance(img, (tuple, list)):
|
||||
img, caption = img
|
||||
if isinstance(img, np.ndarray):
|
||||
file = processing_utils.save_array_to_file(img)
|
||||
file_path = str(utils.abspath(file.name))
|
||||
file = self.img_array_to_temp_file(img, dir=self.DEFAULT_TEMP_DIR)
|
||||
file_path = str(utils.abspath(file))
|
||||
self.temp_files.add(file_path)
|
||||
elif isinstance(img, _Image.Image):
|
||||
file = processing_utils.save_pil_to_file(img)
|
||||
file_path = str(utils.abspath(file.name))
|
||||
file = self.pil_to_temp_file(img, dir=self.DEFAULT_TEMP_DIR)
|
||||
file_path = str(utils.abspath(file))
|
||||
self.temp_files.add(file_path)
|
||||
elif isinstance(img, str):
|
||||
if utils.validate_url(img):
|
||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
@ -64,13 +65,6 @@ def encode_plot_to_base64(plt):
|
||||
return "data:image/png;base64," + base64_str
|
||||
|
||||
|
||||
def save_array_to_file(image_array, dir=None):
|
||||
pil_image = Image.fromarray(_convert(image_array, np.uint8, force_copy=False))
|
||||
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
|
||||
pil_image.save(file_obj)
|
||||
return file_obj
|
||||
|
||||
|
||||
def get_pil_metadata(pil_image):
|
||||
# Copy any text-only metadata
|
||||
metadata = PngImagePlugin.PngInfo()
|
||||
@ -81,16 +75,14 @@ def get_pil_metadata(pil_image):
|
||||
return metadata
|
||||
|
||||
|
||||
def save_pil_to_file(pil_image, dir=None):
|
||||
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
|
||||
pil_image.save(file_obj, pnginfo=get_pil_metadata(pil_image))
|
||||
return file_obj
|
||||
def encode_pil_to_bytes(pil_image, format="png"):
|
||||
with BytesIO() as output_bytes:
|
||||
pil_image.save(output_bytes, format, pnginfo=get_pil_metadata(pil_image))
|
||||
return output_bytes.getvalue()
|
||||
|
||||
|
||||
def encode_pil_to_base64(pil_image):
|
||||
with BytesIO() as output_bytes:
|
||||
pil_image.save(output_bytes, "PNG", pnginfo=get_pil_metadata(pil_image))
|
||||
bytes_data = output_bytes.getvalue()
|
||||
bytes_data = encode_pil_to_bytes(pil_image)
|
||||
base64_str = str(base64.b64encode(bytes_data), "utf-8")
|
||||
return "data:image/png;base64," + base64_str
|
||||
|
||||
@ -519,8 +511,8 @@ def video_is_playable(video_filepath: str) -> bool:
|
||||
def convert_video_to_playable_mp4(video_path: str) -> str:
|
||||
"""Convert the video to mp4. If something goes wrong return the original video."""
|
||||
try:
|
||||
output_path = Path(video_path).with_suffix(".mp4")
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
|
||||
output_path = Path(video_path).with_suffix(".mp4")
|
||||
shutil.copy2(video_path, tmp_file.name)
|
||||
# ffmpeg will automatically use h264 codec (playable in browser) when converting to mp4
|
||||
ff = FFmpeg(
|
||||
@ -532,4 +524,7 @@ def convert_video_to_playable_mp4(video_path: str) -> str:
|
||||
except FFRuntimeError as e:
|
||||
print(f"Error converting video to browser-playable format {str(e)}")
|
||||
output_path = video_path
|
||||
finally:
|
||||
# Remove temp file
|
||||
os.remove(tmp_file.name) # type: ignore
|
||||
return str(output_path)
|
||||
|
@ -1,7 +1,9 @@
|
||||
import inspect
|
||||
import pathlib
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytest
|
||||
from gradio_client import Client
|
||||
|
||||
import gradio as gr
|
||||
|
||||
@ -32,3 +34,24 @@ def io_components():
|
||||
subclasses.append(subclass)
|
||||
|
||||
return subclasses
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connect():
|
||||
@contextmanager
|
||||
def _connect(demo: gr.Blocks, serialize=True):
|
||||
_, local_url, _ = demo.launch(prevent_thread_lock=True)
|
||||
try:
|
||||
yield Client(local_url, serialize=serialize)
|
||||
finally:
|
||||
# A more verbose version of .close()
|
||||
# because we should set a timeout
|
||||
# the tests that call .cancel() can get stuck
|
||||
# waiting for the thread to join
|
||||
if demo.enable_queue:
|
||||
demo._queue.close()
|
||||
demo.is_running = False
|
||||
demo.server.should_exit = True
|
||||
demo.server.thread.join(timeout=1)
|
||||
|
||||
return _connect
|
||||
|
@ -15,11 +15,14 @@ from functools import partial
|
||||
from string import capwords
|
||||
from unittest.mock import patch
|
||||
|
||||
import gradio_client as grc
|
||||
import numpy as np
|
||||
import pytest
|
||||
import uvicorn
|
||||
import websockets
|
||||
from fastapi.testclient import TestClient
|
||||
from gradio_client import media_data
|
||||
from PIL import Image
|
||||
|
||||
import gradio as gr
|
||||
from gradio.events import SelectData
|
||||
@ -463,6 +466,106 @@ class TestBlocksMethods:
|
||||
demo.close()
|
||||
|
||||
|
||||
class TestTempFile:
|
||||
def test_pil_images_hashed(self, tmp_path, connect, monkeypatch):
|
||||
images = [
|
||||
Image.new("RGB", (512, 512), color) for color in ("red", "green", "blue")
|
||||
]
|
||||
|
||||
def create_images(n_images):
|
||||
return random.sample(images, n_images)
|
||||
|
||||
monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path))
|
||||
demo = gr.Interface(
|
||||
create_images,
|
||||
inputs=[gr.Slider(value=3, minimum=1, maximum=3, step=1)],
|
||||
outputs=[gr.Gallery().style(grid=2, preview=True)],
|
||||
)
|
||||
with connect(demo) as client:
|
||||
_ = client.predict(3)
|
||||
_ = client.predict(3)
|
||||
# only three files created
|
||||
assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 3
|
||||
|
||||
def test_no_empty_image_files(self, tmp_path, connect, monkeypatch):
|
||||
file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
|
||||
image = str(file_dir / "bus.png")
|
||||
|
||||
monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path))
|
||||
demo = gr.Interface(
|
||||
lambda x: x,
|
||||
inputs=gr.Image(type="filepath"),
|
||||
outputs=gr.Image(),
|
||||
)
|
||||
with connect(demo) as client:
|
||||
_ = client.predict(image)
|
||||
_ = client.predict(image)
|
||||
_ = client.predict(image)
|
||||
# only three files created
|
||||
assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 1
|
||||
|
||||
@pytest.mark.parametrize("component", [gr.UploadButton, gr.File])
|
||||
def test_file_component_uploads(self, component, tmp_path, connect, monkeypatch):
|
||||
code_file = str(pathlib.Path(__file__))
|
||||
monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path))
|
||||
demo = gr.Interface(lambda x: x.name, component(), gr.File())
|
||||
with connect(demo) as client:
|
||||
_ = client.predict(code_file)
|
||||
_ = client.predict(code_file)
|
||||
# the upload route does not hash the file so 2 files from there
|
||||
# We create two tempfiles (empty) because API says we return
|
||||
# preprocess/postprocess will only create one file since we hash
|
||||
# so 2 + 2 + 1 = 5
|
||||
assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 5
|
||||
|
||||
@pytest.mark.parametrize("component", [gr.UploadButton, gr.File])
|
||||
def test_file_component_uploads_no_serialize(
|
||||
self, component, tmp_path, connect, monkeypatch
|
||||
):
|
||||
code_file = str(pathlib.Path(__file__))
|
||||
monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path))
|
||||
demo = gr.Interface(lambda x: x.name, component(), gr.File())
|
||||
with connect(demo, serialize=False) as client:
|
||||
_ = client.predict(gr.File().serialize(code_file))
|
||||
_ = client.predict(gr.File().serialize(code_file))
|
||||
# We skip the upload route in this case
|
||||
# We create two tempfiles (empty) because API says we return
|
||||
# preprocess/postprocess will only create one file since we hash
|
||||
# so 2 + 1 = 3
|
||||
assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 3
|
||||
|
||||
def test_no_empty_video_files(self, tmp_path, monkeypatch, connect):
|
||||
file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
|
||||
video = str(file_dir / "video_sample.mp4")
|
||||
monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path))
|
||||
demo = gr.Interface(lambda x: x, gr.Video(type="file"), gr.Video())
|
||||
with connect(demo) as client:
|
||||
_, url, _ = demo.launch(prevent_thread_lock=True)
|
||||
client = grc.Client(url)
|
||||
_ = client.predict(video)
|
||||
_ = client.predict(video)
|
||||
# During preprocessing we compute the hash based on base64
|
||||
# In postprocessing we compute it based on the file
|
||||
assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 2
|
||||
|
||||
def test_no_empty_audio_files(self, tmp_path, monkeypatch, connect):
|
||||
file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
|
||||
audio = str(file_dir / "audio_sample.wav")
|
||||
|
||||
def reverse_audio(audio):
|
||||
sr, data = audio
|
||||
return (sr, np.flipud(data))
|
||||
|
||||
monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path))
|
||||
demo = gr.Interface(fn=reverse_audio, inputs=gr.Audio(), outputs=gr.Audio())
|
||||
with connect(demo) as client:
|
||||
_ = client.predict(audio)
|
||||
_ = client.predict(audio)
|
||||
# During preprocessing we compute the hash based on base64
|
||||
# In postprocessing we compute it based on the file
|
||||
assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 2
|
||||
|
||||
|
||||
class TestComponentsInBlocks:
|
||||
def test_slider_random_value_config(self):
|
||||
with gr.Blocks() as demo:
|
||||
|
@ -9,9 +9,9 @@ import ffmpy
|
||||
import numpy as np
|
||||
import pytest
|
||||
from gradio_client import media_data
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageCms
|
||||
|
||||
from gradio import processing_utils, utils
|
||||
from gradio import components, processing_utils, utils
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
@ -54,16 +54,49 @@ class TestImagePreprocessing:
|
||||
output_base64 = processing_utils.encode_pil_to_base64(img)
|
||||
assert output_base64 == deepcopy(media_data.ARRAY_TO_BASE64_IMAGE)
|
||||
|
||||
def test_save_pil_to_file_keeps_pnginfo(self):
|
||||
def test_save_pil_to_file_keeps_pnginfo(self, tmp_path):
|
||||
input_img = Image.open("gradio/test_data/test_image.png")
|
||||
input_img = input_img.convert("RGB")
|
||||
input_img.info = {"key1": "value1", "key2": "value2"}
|
||||
|
||||
file_obj = processing_utils.save_pil_to_file(input_img)
|
||||
file_obj = components.Image().pil_to_temp_file(input_img, dir=tmp_path)
|
||||
output_img = Image.open(file_obj)
|
||||
|
||||
assert output_img.info == input_img.info
|
||||
|
||||
def test_np_pil_encode_to_the_same(self, tmp_path):
|
||||
arr = np.random.randint(0, 255, size=(100, 100, 3), dtype=np.uint8)
|
||||
pil = Image.fromarray(arr)
|
||||
comp = components.Image()
|
||||
assert comp.pil_to_temp_file(pil, dir=tmp_path) == comp.img_array_to_temp_file(
|
||||
arr, dir=tmp_path
|
||||
)
|
||||
|
||||
def test_encode_pil_to_temp_file_metadata_color_profile(self, tmp_path):
|
||||
# Read image
|
||||
img = Image.open("gradio/test_data/test_image.png")
|
||||
img_metadata = Image.open("gradio/test_data/test_image.png")
|
||||
img_metadata.info = {"key1": "value1", "key2": "value2"}
|
||||
|
||||
# Creating sRGB profile
|
||||
profile = ImageCms.createProfile("sRGB")
|
||||
profile2 = ImageCms.ImageCmsProfile(profile)
|
||||
img.save(tmp_path / "img_color_profile.png", icc_profile=profile2.tobytes())
|
||||
img_cp1 = Image.open(str(tmp_path / "img_color_profile.png"))
|
||||
|
||||
# Creating XYZ profile
|
||||
profile = ImageCms.createProfile("XYZ")
|
||||
profile2 = ImageCms.ImageCmsProfile(profile)
|
||||
img.save(tmp_path / "img_color_profile_2.png", icc_profile=profile2.tobytes())
|
||||
img_cp2 = Image.open(str(tmp_path / "img_color_profile_2.png"))
|
||||
|
||||
comp = components.Image()
|
||||
img_path = comp.pil_to_temp_file(img, dir=tmp_path)
|
||||
img_metadata_path = comp.pil_to_temp_file(img_metadata, dir=tmp_path)
|
||||
img_cp1_path = comp.pil_to_temp_file(img_cp1, dir=tmp_path)
|
||||
img_cp2_path = comp.pil_to_temp_file(img_cp2, dir=tmp_path)
|
||||
assert len({img_path, img_metadata_path, img_cp1_path, img_cp2_path}) == 4
|
||||
|
||||
def test_encode_pil_to_base64_keeps_pnginfo(self):
|
||||
input_img = Image.open("gradio/test_data/test_image.png")
|
||||
input_img = input_img.convert("RGB")
|
||||
@ -205,9 +238,12 @@ class TestVideoProcessing:
|
||||
shutil.copy(
|
||||
str(test_file_dir / "bad_video_sample.mp4"), tmp_not_playable_vid.name
|
||||
)
|
||||
playable_vid = processing_utils.convert_video_to_playable_mp4(
|
||||
tmp_not_playable_vid.name
|
||||
)
|
||||
with patch("os.remove", wraps=os.remove) as mock_remove:
|
||||
playable_vid = processing_utils.convert_video_to_playable_mp4(
|
||||
tmp_not_playable_vid.name
|
||||
)
|
||||
# check tempfile got deleted
|
||||
assert not Path(mock_remove.call_args[0][0]).exists()
|
||||
assert processing_utils.video_is_playable(playable_vid)
|
||||
|
||||
@patch("ffmpy.FFmpeg.run", side_effect=raise_ffmpy_runtime_exception)
|
||||
|
Loading…
Reference in New Issue
Block a user