Create fewer temp files and make them consistently-named (#2758)

* tmp files

* components

* changes

* temp_file_sets

* TempFileManager class

* added file manager

* internal functions

* tests

* formatting

* changes

* video tests

* added tests for File

* cheetah image

* formatting

* tests for upload button

* temp files

* formatting

* changelog

* fixed audio

* tmp files

* tmp files

* gallery

* deprecated type=file

* fixing tests

* patch os.path.exists

* fixed test_video_postprocess_converts_to_playable_format

* fixed tests

* changelog

* fix tests

* formatting

* added a download_if_needed

* formatting

* fixed download

* fixed gallery demo

* fix tests

* version

* fix for mac

* consolidate
This commit is contained in:
Abubakar Abid 2022-12-15 14:37:09 -06:00 committed by GitHub
parent 714ab2cc09
commit 20057aa946
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 445 additions and 227 deletions

View File

@ -8,7 +8,9 @@ Adds a `gr.make_waveform()` function that creates a waveform video by combining
![waveform screenrecording](https://user-images.githubusercontent.com/7870876/206062396-164a5e71-451a-4fe0-94a7-cbe9269d57e6.gif)
## Bug Fixes:
No changes to highlight.
* Fixed issue where too many temporary files were created, all with randomly generated
filepaths. Now fewer temporary files are created and are assigned a path that is a
hash based on the file contents by [@abidlabs](https://github.com/abidlabs) in [PR 2758](https://github.com/gradio-app/gradio/pull/2758)
## Documentation Changes:
No changes to highlight.
@ -83,6 +85,11 @@ These links are a more secure and scalable way to create shareable demos!
No changes to highlight.
## Full Changelog:
* Fixed typo in parameter `visible` in classes in `templates.py` by [@abidlabs](https://github.com/abidlabs) in [PR 2805](https://github.com/gradio-app/gradio/pull/2805)
* Switched external service for getting IP address from `https://api.ipify.org` to `https://checkip.amazonaws.com/` by [@abidlabs](https://github.com/abidlabs) in [PR 2810](https://github.com/gradio-app/gradio/pull/2810)
## Contributors Shoutout:
No changes to highlight.
* Fixed typo in parameter `visible` in classes in `templates.py` by [@abidlabs](https://github.com/abidlabs) in [PR 2805](https://github.com/gradio-app/gradio/pull/2805)
* Switched external service for getting IP address from `https://api.ipify.org` to `https://checkip.amazonaws.com/` by [@abidlabs](https://github.com/abidlabs) in [PR 2810](https://github.com/gradio-app/gradio/pull/2810)

View File

@ -98,8 +98,8 @@ class Block:
Context.block.add(self)
if Context.root_block is not None:
Context.root_block.blocks[self._id] = self
if hasattr(self, "temp_dir"):
Context.root_block.temp_dirs.add(self.temp_dir)
if isinstance(self, components.TempFileManager):
Context.root_block.temp_file_sets.append(self.temp_files)
return self
def unrender(self):
@ -557,7 +557,7 @@ class Blocks(BlockContext):
self.auth = None
self.dev_mode = True
self.app_id = random.getrandbits(64)
self.temp_dirs = set()
self.temp_file_sets = []
self.title = title
self.show_api = True
@ -741,7 +741,7 @@ class Blocks(BlockContext):
)
Context.root_block.fns[dependency_offset + i] = new_fn
Context.root_block.dependencies.append(dependency)
Context.root_block.temp_dirs = Context.root_block.temp_dirs | self.temp_dirs
Context.root_block.temp_file_sets.extend(self.temp_file_sets)
if Context.block is not None:
Context.block.children.extend(self.children)

View File

@ -46,6 +46,7 @@ from gradio.events import (
Uploadable,
)
from gradio.layouts import Column, Form, Row
from gradio.processing_utils import TempFileManager
from gradio.serializing import (
FileSerializable,
ImgSerializable,
@ -1263,7 +1264,7 @@ class Image(
invert_colors: whether to invert the image as a preprocessing step.
source: Source of image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "canvas" defaults to a white image that can be edited and drawn upon with tools.
tool: Tools used for editing. "editor" allows a full screen editor (and is the default if source is "upload" or "webcam"), "select" provides a cropping and zoom tool, "sketch" allows you to create a binary sketch (and is the default if source="canvas"), and "color-sketch" allows you to created a sketch in different colors. "color-sketch" can be used with source="upload" or "webcam" to allow sketching on an image. "sketch" can also be used with "upload" or "webcam" to create a mask over an image and in that case both the image and mask are passed into the function as a dictionary with keys "image" and "mask" respectively.
type: The format the image is converted to before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (width, height, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "file" produces a temporary file object whose path can be retrieved by file_obj.name, "filepath" passes a str path to a temporary file containing the image.
type: The format the image is converted to before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (width, height, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "filepath" passes a str path to a temporary file containing the image.
label: component name in interface.
show_label: if True, will display label.
interactive: if True, will allow users to upload and edit an image; if False, can only be used to display images. If not provided, this is inferred based on whether the component is used as an input or output.
@ -1273,7 +1274,7 @@ class Image(
mirror_webcam: If True webcam will be mirrored. Default is True.
"""
self.mirror_webcam = mirror_webcam
valid_types = ["numpy", "pil", "file", "filepath"]
valid_types = ["numpy", "pil", "filepath"]
if type not in valid_types:
raise ValueError(
f"Invalid value for parameter `type`: {type}. Please choose from one of: {valid_types}"
@ -1350,19 +1351,13 @@ class Image(
return im
elif self.type == "numpy":
return np.array(im)
elif self.type == "file" or self.type == "filepath":
elif self.type == "filepath":
file_obj = tempfile.NamedTemporaryFile(
delete=False,
suffix=("." + fmt.lower() if fmt is not None else ".png"),
)
im.save(file_obj.name)
if self.type == "file":
warnings.warn(
"The 'file' type has been deprecated. Set parameter 'type' to 'filepath' instead.",
)
return file_obj
else:
return file_obj.name
return file_obj.name
else:
raise ValueError(
"Unknown type: "
@ -1577,7 +1572,15 @@ class Image(
@document("change", "clear", "play", "pause", "stop", "style")
class Video(Changeable, Clearable, Playable, Uploadable, IOComponent, FileSerializable):
class Video(
Changeable,
Clearable,
Playable,
Uploadable,
IOComponent,
FileSerializable,
TempFileManager,
):
"""
Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output).
For the video to be playable in the browser it must have a compatible container and codec combination. Allowed
@ -1616,7 +1619,6 @@ class Video(Changeable, Clearable, Playable, Uploadable, IOComponent, FileSerial
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
mirror_webcam: If True webcma will be mirrored. Default is True.
"""
self.temp_dir = tempfile.mkdtemp()
self.format = format
valid_sources = ["upload", "webcam"]
if source not in valid_sources:
@ -1625,6 +1627,7 @@ class Video(Changeable, Clearable, Playable, Uploadable, IOComponent, FileSerial
)
self.source = source
self.mirror_webcam = mirror_webcam
TempFileManager.__init__(self)
IOComponent.__init__(
self,
label=label,
@ -1680,15 +1683,15 @@ class Video(Changeable, Clearable, Playable, Uploadable, IOComponent, FileSerial
x.get("is_file", False),
)
if is_file:
file = processing_utils.create_tmp_copy_of_file(file_name)
file = self.make_temp_copy_if_needed(file_name)
file_name = Path(file)
else:
file = processing_utils.decode_base64_to_file(
file_data, file_path=file_name
)
file_name = Path(file.name)
file_name = Path(file.name)
uploaded_format = file_name.suffix.replace(".", "")
modify_format = self.format is not None and uploaded_format != self.format
flip = self.source == "webcam" and self.mirror_webcam
if modify_format or flip:
@ -1698,6 +1701,8 @@ class Video(Changeable, Clearable, Playable, Uploadable, IOComponent, FileSerial
output_file_name = str(
file_name.with_name(f"{file_name.stem}{flip_suffix}{format}")
)
if os.path.exists(output_file_name):
return output_file_name
ff = FFmpeg(
inputs={str(file_name): None},
outputs={output_file_name: output_options},
@ -1724,10 +1729,20 @@ class Video(Changeable, Clearable, Playable, Uploadable, IOComponent, FileSerial
if y is None:
return None
if utils.validate_url(y):
y = processing_utils.download_to_file(y, dir=self.temp_dir).name
returned_format = y.split(".")[-1].lower()
if self.format is None or returned_format == self.format:
conversion_needed = False
else:
conversion_needed = True
# For cases where the video is a URL and does not need to be converted to another format, we can just return the URL
if utils.validate_url(y) and not (conversion_needed):
return {"name": y, "data": None, "is_file": True}
# For cases where the video needs to be converted to another format
if utils.validate_url(y):
y = self.download_temp_copy_if_needed(y)
if (
processing_utils.ffmpeg_installed()
and not processing_utils.video_is_playable(y)
@ -1742,8 +1757,8 @@ class Video(Changeable, Clearable, Playable, Uploadable, IOComponent, FileSerial
ff.run()
y = output_file_name
y = processing_utils.create_tmp_copy_of_file(y, dir=self.temp_dir)
return {"name": y.name, "data": None, "is_file": True}
y = self.make_temp_copy_if_needed(y)
return {"name": y, "data": None, "is_file": True}
def style(
self, *, height: Optional[int] = None, width: Optional[int] = None, **kwargs
@ -1771,6 +1786,7 @@ class Audio(
Uploadable,
IOComponent,
FileSerializable,
TempFileManager,
):
"""
Creates an audio component that can be used to upload/record audio (as an input) or display audio (as an output).
@ -1807,14 +1823,13 @@ class Audio(
streaming: If set to True when used in a `live` interface, will automatically stream webcam feed. Only valid is source is 'microphone'.
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
"""
self.temp_dir = tempfile.mkdtemp()
valid_sources = ["upload", "microphone"]
if source not in valid_sources:
raise ValueError(
f"Invalid value for parameter `source`: {source}. Please choose from one of: {valid_sources}"
)
self.source = source
valid_types = ["numpy", "filepath", "file"]
valid_types = ["numpy", "filepath"]
if type not in valid_types:
raise ValueError(
f"Invalid value for parameter `type`: {type}. Please choose from one of: {valid_types}"
@ -1827,6 +1842,7 @@ class Audio(
raise ValueError(
"Audio streaming only available if source is 'microphone'."
)
TempFileManager.__init__(self)
IOComponent.__init__(
self,
label=label,
@ -1869,7 +1885,7 @@ class Audio(
def preprocess(self, x: Dict[str, str] | None) -> Tuple[int, np.array] | str | None:
"""
Parameters:
x: JSON object with filename as 'name' property and base64 data as 'data' property
x: dictionary with keys "name", "data", "is_file", "crop_min", "crop_max".
Returns:
audio in requested format
"""
@ -1882,25 +1898,25 @@ class Audio(
)
crop_min, crop_max = x.get("crop_min", 0), x.get("crop_max", 100)
if is_file:
file_obj = processing_utils.create_tmp_copy_of_file(file_name)
if utils.validate_url(file_name):
temp_file_path = self.download_temp_copy_if_needed(file_name)
else:
temp_file_path = self.make_temp_copy_if_needed(file_name)
else:
file_obj = processing_utils.decode_base64_to_file(
temp_file_obj = processing_utils.decode_base64_to_file(
file_data, file_path=file_name
)
temp_file_path = temp_file_obj.name
sample_rate, data = processing_utils.audio_from_file(
file_obj.name, crop_min=crop_min, crop_max=crop_max
temp_file_path, crop_min=crop_min, crop_max=crop_max
)
if self.type == "numpy":
return sample_rate, data
elif self.type in ["file", "filepath"]:
processing_utils.audio_to_file(sample_rate, data, file_obj.name)
if self.type == "file":
warnings.warn(
"The 'file' type has been deprecated. Set parameter 'type' to 'filepath' instead.",
)
return file_obj
else:
return file_obj.name
elif self.type == "filepath":
processing_utils.audio_to_file(sample_rate, data, temp_file_path)
return temp_file_path
else:
raise ValueError(
"Unknown type: "
@ -2009,17 +2025,18 @@ class Audio(
return None
if utils.validate_url(y):
file = processing_utils.download_to_file(y, dir=self.temp_dir)
elif isinstance(y, tuple):
sample_rate, data = y
file = tempfile.NamedTemporaryFile(
suffix=".wav", dir=self.temp_dir, delete=False
)
processing_utils.audio_to_file(sample_rate, data, file.name)
else:
file = processing_utils.create_tmp_copy_of_file(y, dir=self.temp_dir)
return {"name": y, "data": None, "is_file": True}
return {"name": file.name, "data": None, "is_file": True}
if isinstance(y, tuple):
sample_rate, data = y
file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
processing_utils.audio_to_file(sample_rate, data, file.name)
file_path = file.name
self.temp_files.add(file_path)
else:
file_path = self.make_temp_copy_if_needed(y)
return {"name": file_path, "data": None, "is_file": True}
def stream(
self,
@ -2072,7 +2089,9 @@ class Audio(
@document("change", "clear", "style")
class File(Changeable, Clearable, Uploadable, IOComponent, FileSerializable):
class File(
Changeable, Clearable, Uploadable, IOComponent, FileSerializable, TempFileManager
):
"""
Creates a file component that allows uploading generic file (when used as an input) and or displaying generic files (output).
Preprocessing: passes the uploaded file as a {file-object} or {List[file-object]} depending on `file_count` (or a {bytes}/{List{bytes}} depending on `type`)
@ -2107,7 +2126,6 @@ class File(Changeable, Clearable, Uploadable, IOComponent, FileSerializable):
visible: If False, component will be hidden.
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
"""
self.temp_dir = tempfile.mkdtemp()
self.file_count = file_count
self.file_types = file_types
valid_types = [
@ -2125,6 +2143,7 @@ class File(Changeable, Clearable, Uploadable, IOComponent, FileSerializable):
)
self.type = type
self.test_input = None
TempFileManager.__init__(self)
IOComponent.__init__(
self,
label=label,
@ -2162,7 +2181,11 @@ class File(Changeable, Clearable, Uploadable, IOComponent, FileSerializable):
}
return IOComponent.add_interactive_to_config(updated_config, interactive)
def preprocess(self, x: List[Dict[str, str]] | None) -> str | List[str]:
def preprocess(
self, x: List[Dict[str, str]] | None
) -> tempfile._TemporaryFileWrapper | List[
tempfile._TemporaryFileWrapper
] | bytes | List[bytes]:
"""
Parameters:
x: List of JSON objects with filename as 'name' property and base64 data as 'data' property
@ -2180,7 +2203,9 @@ class File(Changeable, Clearable, Uploadable, IOComponent, FileSerializable):
)
if self.type == "file":
if is_file:
file = processing_utils.create_tmp_copy_of_file(file_name)
temp_file_path = self.make_temp_copy_if_needed(file_name)
file = tempfile.NamedTemporaryFile(delete=False)
file.name = temp_file_path
file.orig_name = file_name
else:
file = processing_utils.decode_base64_to_file(
@ -2216,7 +2241,9 @@ class File(Changeable, Clearable, Uploadable, IOComponent, FileSerializable):
def generate_sample(self):
return deepcopy(media_data.BASE64_FILE)
def postprocess(self, y: str) -> Dict:
def postprocess(
self, y: str | List[str]
) -> Dict[str | Any] | List[Dict[str | Any]]:
"""
Parameters:
y: file path
@ -2229,9 +2256,7 @@ class File(Changeable, Clearable, Uploadable, IOComponent, FileSerializable):
return [
{
"orig_name": os.path.basename(file),
"name": processing_utils.create_tmp_copy_of_file(
file, dir=self.temp_dir
).name,
"name": self.make_temp_copy_if_needed(file),
"size": os.path.getsize(file),
"data": None,
"is_file": True,
@ -2241,9 +2266,7 @@ class File(Changeable, Clearable, Uploadable, IOComponent, FileSerializable):
else:
return {
"orig_name": os.path.basename(y),
"name": processing_utils.create_tmp_copy_of_file(
y, dir=self.temp_dir
).name,
"name": self.make_temp_copy_if_needed(y),
"size": os.path.getsize(y),
"data": None,
"is_file": True,
@ -2780,7 +2803,9 @@ class Button(Clickable, IOComponent, SimpleSerializable):
@document("click", "upload", "style")
class UploadButton(Clickable, Uploadable, IOComponent, SimpleSerializable):
class UploadButton(
Clickable, Uploadable, IOComponent, SimpleSerializable, TempFileManager
):
"""
Used to create an upload button, when cicked allows a user to upload files that satisfy the specified file type or generic files (if file_type not set).
Preprocessing: passes the uploaded file as a {file-object} or {List[file-object]} depending on `file_count` (or a {bytes}/{List{bytes}} depending on `type`)
@ -2811,11 +2836,11 @@ class UploadButton(Clickable, Uploadable, IOComponent, SimpleSerializable):
visible: If False, component will be hidden.
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
"""
self.temp_dir = tempfile.mkdtemp()
self.type = type
self.file_count = file_count
self.file_types = file_types
self.label = label
TempFileManager.__init__(self)
IOComponent.__init__(
self, label=label, visible=visible, elem_id=elem_id, value=value, **kwargs
)
@ -2843,7 +2868,11 @@ class UploadButton(Clickable, Uploadable, IOComponent, SimpleSerializable):
}
return IOComponent.add_interactive_to_config(updated_config, interactive)
def preprocess(self, x: List[Dict[str, str]] | None) -> str | List[str]:
def preprocess(
self, x: List[Dict[str, str]] | None
) -> tempfile._TemporaryFileWrapper | List[
tempfile._TemporaryFileWrapper
] | bytes | List[bytes]:
"""
Parameters:
x: List of JSON objects with filename as 'name' property and base64 data as 'data' property
@ -2861,13 +2890,13 @@ class UploadButton(Clickable, Uploadable, IOComponent, SimpleSerializable):
)
if self.type == "file":
if is_file:
file = processing_utils.create_tmp_copy_of_file(
file_name, dir=self.temp_dir
)
temp_file_path = self.make_temp_copy_if_needed(file_name)
file = tempfile.NamedTemporaryFile(delete=False)
file.name = temp_file_path
file.orig_name = file_name
else:
file = processing_utils.decode_base64_to_file(
data, file_path=file_name, dir=self.temp_dir
data, file_path=file_name
)
file.orig_name = file_name
return file
@ -3452,7 +3481,7 @@ class HTML(Changeable, IOComponent, SimpleSerializable):
@document("style")
class Gallery(IOComponent):
class Gallery(IOComponent, TempFileManager):
"""
Used to display a list of images as a gallery that can be scrolled through.
Preprocessing: this component does *not* accept input.
@ -3479,7 +3508,7 @@ class Gallery(IOComponent):
visible: If False, component will be hidden.
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
"""
self.temp_dir = tempfile.mkdtemp()
TempFileManager.__init__(self)
super().__init__(
label=label,
show_label=show_label,
@ -3531,25 +3560,27 @@ class Gallery(IOComponent):
if isinstance(img, tuple) or isinstance(img, list):
img, caption = img
if isinstance(img, np.ndarray):
file = processing_utils.save_array_to_file(img, dir=self.temp_dir)
file = processing_utils.save_array_to_file(img)
file_path = os.path.abspath(file.name)
self.temp_files.add(file_path)
elif isinstance(img, PIL.Image.Image):
file = processing_utils.save_pil_to_file(img, dir=self.temp_dir)
file = processing_utils.save_pil_to_file(img)
file_path = os.path.abspath(file.name)
self.temp_files.add(file_path)
elif isinstance(img, str):
if utils.validate_url(img):
file = processing_utils.download_to_file(img, dir=self.temp_dir)
file_path = img
else:
file = processing_utils.create_tmp_copy_of_file(
img, dir=self.temp_dir
)
file_path = self.make_temp_copy_if_needed(img)
else:
raise ValueError(f"Cannot process type as image: {type(img)}")
if caption is not None:
output.append(
[{"name": file.name, "data": None, "is_file": True}, caption]
[{"name": file_path, "data": None, "is_file": True}, caption]
)
else:
output.append({"name": file.name, "data": None, "is_file": True})
output.append({"name": file_path, "data": None, "is_file": True})
return output
@ -3581,6 +3612,7 @@ class Gallery(IOComponent):
if x is None:
return None
gallery_path = os.path.join(save_dir, str(uuid.uuid4()))
os.makedirs(gallery_path)
captions = {}
for img_data in x:
if isinstance(img_data, list) or isinstance(img_data, tuple):
@ -3588,9 +3620,7 @@ class Gallery(IOComponent):
else:
caption = None
name = FileSerializable.deserialize(self, img_data, gallery_path)
if caption is not None:
captions[name] = caption
if len(captions):
captions[name] = caption
captions_file = os.path.join(gallery_path, "captions.json")
with open(captions_file, "w") as captions_json:
json.dump(captions, captions_json)
@ -3599,21 +3629,11 @@ class Gallery(IOComponent):
def serialize(self, x: Any, load_dir: str = "", called_directly: bool = False):
files = []
captions_file = os.path.join(x, "captions.json")
for file in os.listdir(x):
file_path = os.path.join(x, file)
if file_path == captions_file:
continue
if os.path.exists(captions_file):
with open(captions_file) as captions_json:
captions = json.load(captions_json)
caption = captions.get(file_path)
else:
caption = None
img = FileSerializable.serialize(self, file_path)
if caption:
files.append([img, caption])
else:
files.append(img)
with open(captions_file) as captions_json:
captions = json.load(captions_json)
for file_name, caption in captions.items():
img = FileSerializable.serialize(self, file_name)
files.append([img, caption])
return files
@ -3735,7 +3755,9 @@ class Chatbot(Changeable, IOComponent, JSONSerializable):
@document("change", "edit", "clear", "style")
class Model3D(Changeable, Editable, Clearable, IOComponent, FileSerializable):
class Model3D(
Changeable, Editable, Clearable, IOComponent, FileSerializable, TempFileManager
):
"""
Component allows users to upload or view 3D Model files (.obj, .glb, or .gltf).
Preprocessing: This component passes the uploaded file as a {str} filepath.
@ -3765,8 +3787,8 @@ class Model3D(Changeable, Editable, Clearable, IOComponent, FileSerializable):
visible: If False, component will be hidden.
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
"""
self.temp_dir = tempfile.mkdtemp()
self.clear_color = clear_color or [0.2, 0.2, 0.2, 1.0]
TempFileManager.__init__(self)
IOComponent.__init__(
self,
label=label,
@ -3805,7 +3827,7 @@ class Model3D(Changeable, Editable, Clearable, IOComponent, FileSerializable):
Parameters:
x: JSON object with filename as 'name' property and base64 data as 'data' property
Returns:
file path to 3D image model
string file path to temporary file with the 3D image model
"""
if x is None:
return x
@ -3815,13 +3837,14 @@ class Model3D(Changeable, Editable, Clearable, IOComponent, FileSerializable):
x.get("is_file", False),
)
if is_file:
file = processing_utils.create_tmp_copy_of_file(file_name)
temp_file_path = self.make_temp_copy_if_needed(file_name)
else:
file = processing_utils.decode_base64_to_file(
temp_file = processing_utils.decode_base64_to_file(
file_data, file_path=file_name
)
file_name = file.name
return file_name
temp_file_path = temp_file.name
return temp_file_path
def generate_sample(self):
return media_data.BASE64_MODEL3D
@ -3836,7 +3859,7 @@ class Model3D(Changeable, Editable, Clearable, IOComponent, FileSerializable):
if y is None:
return y
data = {
"name": processing_utils.create_tmp_copy_of_file(y, dir=self.temp_dir).name,
"name": self.make_temp_copy_if_needed(y),
"data": None,
"is_file": True,
}

File diff suppressed because one or more lines are too long

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import base64
import hashlib
import json
import mimetypes
import os
@ -8,9 +9,10 @@ import pathlib
import shutil
import subprocess
import tempfile
import urllib.request
import warnings
from io import BytesIO
from typing import Dict
from typing import Dict, Tuple
import numpy as np
import requests
@ -108,15 +110,6 @@ def encode_plot_to_base64(plt):
return "data:image/png;base64," + base64_str
def download_to_file(url, dir=None):
file_suffix = os.path.splitext(url)[1]
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=file_suffix, dir=dir)
with requests.get(url, stream=True) as r:
with open(file_obj.name, "wb") as f:
shutil.copyfileobj(r.raw, f)
return file_obj
def save_array_to_file(image_array, dir=None):
pil_image = Image.fromarray(_convert(image_array, np.uint8, force_copy=False))
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
@ -293,20 +286,6 @@ def decode_base64_to_file(
return file_obj
def create_tmp_copy_of_file_or_url(file_path_or_url: str, dir=None):
try:
response = requests.get(file_path_or_url, stream=True)
if file_path_or_url.find("/"):
new_file_path = file_path_or_url.rsplit("/", 1)[1]
else:
new_file_path = "file.txt"
with open(new_file_path, "wb") as out_file:
shutil.copyfileobj(response.raw, out_file)
del response
except (requests.exceptions.MissingSchema, requests.exceptions.InvalidSchema):
return create_tmp_copy_of_file(file_path_or_url, dir)
def dict_or_str_to_json_file(jsn, dir=None):
if dir is not None:
os.makedirs(dir, exist_ok=True)
@ -325,10 +304,95 @@ def file_to_json(file_path):
return json.load(open(file_path))
class TempFileManager:
"""
A class that should be inherited by any Component that needs to manage temporary files.
It should be instantiated in the __init__ method of the component.
"""
def __init__(self) -> None:
# Set stores all the temporary files created by this component.
self.temp_files = set()
def hash_file(self, file_path: str, chunk_num_blocks: int = 128) -> str:
sha1 = hashlib.sha1()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(chunk_num_blocks * sha1.block_size), b""):
sha1.update(chunk)
return sha1.hexdigest()
def hash_url(self, url: str, chunk_num_blocks: int = 128) -> str:
sha1 = hashlib.sha1()
remote = urllib.request.urlopen(url)
max_file_size = 100 * 1024 * 1024 # 100MB
total_read = 0
while True:
data = remote.read(chunk_num_blocks * sha1.block_size)
total_read += chunk_num_blocks * sha1.block_size
if not data or total_read > max_file_size:
break
sha1.update(data)
return sha1.hexdigest()
def get_prefix_and_extension(self, file_path_or_url: str) -> Tuple[str, str]:
file_name = os.path.basename(file_path_or_url)
prefix, extension = file_name, None
if "." in file_name:
prefix = file_name[0 : file_name.index(".")]
extension = "." + file_name[file_name.index(".") + 1 :]
else:
extension = ""
prefix = utils.strip_invalid_filename_characters(prefix)
return prefix, extension
def get_temp_file_path(self, file_path: str) -> str:
prefix, extension = self.get_prefix_and_extension(file_path)
file_hash = self.hash_file(file_path)
return prefix + file_hash + extension
def get_temp_url_path(self, url: str) -> str:
prefix, extension = self.get_prefix_and_extension(url)
file_hash = self.hash_url(url)
return prefix + file_hash + extension
def make_temp_copy_if_needed(self, file_path: str) -> str:
"""Returns a temporary file path for a copy of the given file path if it does
not already exist. Otherwise returns the path to the existing temp file."""
f = tempfile.NamedTemporaryFile()
temp_dir, _ = os.path.split(f.name)
temp_file_path = self.get_temp_file_path(file_path)
f.name = os.path.join(temp_dir, temp_file_path)
full_temp_file_path = os.path.abspath(f.name)
if not os.path.exists(full_temp_file_path):
shutil.copy2(file_path, full_temp_file_path)
self.temp_files.add(full_temp_file_path)
return full_temp_file_path
def download_temp_copy_if_needed(self, url: str) -> str:
"""Downloads a file and makes a temporary file path for a copy if does not already
exist. Otherwise returns the path to the existing temp file."""
f = tempfile.NamedTemporaryFile()
temp_dir, _ = os.path.split(f.name)
temp_file_path = self.get_temp_url_path(url)
f.name = os.path.join(temp_dir, temp_file_path)
full_temp_file_path = os.path.abspath(f.name)
if not os.path.exists(full_temp_file_path):
with requests.get(url, stream=True) as r:
with open(full_temp_file_path, "wb") as f:
shutil.copyfileobj(r.raw, f)
self.temp_files.add(full_temp_file_path)
return full_temp_file_path
def create_tmp_copy_of_file(file_path, dir=None):
if dir is not None:
os.makedirs(dir, exist_ok=True)
file_name = os.path.basename(file_path)
prefix, extension = file_name, None
if "." in file_name:

View File

@ -260,17 +260,17 @@ class App(FastAPI):
return FileResponse(
io.BytesIO(file_data), attachment_filename=os.path.basename(path)
)
elif Path(app.cwd).resolve() in Path(path).resolve().parents or any(
Path(temp_dir).resolve() in Path(path).resolve().parents
for temp_dir in app.blocks.temp_dirs
):
if Path(app.cwd).resolve() in Path(
path
).resolve().parents or os.path.abspath(path) in set().union(
*app.blocks.temp_file_sets
): # Need to use os.path.abspath in the second condition to be consistent with usage in TempFileManager
return FileResponse(
Path(path).resolve(), headers={"Accept-Ranges": "bytes"}
)
else:
raise ValueError(
f"File cannot be fetched: {path}, perhaps because "
f"it is not in any of {app.blocks.temp_dirs}"
f"File cannot be fetched: {path}. All files must contained within the Gradio python app working directory, or be a temp file created by the Gradio python app."
)
@app.get("/file/{path:path}", dependencies=[Depends(login_check)])

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict
from gradio import processing_utils
from gradio import processing_utils, utils
class Serializable(ABC):
@ -93,7 +93,7 @@ class ImgSerializable(Serializable):
class FileSerializable(Serializable):
def serialize(
self, x: str, load_dir: str = "", encryption_key: bytes | None = None
self, x: str | None, load_dir: str = "", encryption_key: bytes | None = None
) -> Any:
"""
Convert from human-friendly version of a file (string filepath) to a
@ -116,7 +116,10 @@ class FileSerializable(Serializable):
}
def deserialize(
self, x: Dict, save_dir: str | None = None, encryption_key: bytes | None = None
self,
x: str | Dict | None,
save_dir: str | None = None,
encryption_key: bytes | None = None,
):
"""
Convert from serialized representation of a file (base64) to a human-friendly
@ -129,21 +132,26 @@ class FileSerializable(Serializable):
if x is None:
return None
if isinstance(x, str):
file = processing_utils.decode_base64_to_file(
file_name = processing_utils.decode_base64_to_file(
x, dir=save_dir, encryption_key=encryption_key
)
).name
elif isinstance(x, dict):
if x.get("is_file", False):
file = processing_utils.create_tmp_copy_of_file(x["name"], dir=save_dir)
if utils.validate_url(x["name"]):
file_name = x["name"]
else:
file_name = processing_utils.create_tmp_copy_of_file(
x["name"], dir=save_dir
).name
else:
file = processing_utils.decode_base64_to_file(
file_name = processing_utils.decode_base64_to_file(
x["data"], dir=save_dir, encryption_key=encryption_key
)
).name
else:
raise ValueError(
f"A FileSerializable component cannot only deserialize a string or a dict, not a: {type(x)}"
)
return file.name
return file_name
class JSONSerializable(Serializable):

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

View File

@ -1 +1 @@
3.13.2
3.13.2

View File

@ -3,6 +3,7 @@ import copy
import io
import json
import os
import pathlib
import random
import sys
import time
@ -1170,3 +1171,19 @@ async def test_queue_when_using_auth():
*[run_ws(loop, tm + sleep_time * (i + 1) - 0.3, i) for i in range(3)]
)
await group
def test_temp_file_sets_get_extended():
test_file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
with gr.Blocks() as demo1:
gr.Video(str(test_file_dir / "video_sample.mp4"))
with gr.Blocks() as demo2:
gr.Audio(str(test_file_dir / "audio_sample.wav"))
with gr.Blocks() as demo3:
demo1.render()
demo2.render()
assert demo3.temp_file_sets == demo1.temp_file_sets + demo2.temp_file_sets

View File

@ -12,7 +12,7 @@ import shutil
import tempfile
from copy import deepcopy
from difflib import SequenceMatcher
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import matplotlib
import matplotlib.pyplot as plt
@ -591,9 +591,6 @@ class TestImage:
image_input = gr.Image(invert_colors=True)
assert image_input.preprocess(img) is not None
image_input.preprocess(img)
with pytest.warns(Warning):
file_image = gr.Image(type="file")
file_image.preprocess(deepcopy(media_data.BASE64_IMAGE))
file_image = gr.Image(type="filepath")
assert isinstance(file_image.preprocess(img), str)
with pytest.raises(ValueError):
@ -627,7 +624,7 @@ class TestImage:
image_input = gr.Image()
iface = gr.Interface(
lambda x: PIL.Image.open(x).rotate(90, expand=True),
gr.Image(shape=(30, 10), type="file"),
gr.Image(shape=(30, 10), type="filepath"),
"image",
)
output = iface(img)
@ -751,7 +748,7 @@ class TestAudio:
y_audio = gr.processing_utils.decode_base64_to_file(
deepcopy(media_data.BASE64_AUDIO)["data"]
)
audio_output = gr.Audio(type="file")
audio_output = gr.Audio(type="filepath")
assert filecmp.cmp(y_audio.name, audio_output.postprocess(y_audio.name)["name"])
assert audio_output.get_config() == {
"name": "audio",
@ -774,6 +771,10 @@ class TestAudio:
}
).endswith(".wav")
output1 = audio_output.postprocess(y_audio.name)
output2 = audio_output.postprocess(y_audio.name)
assert output1 == output2
def test_tokenize(self):
"""
Tokenize, get_masked_inputs
@ -837,6 +838,11 @@ class TestFile:
assert serialized["orig_name"] == "sample_file.pdf"
assert output.orig_name == "test/test_files/sample_file.pdf"
x_file["is_file"] = True
input1 = file_input.preprocess(x_file)
input2 = file_input.preprocess(x_file)
assert input1.name == input2.name
assert isinstance(file_input.generate_sample(), dict)
file_input = gr.File(label="Upload Your File")
assert file_input.get_config() == {
@ -860,6 +866,10 @@ class TestFile:
output = file_input.preprocess(x_file)
assert type(output) == bytes
output1 = file_input.postprocess("test/test_files/sample_file.pdf")
output2 = file_input.postprocess("test/test_files/sample_file.pdf")
assert output1 == output2
def test_in_interface_as_input(self):
"""
Interface, process
@ -886,6 +896,22 @@ class TestFile:
assert iface("hello world").endswith(".txt")
class TestUploadButton:
def test_component_functions(self):
"""
preprocess
"""
x_file = deepcopy(media_data.BASE64_FILE)
upload_input = gr.UploadButton()
input = upload_input.preprocess(x_file)
assert isinstance(input, tempfile._TemporaryFileWrapper)
x_file["is_file"] = True
input1 = upload_input.preprocess(x_file)
input2 = upload_input.preprocess(x_file)
assert input1.name == input2.name
class TestDataframe:
def test_component_functions(self):
"""
@ -1120,8 +1146,10 @@ class TestVideo:
"""
x_video = deepcopy(media_data.BASE64_VIDEO)
video_input = gr.Video()
output = video_input.preprocess(x_video)
assert isinstance(output, str)
output1 = video_input.preprocess(x_video)
assert isinstance(output1, str)
output2 = video_input.preprocess(x_video)
assert output1 == output2
assert isinstance(video_input.generate_sample(), dict)
video_input = gr.Video(label="Upload Your Video")
@ -1153,7 +1181,11 @@ class TestVideo:
# Output functionalities
y_vid_path = "test/test_files/video_sample.mp4"
video_output = gr.Video()
assert video_output.postprocess(y_vid_path)["name"].endswith("mp4")
output1 = video_output.postprocess(y_vid_path)["name"]
assert output1.endswith("mp4")
output2 = video_output.postprocess(y_vid_path)["name"]
assert output1 == output2
assert video_output.deserialize(
{
"name": None,
@ -1182,7 +1214,7 @@ class TestVideo:
test_file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
# This file has a playable container but not playable codec
with tempfile.NamedTemporaryFile(
suffix="bad_video.mp4"
suffix="bad_video.mp4", delete=False
) as tmp_not_playable_vid:
bad_vid = str(test_file_dir / "bad_video_sample.mp4")
assert not processing_utils.video_is_playable(bad_vid)
@ -1196,7 +1228,7 @@ class TestVideo:
# This file has a playable codec but not a playable container
with tempfile.NamedTemporaryFile(
suffix="playable_but_bad_container.mkv"
suffix="playable_but_bad_container.mkv", delete=False
) as tmp_not_playable_vid:
bad_vid = str(test_file_dir / "playable_but_bad_container.mkv")
assert not processing_utils.video_is_playable(bad_vid)
@ -1207,8 +1239,10 @@ class TestVideo:
)
assert processing_utils.video_is_playable(str(full_path_to_output))
@patch("os.path.exists", MagicMock(return_value=False))
@patch("gradio.components.FFmpeg")
def test_video_preprocessing_flips_video_for_webcam(self, mock_ffmpeg):
# Ensures that the cached temp video file is not used so that ffmpeg is called for each test
x_video = deepcopy(media_data.BASE64_VIDEO)
video_input = gr.Video(source="webcam")
_ = video_input.preprocess(x_video)
@ -1359,8 +1393,8 @@ class TestLabel:
test_file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
path = str(pathlib.Path(test_file_dir, "test_label_json.json"))
label = label_output.postprocess(path)
assert label["label"] == "web site"
label_dict = label_output.postprocess(path)
assert label_dict["label"] == "web site"
assert label_output.get_config() == {
"name": "label",
@ -1692,6 +1726,11 @@ class TestModel3D:
"style": {},
} == component.get_config()
file = "test/test_files/Box.gltf"
output1 = component.postprocess(file)
output2 = component.postprocess(file)
assert output1 == output2
def test_in_interface(self):
"""
Interface, process
@ -1787,7 +1826,8 @@ class TestGallery:
path = gallery.deserialize(data, tmpdir)
assert path.endswith("my-uuid")
data_restored = gallery.serialize(path)
assert sorted(data) == sorted([d["data"] for d in data_restored])
data_restored = [d[0]["data"] for d in data_restored]
assert sorted(data) == sorted(data_restored)
class TestState:

View File

@ -82,7 +82,7 @@ class TestExamples:
)
prediction = await examples.load_from_cache(0)
assert prediction[0][0]["data"] == gr.media_data.BASE64_IMAGE
assert prediction[0][0][0]["data"] == gr.media_data.BASE64_IMAGE
@patch("gradio.examples.CACHED_FOLDER", tempfile.mkdtemp())

View File

@ -3,7 +3,7 @@ import pathlib
import shutil
import tempfile
from copy import deepcopy
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import ffmpy
import matplotlib.pyplot as plt
@ -11,40 +11,40 @@ import numpy as np
import pytest
from PIL import Image
import gradio as gr
from gradio import media_data
from gradio import media_data, processing_utils
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestImagePreprocessing:
def test_decode_base64_to_image(self):
output_image = gr.processing_utils.decode_base64_to_image(
output_image = processing_utils.decode_base64_to_image(
deepcopy(media_data.BASE64_IMAGE)
)
assert isinstance(output_image, Image.Image)
def test_encode_url_or_file_to_base64(self):
output_base64 = gr.processing_utils.encode_url_or_file_to_base64(
output_base64 = processing_utils.encode_url_or_file_to_base64(
"gradio/test_data/test_image.png"
)
assert output_base64 == deepcopy(media_data.BASE64_IMAGE)
def test_encode_file_to_base64(self):
output_base64 = gr.processing_utils.encode_file_to_base64(
output_base64 = processing_utils.encode_file_to_base64(
"gradio/test_data/test_image.png"
)
assert output_base64 == deepcopy(media_data.BASE64_IMAGE)
@pytest.mark.flaky
def test_encode_url_to_base64(self):
output_base64 = gr.processing_utils.encode_url_to_base64(
output_base64 = processing_utils.encode_url_to_base64(
"https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/test_image.png"
)
assert output_base64 == deepcopy(media_data.BASE64_IMAGE)
def test_encode_plot_to_base64(self):
plt.plot([1, 2, 3, 4])
output_base64 = gr.processing_utils.encode_plot_to_base64(plt)
output_base64 = processing_utils.encode_plot_to_base64(plt)
assert output_base64.startswith(
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAo"
)
@ -53,14 +53,14 @@ class TestImagePreprocessing:
img = Image.open("gradio/test_data/test_image.png")
img = img.convert("RGB")
numpy_data = np.asarray(img, dtype=np.uint8)
output_base64 = gr.processing_utils.encode_array_to_base64(numpy_data)
output_base64 = processing_utils.encode_array_to_base64(numpy_data)
assert output_base64 == deepcopy(media_data.ARRAY_TO_BASE64_IMAGE)
def test_encode_pil_to_base64(self):
img = Image.open("gradio/test_data/test_image.png")
img = img.convert("RGB")
img.info = {} # Strip metadata
output_base64 = gr.processing_utils.encode_pil_to_base64(img)
output_base64 = processing_utils.encode_pil_to_base64(img)
assert output_base64 == deepcopy(media_data.ARRAY_TO_BASE64_IMAGE)
def test_encode_pil_to_base64_keeps_pnginfo(self):
@ -68,30 +68,30 @@ class TestImagePreprocessing:
input_img = input_img.convert("RGB")
input_img.info = {"key1": "value1", "key2": "value2"}
encoded_image = gr.processing_utils.encode_pil_to_base64(input_img)
decoded_image = gr.processing_utils.decode_base64_to_image(encoded_image)
encoded_image = processing_utils.encode_pil_to_base64(input_img)
decoded_image = processing_utils.decode_base64_to_image(encoded_image)
assert decoded_image.info == input_img.info
def test_resize_and_crop(self):
img = Image.open("gradio/test_data/test_image.png")
new_img = gr.processing_utils.resize_and_crop(img, (20, 20))
new_img = processing_utils.resize_and_crop(img, (20, 20))
assert new_img.size == (20, 20)
with pytest.raises(ValueError):
gr.processing_utils.resize_and_crop(
processing_utils.resize_and_crop(
**{"img": img, "size": (20, 20), "crop_type": "test"}
)
class TestAudioPreprocessing:
def test_audio_from_file(self):
audio = gr.processing_utils.audio_from_file("gradio/test_data/test_audio.wav")
audio = processing_utils.audio_from_file("gradio/test_data/test_audio.wav")
assert audio[0] == 22050
assert isinstance(audio[1], np.ndarray)
def test_audio_to_file(self):
audio = gr.processing_utils.audio_from_file("gradio/test_data/test_audio.wav")
gr.processing_utils.audio_to_file(audio[0], audio[1], "test_audio_to_file")
audio = processing_utils.audio_from_file("gradio/test_data/test_audio.wav")
processing_utils.audio_to_file(audio[0], audio[1], "test_audio_to_file")
assert os.path.exists("test_audio_to_file")
os.remove("test_audio_to_file")
@ -102,65 +102,117 @@ class TestAudioPreprocessing:
audio[1] = 32766
audio_ = audio.astype("float64")
audio_ = gr.processing_utils.convert_to_16_bit_wav(audio_)
audio_ = processing_utils.convert_to_16_bit_wav(audio_)
assert np.allclose(audio, audio_)
assert audio_.dtype == "int16"
audio_ = audio.astype("float32")
audio_ = gr.processing_utils.convert_to_16_bit_wav(audio_)
audio_ = processing_utils.convert_to_16_bit_wav(audio_)
assert np.allclose(audio, audio_)
assert audio_.dtype == "int16"
audio_ = gr.processing_utils.convert_to_16_bit_wav(audio)
audio_ = processing_utils.convert_to_16_bit_wav(audio)
assert np.allclose(audio, audio_)
assert audio_.dtype == "int16"
class TestTempFileManager:
def test_get_temp_file_path(self):
temp_file_manager = processing_utils.TempFileManager()
temp_file_manager.hash_file = MagicMock(return_value="")
filepath = "C:/gradio/test_image.png"
temp_filepath = temp_file_manager.get_temp_file_path(filepath)
assert "test_image" in temp_filepath
assert temp_filepath.endswith(".png")
filepath = "ABCabc123.csv"
temp_filepath = temp_file_manager.get_temp_file_path(filepath)
assert "ABCabc123" in temp_filepath
assert temp_filepath.endswith(".csv")
filepath = "lion#1.jpeg"
temp_filepath = temp_file_manager.get_temp_file_path(filepath)
assert "lion1" in temp_filepath
assert temp_filepath.endswith(".jpeg")
filepath = "%%lio|n#1.jpeg"
temp_filepath = temp_file_manager.get_temp_file_path(filepath)
assert "lion1" in temp_filepath
assert temp_filepath.endswith(".jpeg")
filepath = "/home/lion--_1.txt"
temp_filepath = temp_file_manager.get_temp_file_path(filepath)
assert "lion--_1" in temp_filepath
assert temp_filepath.endswith(".txt")
def test_hash_file(self):
temp_file_manager = processing_utils.TempFileManager()
h1 = temp_file_manager.hash_file("gradio/test_data/cheetah1.jpg")
h2 = temp_file_manager.hash_file("gradio/test_data/cheetah1-copy.jpg")
h3 = temp_file_manager.hash_file("gradio/test_data/cheetah2.jpg")
assert h1 == h2
assert h1 != h3
@patch("shutil.copy2")
def test_make_temp_copy_if_needed(self, mock_copy):
temp_file_manager = processing_utils.TempFileManager()
f = temp_file_manager.make_temp_copy_if_needed("gradio/test_data/cheetah1.jpg")
try: # Delete if already exists from before this test
os.remove(f)
except OSError:
pass
f = temp_file_manager.make_temp_copy_if_needed("gradio/test_data/cheetah1.jpg")
assert mock_copy.called
assert len(temp_file_manager.temp_files) == 1
f = temp_file_manager.make_temp_copy_if_needed("gradio/test_data/cheetah1.jpg")
assert len(temp_file_manager.temp_files) == 1
f = temp_file_manager.make_temp_copy_if_needed(
"gradio/test_data/cheetah1-copy.jpg"
)
assert len(temp_file_manager.temp_files) == 2
@pytest.mark.flaky
@patch("shutil.copyfileobj")
def test_download_temp_copy_if_needed(self, mock_copy):
temp_file_manager = processing_utils.TempFileManager()
url1 = "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/test_image.png"
url2 = "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/cheetah1.jpg"
f = temp_file_manager.download_temp_copy_if_needed(url1)
try: # Delete if already exists from before this test
os.remove(f)
except OSError:
pass
f = temp_file_manager.download_temp_copy_if_needed(url1)
assert mock_copy.called
assert len(temp_file_manager.temp_files) == 1
f = temp_file_manager.download_temp_copy_if_needed(url1)
assert len(temp_file_manager.temp_files) == 1
f = temp_file_manager.download_temp_copy_if_needed(url2)
assert len(temp_file_manager.temp_files) == 2
class TestOutputPreprocessing:
def test_decode_base64_to_binary(self):
binary = gr.processing_utils.decode_base64_to_binary(
binary = processing_utils.decode_base64_to_binary(
deepcopy(media_data.BASE64_IMAGE)
)
assert deepcopy(media_data.BINARY_IMAGE) == binary
def test_decode_base64_to_file(self):
temp_file = gr.processing_utils.decode_base64_to_file(
temp_file = processing_utils.decode_base64_to_file(
deepcopy(media_data.BASE64_IMAGE)
)
assert isinstance(temp_file, tempfile._TemporaryFileWrapper)
def test_create_tmp_copy_of_file(self):
f = tempfile.NamedTemporaryFile(delete=False)
temp_file = gr.processing_utils.create_tmp_copy_of_file(f.name)
assert isinstance(temp_file, tempfile._TemporaryFileWrapper)
@patch("shutil.copy2")
def test_create_tmp_filenames(self, mock_copy2):
filepath = "C:/gradio/test_image.png"
file_obj = gr.processing_utils.create_tmp_copy_of_file(filepath)
assert "test_image" in file_obj.name
assert file_obj.name.endswith(".png")
filepath = "ABCabc123.csv"
file_obj = gr.processing_utils.create_tmp_copy_of_file(filepath)
assert "ABCabc123" in file_obj.name
assert file_obj.name.endswith(".csv")
filepath = "lion#1.jpeg"
file_obj = gr.processing_utils.create_tmp_copy_of_file(filepath)
assert "lion1" in file_obj.name
assert file_obj.name.endswith(".jpeg")
filepath = "%%lio|n#1.jpeg"
file_obj = gr.processing_utils.create_tmp_copy_of_file(filepath)
assert "lion1" in file_obj.name
assert file_obj.name.endswith(".jpeg")
filepath = "/home/lion--_1.txt"
file_obj = gr.processing_utils.create_tmp_copy_of_file(filepath)
assert "lion--_1" in file_obj.name
assert file_obj.name.endswith(".txt")
float_dtype_list = [
float,
float,
@ -186,7 +238,7 @@ class TestOutputPreprocessing:
for dtype_in, dtype_out in dtype_combin:
x = x.astype(dtype_in)
y = gr.processing_utils._convert(x, dtype_out)
y = processing_utils._convert(x, dtype_out)
assert y.dtype == np.dtype(dtype_out)
def test_subclass_conversion(self):
@ -194,22 +246,22 @@ class TestOutputPreprocessing:
x = np.array([-1, 1])
for dtype in TestOutputPreprocessing.float_dtype_list:
x = x.astype(dtype)
y = gr.processing_utils._convert(x, np.floating)
y = processing_utils._convert(x, np.floating)
assert y.dtype == x.dtype
class TestVideoProcessing:
def test_video_has_playable_codecs(self, test_file_dir):
assert gr.processing_utils.video_is_playable(
assert processing_utils.video_is_playable(
str(test_file_dir / "video_sample.mp4")
)
assert gr.processing_utils.video_is_playable(
assert processing_utils.video_is_playable(
str(test_file_dir / "video_sample.ogg")
)
assert gr.processing_utils.video_is_playable(
assert processing_utils.video_is_playable(
str(test_file_dir / "video_sample.webm")
)
assert not gr.processing_utils.video_is_playable(
assert not processing_utils.video_is_playable(
str(test_file_dir / "bad_video_sample.mp4")
)
@ -223,32 +275,38 @@ class TestVideoProcessing:
self, exception_to_raise, test_file_dir
):
with patch("ffmpy.FFprobe.run", side_effect=exception_to_raise):
with tempfile.NamedTemporaryFile(suffix="out.avi") as tmp_not_playable_vid:
with tempfile.NamedTemporaryFile(
suffix="out.avi", delete=False
) as tmp_not_playable_vid:
shutil.copy(
str(test_file_dir / "bad_video_sample.mp4"),
tmp_not_playable_vid.name,
)
assert gr.processing_utils.video_is_playable(tmp_not_playable_vid.name)
assert processing_utils.video_is_playable(tmp_not_playable_vid.name)
def test_convert_video_to_playable_mp4(self, test_file_dir):
with tempfile.NamedTemporaryFile(suffix="out.avi") as tmp_not_playable_vid:
with tempfile.NamedTemporaryFile(
suffix="out.avi", delete=False
) as tmp_not_playable_vid:
shutil.copy(
str(test_file_dir / "bad_video_sample.mp4"), tmp_not_playable_vid.name
)
playable_vid = gr.processing_utils.convert_video_to_playable_mp4(
playable_vid = processing_utils.convert_video_to_playable_mp4(
tmp_not_playable_vid.name
)
assert gr.processing_utils.video_is_playable(playable_vid)
assert processing_utils.video_is_playable(playable_vid)
@patch("ffmpy.FFmpeg.run", side_effect=raise_ffmpy_runtime_exception)
def test_video_conversion_returns_original_video_if_fails(
self, mock_run, test_file_dir
):
with tempfile.NamedTemporaryFile(suffix="out.avi") as tmp_not_playable_vid:
with tempfile.NamedTemporaryFile(
suffix="out.avi", delete=False
) as tmp_not_playable_vid:
shutil.copy(
str(test_file_dir / "bad_video_sample.mp4"), tmp_not_playable_vid.name
)
playable_vid = gr.processing_utils.convert_video_to_playable_mp4(
playable_vid = processing_utils.convert_video_to_playable_mp4(
tmp_not_playable_vid.name
)
# If the conversion succeeded it'd be .mp4