mirror of
https://github.com/gradio-app/gradio.git
synced 2025-02-05 11:10:03 +08:00
Temp file fixes (#4256)
* Fix bug * Linting * CHANGELOG * Add tests * Update test * Fix remaining components + add tests * Fix tests * Fix tests * Address comments
This commit is contained in:
parent
1151c52535
commit
834afdd303
@ -4,6 +4,9 @@ No changes to highlight.
|
|||||||
|
|
||||||
## Bug Fixes:
|
## Bug Fixes:
|
||||||
|
|
||||||
|
- Fixed Gallery/AnnotatedImage components not respecting GRADIO_DEFAULT_DIR variable by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4256](https://github.com/gradio-app/gradio/pull/4256)
|
||||||
|
- Fixed Gallery/AnnotatedImage components resaving identical images by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4256](https://github.com/gradio-app/gradio/pull/4256)
|
||||||
|
- Fixed Audio/Video/File components creating empty tempfiles on each run by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4256](https://github.com/gradio-app/gradio/pull/4256)
|
||||||
- Fixed the behavior of the `run_on_click` parameter in `gr.Examples` by [@abidlabs](https://github.com/abidlabs) in [PR 4258](https://github.com/gradio-app/gradio/pull/4258).
|
- Fixed the behavior of the `run_on_click` parameter in `gr.Examples` by [@abidlabs](https://github.com/abidlabs) in [PR 4258](https://github.com/gradio-app/gradio/pull/4258).
|
||||||
- Ensure js client respcts the full root when making requests to the server by [@pngwn](https://github.com/pngwn) in [PR 4271](https://github.com/gradio-app/gradio/pull/4271)
|
- Ensure js client respcts the full root when making requests to the server by [@pngwn](https://github.com/pngwn) in [PR 4271](https://github.com/gradio-app/gradio/pull/4271)
|
||||||
|
|
||||||
|
@ -785,8 +785,8 @@ class Endpoint:
|
|||||||
if t in ["file", "uploadbutton"]
|
if t in ["file", "uploadbutton"]
|
||||||
]
|
]
|
||||||
uploaded_files = self._upload(files)
|
uploaded_files = self._upload(files)
|
||||||
self._add_uploaded_files_to_data(uploaded_files, list(data))
|
data = list(data)
|
||||||
|
self._add_uploaded_files_to_data(uploaded_files, data)
|
||||||
o = tuple([s.serialize(d) for s, d in zip(self.serializers, data)])
|
o = tuple([s.serialize(d) for s, d in zip(self.serializers, data)])
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
@ -252,13 +252,19 @@ class TestClientPredictions:
|
|||||||
with patch.object(
|
with patch.object(
|
||||||
client.endpoints[0], "_upload", wraps=client.endpoints[0]._upload
|
client.endpoints[0], "_upload", wraps=client.endpoints[0]._upload
|
||||||
) as upload:
|
) as upload:
|
||||||
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
|
with patch.object(
|
||||||
f.write("Hello from private space!")
|
client.endpoints[0], "serialize", wraps=client.endpoints[0].serialize
|
||||||
|
) as serialize:
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
|
||||||
|
f.write("Hello from private space!")
|
||||||
|
|
||||||
output = client.submit(1, "foo", f.name, api_name="/file_upload").result()
|
output = client.submit(
|
||||||
|
1, "foo", f.name, api_name="/file_upload"
|
||||||
|
).result()
|
||||||
with open(output) as f:
|
with open(output) as f:
|
||||||
assert f.read() == "Hello from private space!"
|
assert f.read() == "Hello from private space!"
|
||||||
upload.assert_called_once()
|
upload.assert_called_once()
|
||||||
|
assert all(f["is_file"] for f in serialize.return_value())
|
||||||
|
|
||||||
with patch.object(
|
with patch.object(
|
||||||
client.endpoints[1], "_upload", wraps=client.endpoints[0]._upload
|
client.endpoints[1], "_upload", wraps=client.endpoints[0]._upload
|
||||||
|
@ -20,7 +20,7 @@ from copy import deepcopy
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, cast
|
from typing import TYPE_CHECKING, Any, Callable, Dict
|
||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import altair as alt
|
import altair as alt
|
||||||
@ -217,14 +217,16 @@ class IOComponent(Component):
|
|||||||
if callable(load_fn):
|
if callable(load_fn):
|
||||||
self.attach_load_event(load_fn, every)
|
self.attach_load_event(load_fn, every)
|
||||||
|
|
||||||
def hash_file(self, file_path: str, chunk_num_blocks: int = 128) -> str:
|
@staticmethod
|
||||||
|
def hash_file(file_path: str, chunk_num_blocks: int = 128) -> str:
|
||||||
sha1 = hashlib.sha1()
|
sha1 = hashlib.sha1()
|
||||||
with open(file_path, "rb") as f:
|
with open(file_path, "rb") as f:
|
||||||
for chunk in iter(lambda: f.read(chunk_num_blocks * sha1.block_size), b""):
|
for chunk in iter(lambda: f.read(chunk_num_blocks * sha1.block_size), b""):
|
||||||
sha1.update(chunk)
|
sha1.update(chunk)
|
||||||
return sha1.hexdigest()
|
return sha1.hexdigest()
|
||||||
|
|
||||||
def hash_url(self, url: str, chunk_num_blocks: int = 128) -> str:
|
@staticmethod
|
||||||
|
def hash_url(url: str, chunk_num_blocks: int = 128) -> str:
|
||||||
sha1 = hashlib.sha1()
|
sha1 = hashlib.sha1()
|
||||||
remote = urllib.request.urlopen(url)
|
remote = urllib.request.urlopen(url)
|
||||||
max_file_size = 100 * 1024 * 1024 # 100MB
|
max_file_size = 100 * 1024 * 1024 # 100MB
|
||||||
@ -237,7 +239,14 @@ class IOComponent(Component):
|
|||||||
sha1.update(data)
|
sha1.update(data)
|
||||||
return sha1.hexdigest()
|
return sha1.hexdigest()
|
||||||
|
|
||||||
def hash_base64(self, base64_encoding: str, chunk_num_blocks: int = 128) -> str:
|
@staticmethod
|
||||||
|
def hash_bytes(bytes: bytes):
|
||||||
|
sha1 = hashlib.sha1()
|
||||||
|
sha1.update(bytes)
|
||||||
|
return sha1.hexdigest()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def hash_base64(base64_encoding: str, chunk_num_blocks: int = 128) -> str:
|
||||||
sha1 = hashlib.sha1()
|
sha1 = hashlib.sha1()
|
||||||
for i in range(0, len(base64_encoding), chunk_num_blocks * sha1.block_size):
|
for i in range(0, len(base64_encoding), chunk_num_blocks * sha1.block_size):
|
||||||
data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size]
|
data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size]
|
||||||
@ -251,9 +260,8 @@ class IOComponent(Component):
|
|||||||
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir
|
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir
|
||||||
temp_dir.mkdir(exist_ok=True, parents=True)
|
temp_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
|
name = client_utils.strip_invalid_filename_characters(Path(file_path).name)
|
||||||
f.name = client_utils.strip_invalid_filename_characters(Path(file_path).name)
|
full_temp_file_path = str(utils.abspath(temp_dir / name))
|
||||||
full_temp_file_path = str(utils.abspath(temp_dir / f.name))
|
|
||||||
|
|
||||||
if not Path(full_temp_file_path).exists():
|
if not Path(full_temp_file_path).exists():
|
||||||
shutil.copy2(file_path, full_temp_file_path)
|
shutil.copy2(file_path, full_temp_file_path)
|
||||||
@ -267,15 +275,14 @@ class IOComponent(Component):
|
|||||||
) # Since the full file is being uploaded anyways, there is no benefit to hashing the file.
|
) # Since the full file is being uploaded anyways, there is no benefit to hashing the file.
|
||||||
temp_dir = Path(upload_dir) / temp_dir
|
temp_dir = Path(upload_dir) / temp_dir
|
||||||
temp_dir.mkdir(exist_ok=True, parents=True)
|
temp_dir.mkdir(exist_ok=True, parents=True)
|
||||||
output_file_obj = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
|
|
||||||
|
|
||||||
if file.filename:
|
if file.filename:
|
||||||
file_name = Path(file.filename).name
|
file_name = Path(file.filename).name
|
||||||
output_file_obj.name = client_utils.strip_invalid_filename_characters(
|
name = client_utils.strip_invalid_filename_characters(file_name)
|
||||||
file_name
|
else:
|
||||||
)
|
name = f"tmp{secrets.token_hex(5)}"
|
||||||
|
|
||||||
full_temp_file_path = str(utils.abspath(temp_dir / output_file_obj.name))
|
full_temp_file_path = str(utils.abspath(temp_dir / name))
|
||||||
|
|
||||||
async with aiofiles.open(full_temp_file_path, "wb") as output_file:
|
async with aiofiles.open(full_temp_file_path, "wb") as output_file:
|
||||||
while True:
|
while True:
|
||||||
@ -292,10 +299,9 @@ class IOComponent(Component):
|
|||||||
temp_dir = self.hash_url(url)
|
temp_dir = self.hash_url(url)
|
||||||
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir
|
temp_dir = Path(self.DEFAULT_TEMP_DIR) / temp_dir
|
||||||
temp_dir.mkdir(exist_ok=True, parents=True)
|
temp_dir.mkdir(exist_ok=True, parents=True)
|
||||||
f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
|
|
||||||
|
|
||||||
f.name = client_utils.strip_invalid_filename_characters(Path(url).name)
|
name = client_utils.strip_invalid_filename_characters(Path(url).name)
|
||||||
full_temp_file_path = str(utils.abspath(temp_dir / f.name))
|
full_temp_file_path = str(utils.abspath(temp_dir / name))
|
||||||
|
|
||||||
if not Path(full_temp_file_path).exists():
|
if not Path(full_temp_file_path).exists():
|
||||||
with requests.get(url, stream=True) as r, open(
|
with requests.get(url, stream=True) as r, open(
|
||||||
@ -323,8 +329,7 @@ class IOComponent(Component):
|
|||||||
file_name = f"file.{guess_extension}"
|
file_name = f"file.{guess_extension}"
|
||||||
else:
|
else:
|
||||||
file_name = "file"
|
file_name = "file"
|
||||||
f = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
|
|
||||||
f.name = file_name # type: ignore
|
|
||||||
full_temp_file_path = str(utils.abspath(temp_dir / file_name)) # type: ignore
|
full_temp_file_path = str(utils.abspath(temp_dir / file_name)) # type: ignore
|
||||||
|
|
||||||
if not Path(full_temp_file_path).exists():
|
if not Path(full_temp_file_path).exists():
|
||||||
@ -335,6 +340,36 @@ class IOComponent(Component):
|
|||||||
self.temp_files.add(full_temp_file_path)
|
self.temp_files.add(full_temp_file_path)
|
||||||
return full_temp_file_path
|
return full_temp_file_path
|
||||||
|
|
||||||
|
def pil_to_temp_file(self, img: _Image.Image, dir: str, format="png") -> str:
|
||||||
|
bytes_data = processing_utils.encode_pil_to_bytes(img, format)
|
||||||
|
temp_dir = Path(dir) / self.hash_bytes(bytes_data)
|
||||||
|
temp_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
filename = str(temp_dir / f"image.{format}")
|
||||||
|
img.save(filename, pnginfo=processing_utils.get_pil_metadata(img))
|
||||||
|
return filename
|
||||||
|
|
||||||
|
def img_array_to_temp_file(self, arr: np.ndarray, dir: str) -> str:
|
||||||
|
pil_image = _Image.fromarray(
|
||||||
|
processing_utils._convert(arr, np.uint8, force_copy=False)
|
||||||
|
)
|
||||||
|
return self.pil_to_temp_file(pil_image, dir, format="png")
|
||||||
|
|
||||||
|
def audio_to_temp_file(
|
||||||
|
self, data: np.ndarray, sample_rate: int, dir: str, format: str
|
||||||
|
):
|
||||||
|
temp_dir = Path(dir) / self.hash_bytes(data.tobytes())
|
||||||
|
temp_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
filename = str(temp_dir / f"audio.{format}")
|
||||||
|
processing_utils.audio_to_file(sample_rate, data, filename, format=format)
|
||||||
|
return filename
|
||||||
|
|
||||||
|
def file_bytes_to_file(self, data: bytes, dir: str, file_name: str):
|
||||||
|
path = Path(dir) / self.hash_bytes(data)
|
||||||
|
path.mkdir(exist_ok=True, parents=True)
|
||||||
|
path = path / Path(file_name).name
|
||||||
|
path.write_bytes(data)
|
||||||
|
return path
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
config = {
|
config = {
|
||||||
"label": self.label,
|
"label": self.label,
|
||||||
@ -1758,12 +1793,11 @@ class Image(
|
|||||||
elif self.type == "numpy":
|
elif self.type == "numpy":
|
||||||
return np.array(im)
|
return np.array(im)
|
||||||
elif self.type == "filepath":
|
elif self.type == "filepath":
|
||||||
file_obj = tempfile.NamedTemporaryFile(
|
path = self.pil_to_temp_file(
|
||||||
delete=False,
|
im, dir=self.DEFAULT_TEMP_DIR, format=fmt or "png"
|
||||||
suffix=(f".{fmt.lower()}" if fmt is not None else ".png"),
|
|
||||||
)
|
)
|
||||||
im.save(file_obj.name)
|
self.temp_files.add(path)
|
||||||
return self.make_temp_copy_if_needed(file_obj.name)
|
return path
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unknown type: "
|
"Unknown type: "
|
||||||
@ -2259,8 +2293,7 @@ class Video(
|
|||||||
# HTML5 only support vtt format
|
# HTML5 only support vtt format
|
||||||
if Path(subtitle).suffix == ".srt":
|
if Path(subtitle).suffix == ".srt":
|
||||||
temp_file = tempfile.NamedTemporaryFile(
|
temp_file = tempfile.NamedTemporaryFile(
|
||||||
delete=False,
|
delete=False, suffix=".vtt", dir=self.DEFAULT_TEMP_DIR
|
||||||
suffix=".vtt",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
srt_to_vtt(subtitle, temp_file.name)
|
srt_to_vtt(subtitle, temp_file.name)
|
||||||
@ -2483,7 +2516,9 @@ class Audio(
|
|||||||
# Handle the leave one outs
|
# Handle the leave one outs
|
||||||
leave_one_out_data = np.copy(data)
|
leave_one_out_data = np.copy(data)
|
||||||
leave_one_out_data[start:stop] = 0
|
leave_one_out_data[start:stop] = 0
|
||||||
file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
file = tempfile.NamedTemporaryFile(
|
||||||
|
delete=False, suffix=".wav", dir=self.DEFAULT_TEMP_DIR
|
||||||
|
)
|
||||||
processing_utils.audio_to_file(sample_rate, leave_one_out_data, file.name)
|
processing_utils.audio_to_file(sample_rate, leave_one_out_data, file.name)
|
||||||
out_data = client_utils.encode_file_to_base64(file.name)
|
out_data = client_utils.encode_file_to_base64(file.name)
|
||||||
leave_one_out_sets.append(out_data)
|
leave_one_out_sets.append(out_data)
|
||||||
@ -2494,7 +2529,9 @@ class Audio(
|
|||||||
token = np.copy(data)
|
token = np.copy(data)
|
||||||
token[0:start] = 0
|
token[0:start] = 0
|
||||||
token[stop:] = 0
|
token[stop:] = 0
|
||||||
file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
file = tempfile.NamedTemporaryFile(
|
||||||
|
delete=False, suffix=".wav", dir=self.DEFAULT_TEMP_DIR
|
||||||
|
)
|
||||||
processing_utils.audio_to_file(sample_rate, token, file.name)
|
processing_utils.audio_to_file(sample_rate, token, file.name)
|
||||||
token_data = client_utils.encode_file_to_base64(file.name)
|
token_data = client_utils.encode_file_to_base64(file.name)
|
||||||
file.close()
|
file.close()
|
||||||
@ -2525,7 +2562,7 @@ class Audio(
|
|||||||
masked_input = np.copy(zero_input)
|
masked_input = np.copy(zero_input)
|
||||||
for t, b in zip(token_data, binary_mask_vector):
|
for t, b in zip(token_data, binary_mask_vector):
|
||||||
masked_input = masked_input + t * int(b)
|
masked_input = masked_input + t * int(b)
|
||||||
file = tempfile.NamedTemporaryFile(delete=False)
|
file = tempfile.NamedTemporaryFile(delete=False, dir=self.DEFAULT_TEMP_DIR)
|
||||||
processing_utils.audio_to_file(sample_rate, masked_input, file.name)
|
processing_utils.audio_to_file(sample_rate, masked_input, file.name)
|
||||||
masked_data = client_utils.encode_file_to_base64(file.name)
|
masked_data = client_utils.encode_file_to_base64(file.name)
|
||||||
file.close()
|
file.close()
|
||||||
@ -2546,11 +2583,9 @@ class Audio(
|
|||||||
return {"name": y, "data": None, "is_file": True}
|
return {"name": y, "data": None, "is_file": True}
|
||||||
if isinstance(y, tuple):
|
if isinstance(y, tuple):
|
||||||
sample_rate, data = y
|
sample_rate, data = y
|
||||||
file = tempfile.NamedTemporaryFile(suffix=f".{self.format}", delete=False)
|
file_path = self.audio_to_temp_file(
|
||||||
processing_utils.audio_to_file(
|
data, sample_rate, dir=self.DEFAULT_TEMP_DIR, format=self.format
|
||||||
sample_rate, data, file.name, format=self.format
|
|
||||||
)
|
)
|
||||||
file_path = str(utils.abspath(file.name))
|
|
||||||
self.temp_files.add(file_path)
|
self.temp_files.add(file_path)
|
||||||
else:
|
else:
|
||||||
file_path = self.make_temp_copy_if_needed(y)
|
file_path = self.make_temp_copy_if_needed(y)
|
||||||
@ -2720,14 +2755,21 @@ class File(
|
|||||||
)
|
)
|
||||||
if self.type == "file":
|
if self.type == "file":
|
||||||
if is_file:
|
if is_file:
|
||||||
temp_file_path = self.make_temp_copy_if_needed(file_name)
|
path = self.make_temp_copy_if_needed(file_name)
|
||||||
file = tempfile.NamedTemporaryFile(delete=False)
|
|
||||||
file.name = temp_file_path
|
|
||||||
file.orig_name = file_name # type: ignore
|
|
||||||
else:
|
else:
|
||||||
file = client_utils.decode_base64_to_file(data, file_path=file_name)
|
data, _ = client_utils.decode_base64_to_binary(data)
|
||||||
file.orig_name = file_name # type: ignore
|
path = self.file_bytes_to_file(
|
||||||
self.temp_files.add(str(utils.abspath(file.name)))
|
data, dir=self.DEFAULT_TEMP_DIR, file_name=file_name
|
||||||
|
)
|
||||||
|
path = str(utils.abspath(path))
|
||||||
|
self.temp_files.add(path)
|
||||||
|
|
||||||
|
# Creation of tempfiles here
|
||||||
|
file = tempfile.NamedTemporaryFile(
|
||||||
|
delete=False, dir=self.DEFAULT_TEMP_DIR
|
||||||
|
)
|
||||||
|
file.name = path
|
||||||
|
file.orig_name = file_name # type: ignore
|
||||||
return file
|
return file
|
||||||
elif (
|
elif (
|
||||||
self.type == "binary" or self.type == "bytes"
|
self.type == "binary" or self.type == "bytes"
|
||||||
@ -2777,13 +2819,14 @@ class File(
|
|||||||
for file in y
|
for file in y
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
return {
|
d = {
|
||||||
"orig_name": Path(y).name,
|
"orig_name": Path(y).name,
|
||||||
"name": self.make_temp_copy_if_needed(y),
|
"name": self.make_temp_copy_if_needed(y),
|
||||||
"size": Path(y).stat().st_size,
|
"size": Path(y).stat().st_size,
|
||||||
"data": None,
|
"data": None,
|
||||||
"is_file": True,
|
"is_file": True,
|
||||||
}
|
}
|
||||||
|
return d
|
||||||
|
|
||||||
def style(
|
def style(
|
||||||
self,
|
self,
|
||||||
@ -3472,14 +3515,19 @@ class UploadButton(Clickable, Uploadable, IOComponent, FileSerializable):
|
|||||||
)
|
)
|
||||||
if self.type == "file":
|
if self.type == "file":
|
||||||
if is_file:
|
if is_file:
|
||||||
temp_file_path = self.make_temp_copy_if_needed(file_name)
|
path = self.make_temp_copy_if_needed(file_name)
|
||||||
file = tempfile.NamedTemporaryFile(delete=False)
|
|
||||||
file.name = temp_file_path
|
|
||||||
file.orig_name = file_name # type: ignore
|
|
||||||
else:
|
else:
|
||||||
file = client_utils.decode_base64_to_file(data, file_path=file_name)
|
data, _ = client_utils.decode_base64_to_binary(data)
|
||||||
file.orig_name = file_name # type: ignore
|
path = self.file_bytes_to_file(
|
||||||
self.temp_files.add(str(utils.abspath(file.name)))
|
data, dir=self.DEFAULT_TEMP_DIR, file_name=file_name
|
||||||
|
)
|
||||||
|
path = str(utils.abspath(path))
|
||||||
|
self.temp_files.add(path)
|
||||||
|
file = tempfile.NamedTemporaryFile(
|
||||||
|
delete=False, dir=self.DEFAULT_TEMP_DIR
|
||||||
|
)
|
||||||
|
file.name = path
|
||||||
|
file.orig_name = file_name # type: ignore
|
||||||
return file
|
return file
|
||||||
elif self.type == "bytes":
|
elif self.type == "bytes":
|
||||||
if is_file:
|
if is_file:
|
||||||
@ -4068,11 +4116,11 @@ class AnnotatedImage(Selectable, IOComponent, JSONSerializable):
|
|||||||
base_img_path = base_img
|
base_img_path = base_img
|
||||||
base_img = np.array(_Image.open(base_img))
|
base_img = np.array(_Image.open(base_img))
|
||||||
elif isinstance(base_img, np.ndarray):
|
elif isinstance(base_img, np.ndarray):
|
||||||
base_file = processing_utils.save_array_to_file(base_img)
|
base_file = self.img_array_to_temp_file(base_img, dir=self.DEFAULT_TEMP_DIR)
|
||||||
base_img_path = str(utils.abspath(base_file.name))
|
base_img_path = str(utils.abspath(base_file))
|
||||||
elif isinstance(base_img, _Image.Image):
|
elif isinstance(base_img, _Image.Image):
|
||||||
base_file = processing_utils.save_pil_to_file(base_img)
|
base_file = self.pil_to_temp_file(base_img, dir=self.DEFAULT_TEMP_DIR)
|
||||||
base_img_path = str(utils.abspath(base_file.name))
|
base_img_path = str(utils.abspath(base_file))
|
||||||
base_img = np.array(base_img)
|
base_img = np.array(base_img)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -4116,8 +4164,10 @@ class AnnotatedImage(Selectable, IOComponent, JSONSerializable):
|
|||||||
|
|
||||||
colored_mask_img = _Image.fromarray((colored_mask).astype(np.uint8))
|
colored_mask_img = _Image.fromarray((colored_mask).astype(np.uint8))
|
||||||
|
|
||||||
mask_file = processing_utils.save_pil_to_file(colored_mask_img)
|
mask_file = self.pil_to_temp_file(
|
||||||
mask_file_path = str(utils.abspath(mask_file.name))
|
colored_mask_img, dir=self.DEFAULT_TEMP_DIR
|
||||||
|
)
|
||||||
|
mask_file_path = str(utils.abspath(mask_file))
|
||||||
self.temp_files.add(mask_file_path)
|
self.temp_files.add(mask_file_path)
|
||||||
|
|
||||||
sections.append(
|
sections.append(
|
||||||
@ -4404,12 +4454,12 @@ class Gallery(IOComponent, GallerySerializable, Selectable):
|
|||||||
if isinstance(img, (tuple, list)):
|
if isinstance(img, (tuple, list)):
|
||||||
img, caption = img
|
img, caption = img
|
||||||
if isinstance(img, np.ndarray):
|
if isinstance(img, np.ndarray):
|
||||||
file = processing_utils.save_array_to_file(img)
|
file = self.img_array_to_temp_file(img, dir=self.DEFAULT_TEMP_DIR)
|
||||||
file_path = str(utils.abspath(file.name))
|
file_path = str(utils.abspath(file))
|
||||||
self.temp_files.add(file_path)
|
self.temp_files.add(file_path)
|
||||||
elif isinstance(img, _Image.Image):
|
elif isinstance(img, _Image.Image):
|
||||||
file = processing_utils.save_pil_to_file(img)
|
file = self.pil_to_temp_file(img, dir=self.DEFAULT_TEMP_DIR)
|
||||||
file_path = str(utils.abspath(file.name))
|
file_path = str(utils.abspath(file))
|
||||||
self.temp_files.add(file_path)
|
self.temp_files.add(file_path)
|
||||||
elif isinstance(img, str):
|
elif isinstance(img, str):
|
||||||
if utils.validate_url(img):
|
if utils.validate_url(img):
|
||||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
@ -64,13 +65,6 @@ def encode_plot_to_base64(plt):
|
|||||||
return "data:image/png;base64," + base64_str
|
return "data:image/png;base64," + base64_str
|
||||||
|
|
||||||
|
|
||||||
def save_array_to_file(image_array, dir=None):
|
|
||||||
pil_image = Image.fromarray(_convert(image_array, np.uint8, force_copy=False))
|
|
||||||
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
|
|
||||||
pil_image.save(file_obj)
|
|
||||||
return file_obj
|
|
||||||
|
|
||||||
|
|
||||||
def get_pil_metadata(pil_image):
|
def get_pil_metadata(pil_image):
|
||||||
# Copy any text-only metadata
|
# Copy any text-only metadata
|
||||||
metadata = PngImagePlugin.PngInfo()
|
metadata = PngImagePlugin.PngInfo()
|
||||||
@ -81,16 +75,14 @@ def get_pil_metadata(pil_image):
|
|||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
def save_pil_to_file(pil_image, dir=None):
|
def encode_pil_to_bytes(pil_image, format="png"):
|
||||||
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
|
with BytesIO() as output_bytes:
|
||||||
pil_image.save(file_obj, pnginfo=get_pil_metadata(pil_image))
|
pil_image.save(output_bytes, format, pnginfo=get_pil_metadata(pil_image))
|
||||||
return file_obj
|
return output_bytes.getvalue()
|
||||||
|
|
||||||
|
|
||||||
def encode_pil_to_base64(pil_image):
|
def encode_pil_to_base64(pil_image):
|
||||||
with BytesIO() as output_bytes:
|
bytes_data = encode_pil_to_bytes(pil_image)
|
||||||
pil_image.save(output_bytes, "PNG", pnginfo=get_pil_metadata(pil_image))
|
|
||||||
bytes_data = output_bytes.getvalue()
|
|
||||||
base64_str = str(base64.b64encode(bytes_data), "utf-8")
|
base64_str = str(base64.b64encode(bytes_data), "utf-8")
|
||||||
return "data:image/png;base64," + base64_str
|
return "data:image/png;base64," + base64_str
|
||||||
|
|
||||||
@ -519,8 +511,8 @@ def video_is_playable(video_filepath: str) -> bool:
|
|||||||
def convert_video_to_playable_mp4(video_path: str) -> str:
|
def convert_video_to_playable_mp4(video_path: str) -> str:
|
||||||
"""Convert the video to mp4. If something goes wrong return the original video."""
|
"""Convert the video to mp4. If something goes wrong return the original video."""
|
||||||
try:
|
try:
|
||||||
output_path = Path(video_path).with_suffix(".mp4")
|
|
||||||
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
|
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
|
||||||
|
output_path = Path(video_path).with_suffix(".mp4")
|
||||||
shutil.copy2(video_path, tmp_file.name)
|
shutil.copy2(video_path, tmp_file.name)
|
||||||
# ffmpeg will automatically use h264 codec (playable in browser) when converting to mp4
|
# ffmpeg will automatically use h264 codec (playable in browser) when converting to mp4
|
||||||
ff = FFmpeg(
|
ff = FFmpeg(
|
||||||
@ -532,4 +524,7 @@ def convert_video_to_playable_mp4(video_path: str) -> str:
|
|||||||
except FFRuntimeError as e:
|
except FFRuntimeError as e:
|
||||||
print(f"Error converting video to browser-playable format {str(e)}")
|
print(f"Error converting video to browser-playable format {str(e)}")
|
||||||
output_path = video_path
|
output_path = video_path
|
||||||
|
finally:
|
||||||
|
# Remove temp file
|
||||||
|
os.remove(tmp_file.name) # type: ignore
|
||||||
return str(output_path)
|
return str(output_path)
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import pathlib
|
import pathlib
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from gradio_client import Client
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
@ -32,3 +34,24 @@ def io_components():
|
|||||||
subclasses.append(subclass)
|
subclasses.append(subclass)
|
||||||
|
|
||||||
return subclasses
|
return subclasses
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def connect():
|
||||||
|
@contextmanager
|
||||||
|
def _connect(demo: gr.Blocks, serialize=True):
|
||||||
|
_, local_url, _ = demo.launch(prevent_thread_lock=True)
|
||||||
|
try:
|
||||||
|
yield Client(local_url, serialize=serialize)
|
||||||
|
finally:
|
||||||
|
# A more verbose version of .close()
|
||||||
|
# because we should set a timeout
|
||||||
|
# the tests that call .cancel() can get stuck
|
||||||
|
# waiting for the thread to join
|
||||||
|
if demo.enable_queue:
|
||||||
|
demo._queue.close()
|
||||||
|
demo.is_running = False
|
||||||
|
demo.server.should_exit = True
|
||||||
|
demo.server.thread.join(timeout=1)
|
||||||
|
|
||||||
|
return _connect
|
||||||
|
@ -15,11 +15,14 @@ from functools import partial
|
|||||||
from string import capwords
|
from string import capwords
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import gradio_client as grc
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import websockets
|
import websockets
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from gradio_client import media_data
|
from gradio_client import media_data
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from gradio.events import SelectData
|
from gradio.events import SelectData
|
||||||
@ -463,6 +466,106 @@ class TestBlocksMethods:
|
|||||||
demo.close()
|
demo.close()
|
||||||
|
|
||||||
|
|
||||||
|
class TestTempFile:
|
||||||
|
def test_pil_images_hashed(self, tmp_path, connect, monkeypatch):
|
||||||
|
images = [
|
||||||
|
Image.new("RGB", (512, 512), color) for color in ("red", "green", "blue")
|
||||||
|
]
|
||||||
|
|
||||||
|
def create_images(n_images):
|
||||||
|
return random.sample(images, n_images)
|
||||||
|
|
||||||
|
monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path))
|
||||||
|
demo = gr.Interface(
|
||||||
|
create_images,
|
||||||
|
inputs=[gr.Slider(value=3, minimum=1, maximum=3, step=1)],
|
||||||
|
outputs=[gr.Gallery().style(grid=2, preview=True)],
|
||||||
|
)
|
||||||
|
with connect(demo) as client:
|
||||||
|
_ = client.predict(3)
|
||||||
|
_ = client.predict(3)
|
||||||
|
# only three files created
|
||||||
|
assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 3
|
||||||
|
|
||||||
|
def test_no_empty_image_files(self, tmp_path, connect, monkeypatch):
|
||||||
|
file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
|
||||||
|
image = str(file_dir / "bus.png")
|
||||||
|
|
||||||
|
monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path))
|
||||||
|
demo = gr.Interface(
|
||||||
|
lambda x: x,
|
||||||
|
inputs=gr.Image(type="filepath"),
|
||||||
|
outputs=gr.Image(),
|
||||||
|
)
|
||||||
|
with connect(demo) as client:
|
||||||
|
_ = client.predict(image)
|
||||||
|
_ = client.predict(image)
|
||||||
|
_ = client.predict(image)
|
||||||
|
# only three files created
|
||||||
|
assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 1
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("component", [gr.UploadButton, gr.File])
|
||||||
|
def test_file_component_uploads(self, component, tmp_path, connect, monkeypatch):
|
||||||
|
code_file = str(pathlib.Path(__file__))
|
||||||
|
monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path))
|
||||||
|
demo = gr.Interface(lambda x: x.name, component(), gr.File())
|
||||||
|
with connect(demo) as client:
|
||||||
|
_ = client.predict(code_file)
|
||||||
|
_ = client.predict(code_file)
|
||||||
|
# the upload route does not hash the file so 2 files from there
|
||||||
|
# We create two tempfiles (empty) because API says we return
|
||||||
|
# preprocess/postprocess will only create one file since we hash
|
||||||
|
# so 2 + 2 + 1 = 5
|
||||||
|
assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 5
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("component", [gr.UploadButton, gr.File])
|
||||||
|
def test_file_component_uploads_no_serialize(
|
||||||
|
self, component, tmp_path, connect, monkeypatch
|
||||||
|
):
|
||||||
|
code_file = str(pathlib.Path(__file__))
|
||||||
|
monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path))
|
||||||
|
demo = gr.Interface(lambda x: x.name, component(), gr.File())
|
||||||
|
with connect(demo, serialize=False) as client:
|
||||||
|
_ = client.predict(gr.File().serialize(code_file))
|
||||||
|
_ = client.predict(gr.File().serialize(code_file))
|
||||||
|
# We skip the upload route in this case
|
||||||
|
# We create two tempfiles (empty) because API says we return
|
||||||
|
# preprocess/postprocess will only create one file since we hash
|
||||||
|
# so 2 + 1 = 3
|
||||||
|
assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 3
|
||||||
|
|
||||||
|
def test_no_empty_video_files(self, tmp_path, monkeypatch, connect):
|
||||||
|
file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
|
||||||
|
video = str(file_dir / "video_sample.mp4")
|
||||||
|
monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path))
|
||||||
|
demo = gr.Interface(lambda x: x, gr.Video(type="file"), gr.Video())
|
||||||
|
with connect(demo) as client:
|
||||||
|
_, url, _ = demo.launch(prevent_thread_lock=True)
|
||||||
|
client = grc.Client(url)
|
||||||
|
_ = client.predict(video)
|
||||||
|
_ = client.predict(video)
|
||||||
|
# During preprocessing we compute the hash based on base64
|
||||||
|
# In postprocessing we compute it based on the file
|
||||||
|
assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 2
|
||||||
|
|
||||||
|
def test_no_empty_audio_files(self, tmp_path, monkeypatch, connect):
|
||||||
|
file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
|
||||||
|
audio = str(file_dir / "audio_sample.wav")
|
||||||
|
|
||||||
|
def reverse_audio(audio):
|
||||||
|
sr, data = audio
|
||||||
|
return (sr, np.flipud(data))
|
||||||
|
|
||||||
|
monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path))
|
||||||
|
demo = gr.Interface(fn=reverse_audio, inputs=gr.Audio(), outputs=gr.Audio())
|
||||||
|
with connect(demo) as client:
|
||||||
|
_ = client.predict(audio)
|
||||||
|
_ = client.predict(audio)
|
||||||
|
# During preprocessing we compute the hash based on base64
|
||||||
|
# In postprocessing we compute it based on the file
|
||||||
|
assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 2
|
||||||
|
|
||||||
|
|
||||||
class TestComponentsInBlocks:
|
class TestComponentsInBlocks:
|
||||||
def test_slider_random_value_config(self):
|
def test_slider_random_value_config(self):
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
|
@ -9,9 +9,9 @@ import ffmpy
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from gradio_client import media_data
|
from gradio_client import media_data
|
||||||
from PIL import Image
|
from PIL import Image, ImageCms
|
||||||
|
|
||||||
from gradio import processing_utils, utils
|
from gradio import components, processing_utils, utils
|
||||||
|
|
||||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||||
|
|
||||||
@ -54,16 +54,49 @@ class TestImagePreprocessing:
|
|||||||
output_base64 = processing_utils.encode_pil_to_base64(img)
|
output_base64 = processing_utils.encode_pil_to_base64(img)
|
||||||
assert output_base64 == deepcopy(media_data.ARRAY_TO_BASE64_IMAGE)
|
assert output_base64 == deepcopy(media_data.ARRAY_TO_BASE64_IMAGE)
|
||||||
|
|
||||||
def test_save_pil_to_file_keeps_pnginfo(self):
|
def test_save_pil_to_file_keeps_pnginfo(self, tmp_path):
|
||||||
input_img = Image.open("gradio/test_data/test_image.png")
|
input_img = Image.open("gradio/test_data/test_image.png")
|
||||||
input_img = input_img.convert("RGB")
|
input_img = input_img.convert("RGB")
|
||||||
input_img.info = {"key1": "value1", "key2": "value2"}
|
input_img.info = {"key1": "value1", "key2": "value2"}
|
||||||
|
|
||||||
file_obj = processing_utils.save_pil_to_file(input_img)
|
file_obj = components.Image().pil_to_temp_file(input_img, dir=tmp_path)
|
||||||
output_img = Image.open(file_obj)
|
output_img = Image.open(file_obj)
|
||||||
|
|
||||||
assert output_img.info == input_img.info
|
assert output_img.info == input_img.info
|
||||||
|
|
||||||
|
def test_np_pil_encode_to_the_same(self, tmp_path):
|
||||||
|
arr = np.random.randint(0, 255, size=(100, 100, 3), dtype=np.uint8)
|
||||||
|
pil = Image.fromarray(arr)
|
||||||
|
comp = components.Image()
|
||||||
|
assert comp.pil_to_temp_file(pil, dir=tmp_path) == comp.img_array_to_temp_file(
|
||||||
|
arr, dir=tmp_path
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_encode_pil_to_temp_file_metadata_color_profile(self, tmp_path):
|
||||||
|
# Read image
|
||||||
|
img = Image.open("gradio/test_data/test_image.png")
|
||||||
|
img_metadata = Image.open("gradio/test_data/test_image.png")
|
||||||
|
img_metadata.info = {"key1": "value1", "key2": "value2"}
|
||||||
|
|
||||||
|
# Creating sRGB profile
|
||||||
|
profile = ImageCms.createProfile("sRGB")
|
||||||
|
profile2 = ImageCms.ImageCmsProfile(profile)
|
||||||
|
img.save(tmp_path / "img_color_profile.png", icc_profile=profile2.tobytes())
|
||||||
|
img_cp1 = Image.open(str(tmp_path / "img_color_profile.png"))
|
||||||
|
|
||||||
|
# Creating XYZ profile
|
||||||
|
profile = ImageCms.createProfile("XYZ")
|
||||||
|
profile2 = ImageCms.ImageCmsProfile(profile)
|
||||||
|
img.save(tmp_path / "img_color_profile_2.png", icc_profile=profile2.tobytes())
|
||||||
|
img_cp2 = Image.open(str(tmp_path / "img_color_profile_2.png"))
|
||||||
|
|
||||||
|
comp = components.Image()
|
||||||
|
img_path = comp.pil_to_temp_file(img, dir=tmp_path)
|
||||||
|
img_metadata_path = comp.pil_to_temp_file(img_metadata, dir=tmp_path)
|
||||||
|
img_cp1_path = comp.pil_to_temp_file(img_cp1, dir=tmp_path)
|
||||||
|
img_cp2_path = comp.pil_to_temp_file(img_cp2, dir=tmp_path)
|
||||||
|
assert len({img_path, img_metadata_path, img_cp1_path, img_cp2_path}) == 4
|
||||||
|
|
||||||
def test_encode_pil_to_base64_keeps_pnginfo(self):
|
def test_encode_pil_to_base64_keeps_pnginfo(self):
|
||||||
input_img = Image.open("gradio/test_data/test_image.png")
|
input_img = Image.open("gradio/test_data/test_image.png")
|
||||||
input_img = input_img.convert("RGB")
|
input_img = input_img.convert("RGB")
|
||||||
@ -205,9 +238,12 @@ class TestVideoProcessing:
|
|||||||
shutil.copy(
|
shutil.copy(
|
||||||
str(test_file_dir / "bad_video_sample.mp4"), tmp_not_playable_vid.name
|
str(test_file_dir / "bad_video_sample.mp4"), tmp_not_playable_vid.name
|
||||||
)
|
)
|
||||||
playable_vid = processing_utils.convert_video_to_playable_mp4(
|
with patch("os.remove", wraps=os.remove) as mock_remove:
|
||||||
tmp_not_playable_vid.name
|
playable_vid = processing_utils.convert_video_to_playable_mp4(
|
||||||
)
|
tmp_not_playable_vid.name
|
||||||
|
)
|
||||||
|
# check tempfile got deleted
|
||||||
|
assert not Path(mock_remove.call_args[0][0]).exists()
|
||||||
assert processing_utils.video_is_playable(playable_vid)
|
assert processing_utils.video_is_playable(playable_vid)
|
||||||
|
|
||||||
@patch("ffmpy.FFmpeg.run", side_effect=raise_ffmpy_runtime_exception)
|
@patch("ffmpy.FFmpeg.run", side_effect=raise_ffmpy_runtime_exception)
|
||||||
|
Loading…
Reference in New Issue
Block a user