mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-12 12:40:29 +08:00
Create fewer temp files and make them consistently-named (#2758)
* tmp files * components * changes * temp_file_sets * TempFileManager class * added file manager * internal functions * tests * formatting * changes * video tests * added tests for File * cheetah image * formatting * tests for upload button * temp files * formatting * changelog * fixed audio * tmp files * tmp files * gallery * deprecated type=file * fixing tests * patch os.path.exists * fixed test_video_postprocess_converts_to_playable_format * fixed tests * changelog * fix tests * formatting * added a download_if_needed * formatting * fixed download * fixed gallery demo * fix tests * version * fix for mac * consolidate
This commit is contained in:
parent
714ab2cc09
commit
20057aa946
@ -8,7 +8,9 @@ Adds a `gr.make_waveform()` function that creates a waveform video by combining
|
||||

|
||||
|
||||
## Bug Fixes:
|
||||
No changes to highlight.
|
||||
* Fixed issue where too many temporary files were created, all with randomly generated
|
||||
filepaths. Now fewer temporary files are created and are assigned a path that is a
|
||||
hash based on the file contents by [@abidlabs](https://github.com/abidlabs) in [PR 2758](https://github.com/gradio-app/gradio/pull/2758)
|
||||
|
||||
## Documentation Changes:
|
||||
No changes to highlight.
|
||||
@ -83,6 +85,11 @@ These links are a more secure and scalable way to create shareable demos!
|
||||
No changes to highlight.
|
||||
|
||||
## Full Changelog:
|
||||
* Fixed typo in parameter `visible` in classes in `templates.py` by [@abidlabs](https://github.com/abidlabs) in [PR 2805](https://github.com/gradio-app/gradio/pull/2805)
|
||||
* Switched external service for getting IP address from `https://api.ipify.org` to `https://checkip.amazonaws.com/` by [@abidlabs](https://github.com/abidlabs) in [PR 2810](https://github.com/gradio-app/gradio/pull/2810)
|
||||
|
||||
## Contributors Shoutout:
|
||||
No changes to highlight.
|
||||
|
||||
* Fixed typo in parameter `visible` in classes in `templates.py` by [@abidlabs](https://github.com/abidlabs) in [PR 2805](https://github.com/gradio-app/gradio/pull/2805)
|
||||
* Switched external service for getting IP address from `https://api.ipify.org` to `https://checkip.amazonaws.com/` by [@abidlabs](https://github.com/abidlabs) in [PR 2810](https://github.com/gradio-app/gradio/pull/2810)
|
||||
|
@ -98,8 +98,8 @@ class Block:
|
||||
Context.block.add(self)
|
||||
if Context.root_block is not None:
|
||||
Context.root_block.blocks[self._id] = self
|
||||
if hasattr(self, "temp_dir"):
|
||||
Context.root_block.temp_dirs.add(self.temp_dir)
|
||||
if isinstance(self, components.TempFileManager):
|
||||
Context.root_block.temp_file_sets.append(self.temp_files)
|
||||
return self
|
||||
|
||||
def unrender(self):
|
||||
@ -557,7 +557,7 @@ class Blocks(BlockContext):
|
||||
self.auth = None
|
||||
self.dev_mode = True
|
||||
self.app_id = random.getrandbits(64)
|
||||
self.temp_dirs = set()
|
||||
self.temp_file_sets = []
|
||||
self.title = title
|
||||
self.show_api = True
|
||||
|
||||
@ -741,7 +741,7 @@ class Blocks(BlockContext):
|
||||
)
|
||||
Context.root_block.fns[dependency_offset + i] = new_fn
|
||||
Context.root_block.dependencies.append(dependency)
|
||||
Context.root_block.temp_dirs = Context.root_block.temp_dirs | self.temp_dirs
|
||||
Context.root_block.temp_file_sets.extend(self.temp_file_sets)
|
||||
|
||||
if Context.block is not None:
|
||||
Context.block.children.extend(self.children)
|
||||
|
@ -46,6 +46,7 @@ from gradio.events import (
|
||||
Uploadable,
|
||||
)
|
||||
from gradio.layouts import Column, Form, Row
|
||||
from gradio.processing_utils import TempFileManager
|
||||
from gradio.serializing import (
|
||||
FileSerializable,
|
||||
ImgSerializable,
|
||||
@ -1263,7 +1264,7 @@ class Image(
|
||||
invert_colors: whether to invert the image as a preprocessing step.
|
||||
source: Source of image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "canvas" defaults to a white image that can be edited and drawn upon with tools.
|
||||
tool: Tools used for editing. "editor" allows a full screen editor (and is the default if source is "upload" or "webcam"), "select" provides a cropping and zoom tool, "sketch" allows you to create a binary sketch (and is the default if source="canvas"), and "color-sketch" allows you to created a sketch in different colors. "color-sketch" can be used with source="upload" or "webcam" to allow sketching on an image. "sketch" can also be used with "upload" or "webcam" to create a mask over an image and in that case both the image and mask are passed into the function as a dictionary with keys "image" and "mask" respectively.
|
||||
type: The format the image is converted to before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (width, height, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "file" produces a temporary file object whose path can be retrieved by file_obj.name, "filepath" passes a str path to a temporary file containing the image.
|
||||
type: The format the image is converted to before being passed into the prediction function. "numpy" converts the image to a numpy array with shape (width, height, 3) and values from 0 to 255, "pil" converts the image to a PIL image object, "filepath" passes a str path to a temporary file containing the image.
|
||||
label: component name in interface.
|
||||
show_label: if True, will display label.
|
||||
interactive: if True, will allow users to upload and edit an image; if False, can only be used to display images. If not provided, this is inferred based on whether the component is used as an input or output.
|
||||
@ -1273,7 +1274,7 @@ class Image(
|
||||
mirror_webcam: If True webcam will be mirrored. Default is True.
|
||||
"""
|
||||
self.mirror_webcam = mirror_webcam
|
||||
valid_types = ["numpy", "pil", "file", "filepath"]
|
||||
valid_types = ["numpy", "pil", "filepath"]
|
||||
if type not in valid_types:
|
||||
raise ValueError(
|
||||
f"Invalid value for parameter `type`: {type}. Please choose from one of: {valid_types}"
|
||||
@ -1350,19 +1351,13 @@ class Image(
|
||||
return im
|
||||
elif self.type == "numpy":
|
||||
return np.array(im)
|
||||
elif self.type == "file" or self.type == "filepath":
|
||||
elif self.type == "filepath":
|
||||
file_obj = tempfile.NamedTemporaryFile(
|
||||
delete=False,
|
||||
suffix=("." + fmt.lower() if fmt is not None else ".png"),
|
||||
)
|
||||
im.save(file_obj.name)
|
||||
if self.type == "file":
|
||||
warnings.warn(
|
||||
"The 'file' type has been deprecated. Set parameter 'type' to 'filepath' instead.",
|
||||
)
|
||||
return file_obj
|
||||
else:
|
||||
return file_obj.name
|
||||
return file_obj.name
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown type: "
|
||||
@ -1577,7 +1572,15 @@ class Image(
|
||||
|
||||
|
||||
@document("change", "clear", "play", "pause", "stop", "style")
|
||||
class Video(Changeable, Clearable, Playable, Uploadable, IOComponent, FileSerializable):
|
||||
class Video(
|
||||
Changeable,
|
||||
Clearable,
|
||||
Playable,
|
||||
Uploadable,
|
||||
IOComponent,
|
||||
FileSerializable,
|
||||
TempFileManager,
|
||||
):
|
||||
"""
|
||||
Creates a video component that can be used to upload/record videos (as an input) or display videos (as an output).
|
||||
For the video to be playable in the browser it must have a compatible container and codec combination. Allowed
|
||||
@ -1616,7 +1619,6 @@ class Video(Changeable, Clearable, Playable, Uploadable, IOComponent, FileSerial
|
||||
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
|
||||
mirror_webcam: If True webcma will be mirrored. Default is True.
|
||||
"""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.format = format
|
||||
valid_sources = ["upload", "webcam"]
|
||||
if source not in valid_sources:
|
||||
@ -1625,6 +1627,7 @@ class Video(Changeable, Clearable, Playable, Uploadable, IOComponent, FileSerial
|
||||
)
|
||||
self.source = source
|
||||
self.mirror_webcam = mirror_webcam
|
||||
TempFileManager.__init__(self)
|
||||
IOComponent.__init__(
|
||||
self,
|
||||
label=label,
|
||||
@ -1680,15 +1683,15 @@ class Video(Changeable, Clearable, Playable, Uploadable, IOComponent, FileSerial
|
||||
x.get("is_file", False),
|
||||
)
|
||||
if is_file:
|
||||
file = processing_utils.create_tmp_copy_of_file(file_name)
|
||||
file = self.make_temp_copy_if_needed(file_name)
|
||||
file_name = Path(file)
|
||||
else:
|
||||
file = processing_utils.decode_base64_to_file(
|
||||
file_data, file_path=file_name
|
||||
)
|
||||
file_name = Path(file.name)
|
||||
|
||||
file_name = Path(file.name)
|
||||
uploaded_format = file_name.suffix.replace(".", "")
|
||||
|
||||
modify_format = self.format is not None and uploaded_format != self.format
|
||||
flip = self.source == "webcam" and self.mirror_webcam
|
||||
if modify_format or flip:
|
||||
@ -1698,6 +1701,8 @@ class Video(Changeable, Clearable, Playable, Uploadable, IOComponent, FileSerial
|
||||
output_file_name = str(
|
||||
file_name.with_name(f"{file_name.stem}{flip_suffix}{format}")
|
||||
)
|
||||
if os.path.exists(output_file_name):
|
||||
return output_file_name
|
||||
ff = FFmpeg(
|
||||
inputs={str(file_name): None},
|
||||
outputs={output_file_name: output_options},
|
||||
@ -1724,10 +1729,20 @@ class Video(Changeable, Clearable, Playable, Uploadable, IOComponent, FileSerial
|
||||
if y is None:
|
||||
return None
|
||||
|
||||
if utils.validate_url(y):
|
||||
y = processing_utils.download_to_file(y, dir=self.temp_dir).name
|
||||
|
||||
returned_format = y.split(".")[-1].lower()
|
||||
|
||||
if self.format is None or returned_format == self.format:
|
||||
conversion_needed = False
|
||||
else:
|
||||
conversion_needed = True
|
||||
|
||||
# For cases where the video is a URL and does not need to be converted to another format, we can just return the URL
|
||||
if utils.validate_url(y) and not (conversion_needed):
|
||||
return {"name": y, "data": None, "is_file": True}
|
||||
|
||||
# For cases where the video needs to be converted to another format
|
||||
if utils.validate_url(y):
|
||||
y = self.download_temp_copy_if_needed(y)
|
||||
if (
|
||||
processing_utils.ffmpeg_installed()
|
||||
and not processing_utils.video_is_playable(y)
|
||||
@ -1742,8 +1757,8 @@ class Video(Changeable, Clearable, Playable, Uploadable, IOComponent, FileSerial
|
||||
ff.run()
|
||||
y = output_file_name
|
||||
|
||||
y = processing_utils.create_tmp_copy_of_file(y, dir=self.temp_dir)
|
||||
return {"name": y.name, "data": None, "is_file": True}
|
||||
y = self.make_temp_copy_if_needed(y)
|
||||
return {"name": y, "data": None, "is_file": True}
|
||||
|
||||
def style(
|
||||
self, *, height: Optional[int] = None, width: Optional[int] = None, **kwargs
|
||||
@ -1771,6 +1786,7 @@ class Audio(
|
||||
Uploadable,
|
||||
IOComponent,
|
||||
FileSerializable,
|
||||
TempFileManager,
|
||||
):
|
||||
"""
|
||||
Creates an audio component that can be used to upload/record audio (as an input) or display audio (as an output).
|
||||
@ -1807,14 +1823,13 @@ class Audio(
|
||||
streaming: If set to True when used in a `live` interface, will automatically stream webcam feed. Only valid is source is 'microphone'.
|
||||
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
|
||||
"""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
valid_sources = ["upload", "microphone"]
|
||||
if source not in valid_sources:
|
||||
raise ValueError(
|
||||
f"Invalid value for parameter `source`: {source}. Please choose from one of: {valid_sources}"
|
||||
)
|
||||
self.source = source
|
||||
valid_types = ["numpy", "filepath", "file"]
|
||||
valid_types = ["numpy", "filepath"]
|
||||
if type not in valid_types:
|
||||
raise ValueError(
|
||||
f"Invalid value for parameter `type`: {type}. Please choose from one of: {valid_types}"
|
||||
@ -1827,6 +1842,7 @@ class Audio(
|
||||
raise ValueError(
|
||||
"Audio streaming only available if source is 'microphone'."
|
||||
)
|
||||
TempFileManager.__init__(self)
|
||||
IOComponent.__init__(
|
||||
self,
|
||||
label=label,
|
||||
@ -1869,7 +1885,7 @@ class Audio(
|
||||
def preprocess(self, x: Dict[str, str] | None) -> Tuple[int, np.array] | str | None:
|
||||
"""
|
||||
Parameters:
|
||||
x: JSON object with filename as 'name' property and base64 data as 'data' property
|
||||
x: dictionary with keys "name", "data", "is_file", "crop_min", "crop_max".
|
||||
Returns:
|
||||
audio in requested format
|
||||
"""
|
||||
@ -1882,25 +1898,25 @@ class Audio(
|
||||
)
|
||||
crop_min, crop_max = x.get("crop_min", 0), x.get("crop_max", 100)
|
||||
if is_file:
|
||||
file_obj = processing_utils.create_tmp_copy_of_file(file_name)
|
||||
if utils.validate_url(file_name):
|
||||
temp_file_path = self.download_temp_copy_if_needed(file_name)
|
||||
else:
|
||||
temp_file_path = self.make_temp_copy_if_needed(file_name)
|
||||
else:
|
||||
file_obj = processing_utils.decode_base64_to_file(
|
||||
temp_file_obj = processing_utils.decode_base64_to_file(
|
||||
file_data, file_path=file_name
|
||||
)
|
||||
temp_file_path = temp_file_obj.name
|
||||
|
||||
sample_rate, data = processing_utils.audio_from_file(
|
||||
file_obj.name, crop_min=crop_min, crop_max=crop_max
|
||||
temp_file_path, crop_min=crop_min, crop_max=crop_max
|
||||
)
|
||||
|
||||
if self.type == "numpy":
|
||||
return sample_rate, data
|
||||
elif self.type in ["file", "filepath"]:
|
||||
processing_utils.audio_to_file(sample_rate, data, file_obj.name)
|
||||
if self.type == "file":
|
||||
warnings.warn(
|
||||
"The 'file' type has been deprecated. Set parameter 'type' to 'filepath' instead.",
|
||||
)
|
||||
return file_obj
|
||||
else:
|
||||
return file_obj.name
|
||||
elif self.type == "filepath":
|
||||
processing_utils.audio_to_file(sample_rate, data, temp_file_path)
|
||||
return temp_file_path
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown type: "
|
||||
@ -2009,17 +2025,18 @@ class Audio(
|
||||
return None
|
||||
|
||||
if utils.validate_url(y):
|
||||
file = processing_utils.download_to_file(y, dir=self.temp_dir)
|
||||
elif isinstance(y, tuple):
|
||||
sample_rate, data = y
|
||||
file = tempfile.NamedTemporaryFile(
|
||||
suffix=".wav", dir=self.temp_dir, delete=False
|
||||
)
|
||||
processing_utils.audio_to_file(sample_rate, data, file.name)
|
||||
else:
|
||||
file = processing_utils.create_tmp_copy_of_file(y, dir=self.temp_dir)
|
||||
return {"name": y, "data": None, "is_file": True}
|
||||
|
||||
return {"name": file.name, "data": None, "is_file": True}
|
||||
if isinstance(y, tuple):
|
||||
sample_rate, data = y
|
||||
file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
processing_utils.audio_to_file(sample_rate, data, file.name)
|
||||
file_path = file.name
|
||||
self.temp_files.add(file_path)
|
||||
else:
|
||||
file_path = self.make_temp_copy_if_needed(y)
|
||||
|
||||
return {"name": file_path, "data": None, "is_file": True}
|
||||
|
||||
def stream(
|
||||
self,
|
||||
@ -2072,7 +2089,9 @@ class Audio(
|
||||
|
||||
|
||||
@document("change", "clear", "style")
|
||||
class File(Changeable, Clearable, Uploadable, IOComponent, FileSerializable):
|
||||
class File(
|
||||
Changeable, Clearable, Uploadable, IOComponent, FileSerializable, TempFileManager
|
||||
):
|
||||
"""
|
||||
Creates a file component that allows uploading generic file (when used as an input) and or displaying generic files (output).
|
||||
Preprocessing: passes the uploaded file as a {file-object} or {List[file-object]} depending on `file_count` (or a {bytes}/{List{bytes}} depending on `type`)
|
||||
@ -2107,7 +2126,6 @@ class File(Changeable, Clearable, Uploadable, IOComponent, FileSerializable):
|
||||
visible: If False, component will be hidden.
|
||||
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
|
||||
"""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.file_count = file_count
|
||||
self.file_types = file_types
|
||||
valid_types = [
|
||||
@ -2125,6 +2143,7 @@ class File(Changeable, Clearable, Uploadable, IOComponent, FileSerializable):
|
||||
)
|
||||
self.type = type
|
||||
self.test_input = None
|
||||
TempFileManager.__init__(self)
|
||||
IOComponent.__init__(
|
||||
self,
|
||||
label=label,
|
||||
@ -2162,7 +2181,11 @@ class File(Changeable, Clearable, Uploadable, IOComponent, FileSerializable):
|
||||
}
|
||||
return IOComponent.add_interactive_to_config(updated_config, interactive)
|
||||
|
||||
def preprocess(self, x: List[Dict[str, str]] | None) -> str | List[str]:
|
||||
def preprocess(
|
||||
self, x: List[Dict[str, str]] | None
|
||||
) -> tempfile._TemporaryFileWrapper | List[
|
||||
tempfile._TemporaryFileWrapper
|
||||
] | bytes | List[bytes]:
|
||||
"""
|
||||
Parameters:
|
||||
x: List of JSON objects with filename as 'name' property and base64 data as 'data' property
|
||||
@ -2180,7 +2203,9 @@ class File(Changeable, Clearable, Uploadable, IOComponent, FileSerializable):
|
||||
)
|
||||
if self.type == "file":
|
||||
if is_file:
|
||||
file = processing_utils.create_tmp_copy_of_file(file_name)
|
||||
temp_file_path = self.make_temp_copy_if_needed(file_name)
|
||||
file = tempfile.NamedTemporaryFile(delete=False)
|
||||
file.name = temp_file_path
|
||||
file.orig_name = file_name
|
||||
else:
|
||||
file = processing_utils.decode_base64_to_file(
|
||||
@ -2216,7 +2241,9 @@ class File(Changeable, Clearable, Uploadable, IOComponent, FileSerializable):
|
||||
def generate_sample(self):
|
||||
return deepcopy(media_data.BASE64_FILE)
|
||||
|
||||
def postprocess(self, y: str) -> Dict:
|
||||
def postprocess(
|
||||
self, y: str | List[str]
|
||||
) -> Dict[str | Any] | List[Dict[str | Any]]:
|
||||
"""
|
||||
Parameters:
|
||||
y: file path
|
||||
@ -2229,9 +2256,7 @@ class File(Changeable, Clearable, Uploadable, IOComponent, FileSerializable):
|
||||
return [
|
||||
{
|
||||
"orig_name": os.path.basename(file),
|
||||
"name": processing_utils.create_tmp_copy_of_file(
|
||||
file, dir=self.temp_dir
|
||||
).name,
|
||||
"name": self.make_temp_copy_if_needed(file),
|
||||
"size": os.path.getsize(file),
|
||||
"data": None,
|
||||
"is_file": True,
|
||||
@ -2241,9 +2266,7 @@ class File(Changeable, Clearable, Uploadable, IOComponent, FileSerializable):
|
||||
else:
|
||||
return {
|
||||
"orig_name": os.path.basename(y),
|
||||
"name": processing_utils.create_tmp_copy_of_file(
|
||||
y, dir=self.temp_dir
|
||||
).name,
|
||||
"name": self.make_temp_copy_if_needed(y),
|
||||
"size": os.path.getsize(y),
|
||||
"data": None,
|
||||
"is_file": True,
|
||||
@ -2780,7 +2803,9 @@ class Button(Clickable, IOComponent, SimpleSerializable):
|
||||
|
||||
|
||||
@document("click", "upload", "style")
|
||||
class UploadButton(Clickable, Uploadable, IOComponent, SimpleSerializable):
|
||||
class UploadButton(
|
||||
Clickable, Uploadable, IOComponent, SimpleSerializable, TempFileManager
|
||||
):
|
||||
"""
|
||||
Used to create an upload button, when cicked allows a user to upload files that satisfy the specified file type or generic files (if file_type not set).
|
||||
Preprocessing: passes the uploaded file as a {file-object} or {List[file-object]} depending on `file_count` (or a {bytes}/{List{bytes}} depending on `type`)
|
||||
@ -2811,11 +2836,11 @@ class UploadButton(Clickable, Uploadable, IOComponent, SimpleSerializable):
|
||||
visible: If False, component will be hidden.
|
||||
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
|
||||
"""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.type = type
|
||||
self.file_count = file_count
|
||||
self.file_types = file_types
|
||||
self.label = label
|
||||
TempFileManager.__init__(self)
|
||||
IOComponent.__init__(
|
||||
self, label=label, visible=visible, elem_id=elem_id, value=value, **kwargs
|
||||
)
|
||||
@ -2843,7 +2868,11 @@ class UploadButton(Clickable, Uploadable, IOComponent, SimpleSerializable):
|
||||
}
|
||||
return IOComponent.add_interactive_to_config(updated_config, interactive)
|
||||
|
||||
def preprocess(self, x: List[Dict[str, str]] | None) -> str | List[str]:
|
||||
def preprocess(
|
||||
self, x: List[Dict[str, str]] | None
|
||||
) -> tempfile._TemporaryFileWrapper | List[
|
||||
tempfile._TemporaryFileWrapper
|
||||
] | bytes | List[bytes]:
|
||||
"""
|
||||
Parameters:
|
||||
x: List of JSON objects with filename as 'name' property and base64 data as 'data' property
|
||||
@ -2861,13 +2890,13 @@ class UploadButton(Clickable, Uploadable, IOComponent, SimpleSerializable):
|
||||
)
|
||||
if self.type == "file":
|
||||
if is_file:
|
||||
file = processing_utils.create_tmp_copy_of_file(
|
||||
file_name, dir=self.temp_dir
|
||||
)
|
||||
temp_file_path = self.make_temp_copy_if_needed(file_name)
|
||||
file = tempfile.NamedTemporaryFile(delete=False)
|
||||
file.name = temp_file_path
|
||||
file.orig_name = file_name
|
||||
else:
|
||||
file = processing_utils.decode_base64_to_file(
|
||||
data, file_path=file_name, dir=self.temp_dir
|
||||
data, file_path=file_name
|
||||
)
|
||||
file.orig_name = file_name
|
||||
return file
|
||||
@ -3452,7 +3481,7 @@ class HTML(Changeable, IOComponent, SimpleSerializable):
|
||||
|
||||
|
||||
@document("style")
|
||||
class Gallery(IOComponent):
|
||||
class Gallery(IOComponent, TempFileManager):
|
||||
"""
|
||||
Used to display a list of images as a gallery that can be scrolled through.
|
||||
Preprocessing: this component does *not* accept input.
|
||||
@ -3479,7 +3508,7 @@ class Gallery(IOComponent):
|
||||
visible: If False, component will be hidden.
|
||||
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
|
||||
"""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
TempFileManager.__init__(self)
|
||||
super().__init__(
|
||||
label=label,
|
||||
show_label=show_label,
|
||||
@ -3531,25 +3560,27 @@ class Gallery(IOComponent):
|
||||
if isinstance(img, tuple) or isinstance(img, list):
|
||||
img, caption = img
|
||||
if isinstance(img, np.ndarray):
|
||||
file = processing_utils.save_array_to_file(img, dir=self.temp_dir)
|
||||
file = processing_utils.save_array_to_file(img)
|
||||
file_path = os.path.abspath(file.name)
|
||||
self.temp_files.add(file_path)
|
||||
elif isinstance(img, PIL.Image.Image):
|
||||
file = processing_utils.save_pil_to_file(img, dir=self.temp_dir)
|
||||
file = processing_utils.save_pil_to_file(img)
|
||||
file_path = os.path.abspath(file.name)
|
||||
self.temp_files.add(file_path)
|
||||
elif isinstance(img, str):
|
||||
if utils.validate_url(img):
|
||||
file = processing_utils.download_to_file(img, dir=self.temp_dir)
|
||||
file_path = img
|
||||
else:
|
||||
file = processing_utils.create_tmp_copy_of_file(
|
||||
img, dir=self.temp_dir
|
||||
)
|
||||
file_path = self.make_temp_copy_if_needed(img)
|
||||
else:
|
||||
raise ValueError(f"Cannot process type as image: {type(img)}")
|
||||
|
||||
if caption is not None:
|
||||
output.append(
|
||||
[{"name": file.name, "data": None, "is_file": True}, caption]
|
||||
[{"name": file_path, "data": None, "is_file": True}, caption]
|
||||
)
|
||||
else:
|
||||
output.append({"name": file.name, "data": None, "is_file": True})
|
||||
output.append({"name": file_path, "data": None, "is_file": True})
|
||||
|
||||
return output
|
||||
|
||||
@ -3581,6 +3612,7 @@ class Gallery(IOComponent):
|
||||
if x is None:
|
||||
return None
|
||||
gallery_path = os.path.join(save_dir, str(uuid.uuid4()))
|
||||
os.makedirs(gallery_path)
|
||||
captions = {}
|
||||
for img_data in x:
|
||||
if isinstance(img_data, list) or isinstance(img_data, tuple):
|
||||
@ -3588,9 +3620,7 @@ class Gallery(IOComponent):
|
||||
else:
|
||||
caption = None
|
||||
name = FileSerializable.deserialize(self, img_data, gallery_path)
|
||||
if caption is not None:
|
||||
captions[name] = caption
|
||||
if len(captions):
|
||||
captions[name] = caption
|
||||
captions_file = os.path.join(gallery_path, "captions.json")
|
||||
with open(captions_file, "w") as captions_json:
|
||||
json.dump(captions, captions_json)
|
||||
@ -3599,21 +3629,11 @@ class Gallery(IOComponent):
|
||||
def serialize(self, x: Any, load_dir: str = "", called_directly: bool = False):
|
||||
files = []
|
||||
captions_file = os.path.join(x, "captions.json")
|
||||
for file in os.listdir(x):
|
||||
file_path = os.path.join(x, file)
|
||||
if file_path == captions_file:
|
||||
continue
|
||||
if os.path.exists(captions_file):
|
||||
with open(captions_file) as captions_json:
|
||||
captions = json.load(captions_json)
|
||||
caption = captions.get(file_path)
|
||||
else:
|
||||
caption = None
|
||||
img = FileSerializable.serialize(self, file_path)
|
||||
if caption:
|
||||
files.append([img, caption])
|
||||
else:
|
||||
files.append(img)
|
||||
with open(captions_file) as captions_json:
|
||||
captions = json.load(captions_json)
|
||||
for file_name, caption in captions.items():
|
||||
img = FileSerializable.serialize(self, file_name)
|
||||
files.append([img, caption])
|
||||
return files
|
||||
|
||||
|
||||
@ -3735,7 +3755,9 @@ class Chatbot(Changeable, IOComponent, JSONSerializable):
|
||||
|
||||
|
||||
@document("change", "edit", "clear", "style")
|
||||
class Model3D(Changeable, Editable, Clearable, IOComponent, FileSerializable):
|
||||
class Model3D(
|
||||
Changeable, Editable, Clearable, IOComponent, FileSerializable, TempFileManager
|
||||
):
|
||||
"""
|
||||
Component allows users to upload or view 3D Model files (.obj, .glb, or .gltf).
|
||||
Preprocessing: This component passes the uploaded file as a {str} filepath.
|
||||
@ -3765,8 +3787,8 @@ class Model3D(Changeable, Editable, Clearable, IOComponent, FileSerializable):
|
||||
visible: If False, component will be hidden.
|
||||
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
|
||||
"""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.clear_color = clear_color or [0.2, 0.2, 0.2, 1.0]
|
||||
TempFileManager.__init__(self)
|
||||
IOComponent.__init__(
|
||||
self,
|
||||
label=label,
|
||||
@ -3805,7 +3827,7 @@ class Model3D(Changeable, Editable, Clearable, IOComponent, FileSerializable):
|
||||
Parameters:
|
||||
x: JSON object with filename as 'name' property and base64 data as 'data' property
|
||||
Returns:
|
||||
file path to 3D image model
|
||||
string file path to temporary file with the 3D image model
|
||||
"""
|
||||
if x is None:
|
||||
return x
|
||||
@ -3815,13 +3837,14 @@ class Model3D(Changeable, Editable, Clearable, IOComponent, FileSerializable):
|
||||
x.get("is_file", False),
|
||||
)
|
||||
if is_file:
|
||||
file = processing_utils.create_tmp_copy_of_file(file_name)
|
||||
temp_file_path = self.make_temp_copy_if_needed(file_name)
|
||||
else:
|
||||
file = processing_utils.decode_base64_to_file(
|
||||
temp_file = processing_utils.decode_base64_to_file(
|
||||
file_data, file_path=file_name
|
||||
)
|
||||
file_name = file.name
|
||||
return file_name
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
return temp_file_path
|
||||
|
||||
def generate_sample(self):
|
||||
return media_data.BASE64_MODEL3D
|
||||
@ -3836,7 +3859,7 @@ class Model3D(Changeable, Editable, Clearable, IOComponent, FileSerializable):
|
||||
if y is None:
|
||||
return y
|
||||
data = {
|
||||
"name": processing_utils.create_tmp_copy_of_file(y, dir=self.temp_dir).name,
|
||||
"name": self.make_temp_copy_if_needed(y),
|
||||
"data": None,
|
||||
"is_file": True,
|
||||
}
|
||||
|
File diff suppressed because one or more lines are too long
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
@ -8,9 +9,10 @@ import pathlib
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import urllib.request
|
||||
import warnings
|
||||
from io import BytesIO
|
||||
from typing import Dict
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@ -108,15 +110,6 @@ def encode_plot_to_base64(plt):
|
||||
return "data:image/png;base64," + base64_str
|
||||
|
||||
|
||||
def download_to_file(url, dir=None):
|
||||
file_suffix = os.path.splitext(url)[1]
|
||||
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=file_suffix, dir=dir)
|
||||
with requests.get(url, stream=True) as r:
|
||||
with open(file_obj.name, "wb") as f:
|
||||
shutil.copyfileobj(r.raw, f)
|
||||
return file_obj
|
||||
|
||||
|
||||
def save_array_to_file(image_array, dir=None):
|
||||
pil_image = Image.fromarray(_convert(image_array, np.uint8, force_copy=False))
|
||||
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
|
||||
@ -293,20 +286,6 @@ def decode_base64_to_file(
|
||||
return file_obj
|
||||
|
||||
|
||||
def create_tmp_copy_of_file_or_url(file_path_or_url: str, dir=None):
|
||||
try:
|
||||
response = requests.get(file_path_or_url, stream=True)
|
||||
if file_path_or_url.find("/"):
|
||||
new_file_path = file_path_or_url.rsplit("/", 1)[1]
|
||||
else:
|
||||
new_file_path = "file.txt"
|
||||
with open(new_file_path, "wb") as out_file:
|
||||
shutil.copyfileobj(response.raw, out_file)
|
||||
del response
|
||||
except (requests.exceptions.MissingSchema, requests.exceptions.InvalidSchema):
|
||||
return create_tmp_copy_of_file(file_path_or_url, dir)
|
||||
|
||||
|
||||
def dict_or_str_to_json_file(jsn, dir=None):
|
||||
if dir is not None:
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
@ -325,10 +304,95 @@ def file_to_json(file_path):
|
||||
return json.load(open(file_path))
|
||||
|
||||
|
||||
class TempFileManager:
|
||||
"""
|
||||
A class that should be inherited by any Component that needs to manage temporary files.
|
||||
It should be instantiated in the __init__ method of the component.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Set stores all the temporary files created by this component.
|
||||
self.temp_files = set()
|
||||
|
||||
def hash_file(self, file_path: str, chunk_num_blocks: int = 128) -> str:
|
||||
sha1 = hashlib.sha1()
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(chunk_num_blocks * sha1.block_size), b""):
|
||||
sha1.update(chunk)
|
||||
return sha1.hexdigest()
|
||||
|
||||
def hash_url(self, url: str, chunk_num_blocks: int = 128) -> str:
|
||||
sha1 = hashlib.sha1()
|
||||
remote = urllib.request.urlopen(url)
|
||||
max_file_size = 100 * 1024 * 1024 # 100MB
|
||||
total_read = 0
|
||||
while True:
|
||||
data = remote.read(chunk_num_blocks * sha1.block_size)
|
||||
total_read += chunk_num_blocks * sha1.block_size
|
||||
if not data or total_read > max_file_size:
|
||||
break
|
||||
sha1.update(data)
|
||||
return sha1.hexdigest()
|
||||
|
||||
def get_prefix_and_extension(self, file_path_or_url: str) -> Tuple[str, str]:
|
||||
file_name = os.path.basename(file_path_or_url)
|
||||
prefix, extension = file_name, None
|
||||
if "." in file_name:
|
||||
prefix = file_name[0 : file_name.index(".")]
|
||||
extension = "." + file_name[file_name.index(".") + 1 :]
|
||||
else:
|
||||
extension = ""
|
||||
prefix = utils.strip_invalid_filename_characters(prefix)
|
||||
return prefix, extension
|
||||
|
||||
def get_temp_file_path(self, file_path: str) -> str:
|
||||
prefix, extension = self.get_prefix_and_extension(file_path)
|
||||
file_hash = self.hash_file(file_path)
|
||||
return prefix + file_hash + extension
|
||||
|
||||
def get_temp_url_path(self, url: str) -> str:
|
||||
prefix, extension = self.get_prefix_and_extension(url)
|
||||
file_hash = self.hash_url(url)
|
||||
return prefix + file_hash + extension
|
||||
|
||||
def make_temp_copy_if_needed(self, file_path: str) -> str:
|
||||
"""Returns a temporary file path for a copy of the given file path if it does
|
||||
not already exist. Otherwise returns the path to the existing temp file."""
|
||||
f = tempfile.NamedTemporaryFile()
|
||||
temp_dir, _ = os.path.split(f.name)
|
||||
|
||||
temp_file_path = self.get_temp_file_path(file_path)
|
||||
f.name = os.path.join(temp_dir, temp_file_path)
|
||||
full_temp_file_path = os.path.abspath(f.name)
|
||||
|
||||
if not os.path.exists(full_temp_file_path):
|
||||
shutil.copy2(file_path, full_temp_file_path)
|
||||
|
||||
self.temp_files.add(full_temp_file_path)
|
||||
return full_temp_file_path
|
||||
|
||||
def download_temp_copy_if_needed(self, url: str) -> str:
|
||||
"""Downloads a file and makes a temporary file path for a copy if does not already
|
||||
exist. Otherwise returns the path to the existing temp file."""
|
||||
f = tempfile.NamedTemporaryFile()
|
||||
temp_dir, _ = os.path.split(f.name)
|
||||
|
||||
temp_file_path = self.get_temp_url_path(url)
|
||||
f.name = os.path.join(temp_dir, temp_file_path)
|
||||
full_temp_file_path = os.path.abspath(f.name)
|
||||
|
||||
if not os.path.exists(full_temp_file_path):
|
||||
with requests.get(url, stream=True) as r:
|
||||
with open(full_temp_file_path, "wb") as f:
|
||||
shutil.copyfileobj(r.raw, f)
|
||||
|
||||
self.temp_files.add(full_temp_file_path)
|
||||
return full_temp_file_path
|
||||
|
||||
|
||||
def create_tmp_copy_of_file(file_path, dir=None):
|
||||
if dir is not None:
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
|
||||
file_name = os.path.basename(file_path)
|
||||
prefix, extension = file_name, None
|
||||
if "." in file_name:
|
||||
|
@ -260,17 +260,17 @@ class App(FastAPI):
|
||||
return FileResponse(
|
||||
io.BytesIO(file_data), attachment_filename=os.path.basename(path)
|
||||
)
|
||||
elif Path(app.cwd).resolve() in Path(path).resolve().parents or any(
|
||||
Path(temp_dir).resolve() in Path(path).resolve().parents
|
||||
for temp_dir in app.blocks.temp_dirs
|
||||
):
|
||||
if Path(app.cwd).resolve() in Path(
|
||||
path
|
||||
).resolve().parents or os.path.abspath(path) in set().union(
|
||||
*app.blocks.temp_file_sets
|
||||
): # Need to use os.path.abspath in the second condition to be consistent with usage in TempFileManager
|
||||
return FileResponse(
|
||||
Path(path).resolve(), headers={"Accept-Ranges": "bytes"}
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"File cannot be fetched: {path}, perhaps because "
|
||||
f"it is not in any of {app.blocks.temp_dirs}"
|
||||
f"File cannot be fetched: {path}. All files must contained within the Gradio python app working directory, or be a temp file created by the Gradio python app."
|
||||
)
|
||||
|
||||
@app.get("/file/{path:path}", dependencies=[Depends(login_check)])
|
||||
|
@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
from gradio import processing_utils
|
||||
from gradio import processing_utils, utils
|
||||
|
||||
|
||||
class Serializable(ABC):
|
||||
@ -93,7 +93,7 @@ class ImgSerializable(Serializable):
|
||||
|
||||
class FileSerializable(Serializable):
|
||||
def serialize(
|
||||
self, x: str, load_dir: str = "", encryption_key: bytes | None = None
|
||||
self, x: str | None, load_dir: str = "", encryption_key: bytes | None = None
|
||||
) -> Any:
|
||||
"""
|
||||
Convert from human-friendly version of a file (string filepath) to a
|
||||
@ -116,7 +116,10 @@ class FileSerializable(Serializable):
|
||||
}
|
||||
|
||||
def deserialize(
|
||||
self, x: Dict, save_dir: str | None = None, encryption_key: bytes | None = None
|
||||
self,
|
||||
x: str | Dict | None,
|
||||
save_dir: str | None = None,
|
||||
encryption_key: bytes | None = None,
|
||||
):
|
||||
"""
|
||||
Convert from serialized representation of a file (base64) to a human-friendly
|
||||
@ -129,21 +132,26 @@ class FileSerializable(Serializable):
|
||||
if x is None:
|
||||
return None
|
||||
if isinstance(x, str):
|
||||
file = processing_utils.decode_base64_to_file(
|
||||
file_name = processing_utils.decode_base64_to_file(
|
||||
x, dir=save_dir, encryption_key=encryption_key
|
||||
)
|
||||
).name
|
||||
elif isinstance(x, dict):
|
||||
if x.get("is_file", False):
|
||||
file = processing_utils.create_tmp_copy_of_file(x["name"], dir=save_dir)
|
||||
if utils.validate_url(x["name"]):
|
||||
file_name = x["name"]
|
||||
else:
|
||||
file_name = processing_utils.create_tmp_copy_of_file(
|
||||
x["name"], dir=save_dir
|
||||
).name
|
||||
else:
|
||||
file = processing_utils.decode_base64_to_file(
|
||||
file_name = processing_utils.decode_base64_to_file(
|
||||
x["data"], dir=save_dir, encryption_key=encryption_key
|
||||
)
|
||||
).name
|
||||
else:
|
||||
raise ValueError(
|
||||
f"A FileSerializable component cannot only deserialize a string or a dict, not a: {type(x)}"
|
||||
)
|
||||
return file.name
|
||||
return file_name
|
||||
|
||||
|
||||
class JSONSerializable(Serializable):
|
||||
|
BIN
gradio/test_data/cheetah1-copy.jpg
Normal file
BIN
gradio/test_data/cheetah1-copy.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 20 KiB |
@ -1 +1 @@
|
||||
3.13.2
|
||||
3.13.2
|
@ -3,6 +3,7 @@ import copy
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
@ -1170,3 +1171,19 @@ async def test_queue_when_using_auth():
|
||||
*[run_ws(loop, tm + sleep_time * (i + 1) - 0.3, i) for i in range(3)]
|
||||
)
|
||||
await group
|
||||
|
||||
|
||||
def test_temp_file_sets_get_extended():
|
||||
test_file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
|
||||
|
||||
with gr.Blocks() as demo1:
|
||||
gr.Video(str(test_file_dir / "video_sample.mp4"))
|
||||
|
||||
with gr.Blocks() as demo2:
|
||||
gr.Audio(str(test_file_dir / "audio_sample.wav"))
|
||||
|
||||
with gr.Blocks() as demo3:
|
||||
demo1.render()
|
||||
demo2.render()
|
||||
|
||||
assert demo3.temp_file_sets == demo1.temp_file_sets + demo2.temp_file_sets
|
||||
|
@ -12,7 +12,7 @@ import shutil
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from difflib import SequenceMatcher
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
@ -591,9 +591,6 @@ class TestImage:
|
||||
image_input = gr.Image(invert_colors=True)
|
||||
assert image_input.preprocess(img) is not None
|
||||
image_input.preprocess(img)
|
||||
with pytest.warns(Warning):
|
||||
file_image = gr.Image(type="file")
|
||||
file_image.preprocess(deepcopy(media_data.BASE64_IMAGE))
|
||||
file_image = gr.Image(type="filepath")
|
||||
assert isinstance(file_image.preprocess(img), str)
|
||||
with pytest.raises(ValueError):
|
||||
@ -627,7 +624,7 @@ class TestImage:
|
||||
image_input = gr.Image()
|
||||
iface = gr.Interface(
|
||||
lambda x: PIL.Image.open(x).rotate(90, expand=True),
|
||||
gr.Image(shape=(30, 10), type="file"),
|
||||
gr.Image(shape=(30, 10), type="filepath"),
|
||||
"image",
|
||||
)
|
||||
output = iface(img)
|
||||
@ -751,7 +748,7 @@ class TestAudio:
|
||||
y_audio = gr.processing_utils.decode_base64_to_file(
|
||||
deepcopy(media_data.BASE64_AUDIO)["data"]
|
||||
)
|
||||
audio_output = gr.Audio(type="file")
|
||||
audio_output = gr.Audio(type="filepath")
|
||||
assert filecmp.cmp(y_audio.name, audio_output.postprocess(y_audio.name)["name"])
|
||||
assert audio_output.get_config() == {
|
||||
"name": "audio",
|
||||
@ -774,6 +771,10 @@ class TestAudio:
|
||||
}
|
||||
).endswith(".wav")
|
||||
|
||||
output1 = audio_output.postprocess(y_audio.name)
|
||||
output2 = audio_output.postprocess(y_audio.name)
|
||||
assert output1 == output2
|
||||
|
||||
def test_tokenize(self):
|
||||
"""
|
||||
Tokenize, get_masked_inputs
|
||||
@ -837,6 +838,11 @@ class TestFile:
|
||||
assert serialized["orig_name"] == "sample_file.pdf"
|
||||
assert output.orig_name == "test/test_files/sample_file.pdf"
|
||||
|
||||
x_file["is_file"] = True
|
||||
input1 = file_input.preprocess(x_file)
|
||||
input2 = file_input.preprocess(x_file)
|
||||
assert input1.name == input2.name
|
||||
|
||||
assert isinstance(file_input.generate_sample(), dict)
|
||||
file_input = gr.File(label="Upload Your File")
|
||||
assert file_input.get_config() == {
|
||||
@ -860,6 +866,10 @@ class TestFile:
|
||||
output = file_input.preprocess(x_file)
|
||||
assert type(output) == bytes
|
||||
|
||||
output1 = file_input.postprocess("test/test_files/sample_file.pdf")
|
||||
output2 = file_input.postprocess("test/test_files/sample_file.pdf")
|
||||
assert output1 == output2
|
||||
|
||||
def test_in_interface_as_input(self):
|
||||
"""
|
||||
Interface, process
|
||||
@ -886,6 +896,22 @@ class TestFile:
|
||||
assert iface("hello world").endswith(".txt")
|
||||
|
||||
|
||||
class TestUploadButton:
|
||||
def test_component_functions(self):
|
||||
"""
|
||||
preprocess
|
||||
"""
|
||||
x_file = deepcopy(media_data.BASE64_FILE)
|
||||
upload_input = gr.UploadButton()
|
||||
input = upload_input.preprocess(x_file)
|
||||
assert isinstance(input, tempfile._TemporaryFileWrapper)
|
||||
|
||||
x_file["is_file"] = True
|
||||
input1 = upload_input.preprocess(x_file)
|
||||
input2 = upload_input.preprocess(x_file)
|
||||
assert input1.name == input2.name
|
||||
|
||||
|
||||
class TestDataframe:
|
||||
def test_component_functions(self):
|
||||
"""
|
||||
@ -1120,8 +1146,10 @@ class TestVideo:
|
||||
"""
|
||||
x_video = deepcopy(media_data.BASE64_VIDEO)
|
||||
video_input = gr.Video()
|
||||
output = video_input.preprocess(x_video)
|
||||
assert isinstance(output, str)
|
||||
output1 = video_input.preprocess(x_video)
|
||||
assert isinstance(output1, str)
|
||||
output2 = video_input.preprocess(x_video)
|
||||
assert output1 == output2
|
||||
|
||||
assert isinstance(video_input.generate_sample(), dict)
|
||||
video_input = gr.Video(label="Upload Your Video")
|
||||
@ -1153,7 +1181,11 @@ class TestVideo:
|
||||
# Output functionalities
|
||||
y_vid_path = "test/test_files/video_sample.mp4"
|
||||
video_output = gr.Video()
|
||||
assert video_output.postprocess(y_vid_path)["name"].endswith("mp4")
|
||||
output1 = video_output.postprocess(y_vid_path)["name"]
|
||||
assert output1.endswith("mp4")
|
||||
output2 = video_output.postprocess(y_vid_path)["name"]
|
||||
assert output1 == output2
|
||||
|
||||
assert video_output.deserialize(
|
||||
{
|
||||
"name": None,
|
||||
@ -1182,7 +1214,7 @@ class TestVideo:
|
||||
test_file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
|
||||
# This file has a playable container but not playable codec
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix="bad_video.mp4"
|
||||
suffix="bad_video.mp4", delete=False
|
||||
) as tmp_not_playable_vid:
|
||||
bad_vid = str(test_file_dir / "bad_video_sample.mp4")
|
||||
assert not processing_utils.video_is_playable(bad_vid)
|
||||
@ -1196,7 +1228,7 @@ class TestVideo:
|
||||
|
||||
# This file has a playable codec but not a playable container
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix="playable_but_bad_container.mkv"
|
||||
suffix="playable_but_bad_container.mkv", delete=False
|
||||
) as tmp_not_playable_vid:
|
||||
bad_vid = str(test_file_dir / "playable_but_bad_container.mkv")
|
||||
assert not processing_utils.video_is_playable(bad_vid)
|
||||
@ -1207,8 +1239,10 @@ class TestVideo:
|
||||
)
|
||||
assert processing_utils.video_is_playable(str(full_path_to_output))
|
||||
|
||||
@patch("os.path.exists", MagicMock(return_value=False))
|
||||
@patch("gradio.components.FFmpeg")
|
||||
def test_video_preprocessing_flips_video_for_webcam(self, mock_ffmpeg):
|
||||
# Ensures that the cached temp video file is not used so that ffmpeg is called for each test
|
||||
x_video = deepcopy(media_data.BASE64_VIDEO)
|
||||
video_input = gr.Video(source="webcam")
|
||||
_ = video_input.preprocess(x_video)
|
||||
@ -1359,8 +1393,8 @@ class TestLabel:
|
||||
|
||||
test_file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
|
||||
path = str(pathlib.Path(test_file_dir, "test_label_json.json"))
|
||||
label = label_output.postprocess(path)
|
||||
assert label["label"] == "web site"
|
||||
label_dict = label_output.postprocess(path)
|
||||
assert label_dict["label"] == "web site"
|
||||
|
||||
assert label_output.get_config() == {
|
||||
"name": "label",
|
||||
@ -1692,6 +1726,11 @@ class TestModel3D:
|
||||
"style": {},
|
||||
} == component.get_config()
|
||||
|
||||
file = "test/test_files/Box.gltf"
|
||||
output1 = component.postprocess(file)
|
||||
output2 = component.postprocess(file)
|
||||
assert output1 == output2
|
||||
|
||||
def test_in_interface(self):
|
||||
"""
|
||||
Interface, process
|
||||
@ -1787,7 +1826,8 @@ class TestGallery:
|
||||
path = gallery.deserialize(data, tmpdir)
|
||||
assert path.endswith("my-uuid")
|
||||
data_restored = gallery.serialize(path)
|
||||
assert sorted(data) == sorted([d["data"] for d in data_restored])
|
||||
data_restored = [d[0]["data"] for d in data_restored]
|
||||
assert sorted(data) == sorted(data_restored)
|
||||
|
||||
|
||||
class TestState:
|
||||
|
@ -82,7 +82,7 @@ class TestExamples:
|
||||
)
|
||||
|
||||
prediction = await examples.load_from_cache(0)
|
||||
assert prediction[0][0]["data"] == gr.media_data.BASE64_IMAGE
|
||||
assert prediction[0][0][0]["data"] == gr.media_data.BASE64_IMAGE
|
||||
|
||||
|
||||
@patch("gradio.examples.CACHED_FOLDER", tempfile.mkdtemp())
|
||||
|
@ -3,7 +3,7 @@ import pathlib
|
||||
import shutil
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import ffmpy
|
||||
import matplotlib.pyplot as plt
|
||||
@ -11,40 +11,40 @@ import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
import gradio as gr
|
||||
from gradio import media_data
|
||||
from gradio import media_data, processing_utils
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
class TestImagePreprocessing:
|
||||
def test_decode_base64_to_image(self):
|
||||
output_image = gr.processing_utils.decode_base64_to_image(
|
||||
output_image = processing_utils.decode_base64_to_image(
|
||||
deepcopy(media_data.BASE64_IMAGE)
|
||||
)
|
||||
assert isinstance(output_image, Image.Image)
|
||||
|
||||
def test_encode_url_or_file_to_base64(self):
|
||||
output_base64 = gr.processing_utils.encode_url_or_file_to_base64(
|
||||
output_base64 = processing_utils.encode_url_or_file_to_base64(
|
||||
"gradio/test_data/test_image.png"
|
||||
)
|
||||
assert output_base64 == deepcopy(media_data.BASE64_IMAGE)
|
||||
|
||||
def test_encode_file_to_base64(self):
|
||||
output_base64 = gr.processing_utils.encode_file_to_base64(
|
||||
output_base64 = processing_utils.encode_file_to_base64(
|
||||
"gradio/test_data/test_image.png"
|
||||
)
|
||||
assert output_base64 == deepcopy(media_data.BASE64_IMAGE)
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_encode_url_to_base64(self):
|
||||
output_base64 = gr.processing_utils.encode_url_to_base64(
|
||||
output_base64 = processing_utils.encode_url_to_base64(
|
||||
"https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/test_image.png"
|
||||
)
|
||||
assert output_base64 == deepcopy(media_data.BASE64_IMAGE)
|
||||
|
||||
def test_encode_plot_to_base64(self):
|
||||
plt.plot([1, 2, 3, 4])
|
||||
output_base64 = gr.processing_utils.encode_plot_to_base64(plt)
|
||||
output_base64 = processing_utils.encode_plot_to_base64(plt)
|
||||
assert output_base64.startswith(
|
||||
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAo"
|
||||
)
|
||||
@ -53,14 +53,14 @@ class TestImagePreprocessing:
|
||||
img = Image.open("gradio/test_data/test_image.png")
|
||||
img = img.convert("RGB")
|
||||
numpy_data = np.asarray(img, dtype=np.uint8)
|
||||
output_base64 = gr.processing_utils.encode_array_to_base64(numpy_data)
|
||||
output_base64 = processing_utils.encode_array_to_base64(numpy_data)
|
||||
assert output_base64 == deepcopy(media_data.ARRAY_TO_BASE64_IMAGE)
|
||||
|
||||
def test_encode_pil_to_base64(self):
|
||||
img = Image.open("gradio/test_data/test_image.png")
|
||||
img = img.convert("RGB")
|
||||
img.info = {} # Strip metadata
|
||||
output_base64 = gr.processing_utils.encode_pil_to_base64(img)
|
||||
output_base64 = processing_utils.encode_pil_to_base64(img)
|
||||
assert output_base64 == deepcopy(media_data.ARRAY_TO_BASE64_IMAGE)
|
||||
|
||||
def test_encode_pil_to_base64_keeps_pnginfo(self):
|
||||
@ -68,30 +68,30 @@ class TestImagePreprocessing:
|
||||
input_img = input_img.convert("RGB")
|
||||
input_img.info = {"key1": "value1", "key2": "value2"}
|
||||
|
||||
encoded_image = gr.processing_utils.encode_pil_to_base64(input_img)
|
||||
decoded_image = gr.processing_utils.decode_base64_to_image(encoded_image)
|
||||
encoded_image = processing_utils.encode_pil_to_base64(input_img)
|
||||
decoded_image = processing_utils.decode_base64_to_image(encoded_image)
|
||||
|
||||
assert decoded_image.info == input_img.info
|
||||
|
||||
def test_resize_and_crop(self):
|
||||
img = Image.open("gradio/test_data/test_image.png")
|
||||
new_img = gr.processing_utils.resize_and_crop(img, (20, 20))
|
||||
new_img = processing_utils.resize_and_crop(img, (20, 20))
|
||||
assert new_img.size == (20, 20)
|
||||
with pytest.raises(ValueError):
|
||||
gr.processing_utils.resize_and_crop(
|
||||
processing_utils.resize_and_crop(
|
||||
**{"img": img, "size": (20, 20), "crop_type": "test"}
|
||||
)
|
||||
|
||||
|
||||
class TestAudioPreprocessing:
|
||||
def test_audio_from_file(self):
|
||||
audio = gr.processing_utils.audio_from_file("gradio/test_data/test_audio.wav")
|
||||
audio = processing_utils.audio_from_file("gradio/test_data/test_audio.wav")
|
||||
assert audio[0] == 22050
|
||||
assert isinstance(audio[1], np.ndarray)
|
||||
|
||||
def test_audio_to_file(self):
|
||||
audio = gr.processing_utils.audio_from_file("gradio/test_data/test_audio.wav")
|
||||
gr.processing_utils.audio_to_file(audio[0], audio[1], "test_audio_to_file")
|
||||
audio = processing_utils.audio_from_file("gradio/test_data/test_audio.wav")
|
||||
processing_utils.audio_to_file(audio[0], audio[1], "test_audio_to_file")
|
||||
assert os.path.exists("test_audio_to_file")
|
||||
os.remove("test_audio_to_file")
|
||||
|
||||
@ -102,65 +102,117 @@ class TestAudioPreprocessing:
|
||||
audio[1] = 32766
|
||||
|
||||
audio_ = audio.astype("float64")
|
||||
audio_ = gr.processing_utils.convert_to_16_bit_wav(audio_)
|
||||
audio_ = processing_utils.convert_to_16_bit_wav(audio_)
|
||||
assert np.allclose(audio, audio_)
|
||||
assert audio_.dtype == "int16"
|
||||
|
||||
audio_ = audio.astype("float32")
|
||||
audio_ = gr.processing_utils.convert_to_16_bit_wav(audio_)
|
||||
audio_ = processing_utils.convert_to_16_bit_wav(audio_)
|
||||
assert np.allclose(audio, audio_)
|
||||
assert audio_.dtype == "int16"
|
||||
|
||||
audio_ = gr.processing_utils.convert_to_16_bit_wav(audio)
|
||||
audio_ = processing_utils.convert_to_16_bit_wav(audio)
|
||||
assert np.allclose(audio, audio_)
|
||||
assert audio_.dtype == "int16"
|
||||
|
||||
|
||||
class TestTempFileManager:
|
||||
def test_get_temp_file_path(self):
|
||||
temp_file_manager = processing_utils.TempFileManager()
|
||||
temp_file_manager.hash_file = MagicMock(return_value="")
|
||||
|
||||
filepath = "C:/gradio/test_image.png"
|
||||
temp_filepath = temp_file_manager.get_temp_file_path(filepath)
|
||||
assert "test_image" in temp_filepath
|
||||
assert temp_filepath.endswith(".png")
|
||||
|
||||
filepath = "ABCabc123.csv"
|
||||
temp_filepath = temp_file_manager.get_temp_file_path(filepath)
|
||||
assert "ABCabc123" in temp_filepath
|
||||
assert temp_filepath.endswith(".csv")
|
||||
|
||||
filepath = "lion#1.jpeg"
|
||||
temp_filepath = temp_file_manager.get_temp_file_path(filepath)
|
||||
assert "lion1" in temp_filepath
|
||||
assert temp_filepath.endswith(".jpeg")
|
||||
|
||||
filepath = "%%lio|n#1.jpeg"
|
||||
temp_filepath = temp_file_manager.get_temp_file_path(filepath)
|
||||
assert "lion1" in temp_filepath
|
||||
assert temp_filepath.endswith(".jpeg")
|
||||
|
||||
filepath = "/home/lion--_1.txt"
|
||||
temp_filepath = temp_file_manager.get_temp_file_path(filepath)
|
||||
assert "lion--_1" in temp_filepath
|
||||
assert temp_filepath.endswith(".txt")
|
||||
|
||||
def test_hash_file(self):
|
||||
temp_file_manager = processing_utils.TempFileManager()
|
||||
h1 = temp_file_manager.hash_file("gradio/test_data/cheetah1.jpg")
|
||||
h2 = temp_file_manager.hash_file("gradio/test_data/cheetah1-copy.jpg")
|
||||
h3 = temp_file_manager.hash_file("gradio/test_data/cheetah2.jpg")
|
||||
assert h1 == h2
|
||||
assert h1 != h3
|
||||
|
||||
@patch("shutil.copy2")
|
||||
def test_make_temp_copy_if_needed(self, mock_copy):
|
||||
temp_file_manager = processing_utils.TempFileManager()
|
||||
|
||||
f = temp_file_manager.make_temp_copy_if_needed("gradio/test_data/cheetah1.jpg")
|
||||
try: # Delete if already exists from before this test
|
||||
os.remove(f)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
f = temp_file_manager.make_temp_copy_if_needed("gradio/test_data/cheetah1.jpg")
|
||||
assert mock_copy.called
|
||||
assert len(temp_file_manager.temp_files) == 1
|
||||
|
||||
f = temp_file_manager.make_temp_copy_if_needed("gradio/test_data/cheetah1.jpg")
|
||||
assert len(temp_file_manager.temp_files) == 1
|
||||
|
||||
f = temp_file_manager.make_temp_copy_if_needed(
|
||||
"gradio/test_data/cheetah1-copy.jpg"
|
||||
)
|
||||
assert len(temp_file_manager.temp_files) == 2
|
||||
|
||||
@pytest.mark.flaky
|
||||
@patch("shutil.copyfileobj")
|
||||
def test_download_temp_copy_if_needed(self, mock_copy):
|
||||
temp_file_manager = processing_utils.TempFileManager()
|
||||
url1 = "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/test_image.png"
|
||||
url2 = "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/cheetah1.jpg"
|
||||
|
||||
f = temp_file_manager.download_temp_copy_if_needed(url1)
|
||||
try: # Delete if already exists from before this test
|
||||
os.remove(f)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
f = temp_file_manager.download_temp_copy_if_needed(url1)
|
||||
assert mock_copy.called
|
||||
assert len(temp_file_manager.temp_files) == 1
|
||||
|
||||
f = temp_file_manager.download_temp_copy_if_needed(url1)
|
||||
assert len(temp_file_manager.temp_files) == 1
|
||||
|
||||
f = temp_file_manager.download_temp_copy_if_needed(url2)
|
||||
assert len(temp_file_manager.temp_files) == 2
|
||||
|
||||
|
||||
class TestOutputPreprocessing:
|
||||
def test_decode_base64_to_binary(self):
|
||||
binary = gr.processing_utils.decode_base64_to_binary(
|
||||
binary = processing_utils.decode_base64_to_binary(
|
||||
deepcopy(media_data.BASE64_IMAGE)
|
||||
)
|
||||
assert deepcopy(media_data.BINARY_IMAGE) == binary
|
||||
|
||||
def test_decode_base64_to_file(self):
|
||||
temp_file = gr.processing_utils.decode_base64_to_file(
|
||||
temp_file = processing_utils.decode_base64_to_file(
|
||||
deepcopy(media_data.BASE64_IMAGE)
|
||||
)
|
||||
assert isinstance(temp_file, tempfile._TemporaryFileWrapper)
|
||||
|
||||
def test_create_tmp_copy_of_file(self):
|
||||
f = tempfile.NamedTemporaryFile(delete=False)
|
||||
temp_file = gr.processing_utils.create_tmp_copy_of_file(f.name)
|
||||
assert isinstance(temp_file, tempfile._TemporaryFileWrapper)
|
||||
|
||||
@patch("shutil.copy2")
|
||||
def test_create_tmp_filenames(self, mock_copy2):
|
||||
filepath = "C:/gradio/test_image.png"
|
||||
file_obj = gr.processing_utils.create_tmp_copy_of_file(filepath)
|
||||
assert "test_image" in file_obj.name
|
||||
assert file_obj.name.endswith(".png")
|
||||
|
||||
filepath = "ABCabc123.csv"
|
||||
file_obj = gr.processing_utils.create_tmp_copy_of_file(filepath)
|
||||
assert "ABCabc123" in file_obj.name
|
||||
assert file_obj.name.endswith(".csv")
|
||||
|
||||
filepath = "lion#1.jpeg"
|
||||
file_obj = gr.processing_utils.create_tmp_copy_of_file(filepath)
|
||||
assert "lion1" in file_obj.name
|
||||
assert file_obj.name.endswith(".jpeg")
|
||||
|
||||
filepath = "%%lio|n#1.jpeg"
|
||||
file_obj = gr.processing_utils.create_tmp_copy_of_file(filepath)
|
||||
assert "lion1" in file_obj.name
|
||||
assert file_obj.name.endswith(".jpeg")
|
||||
|
||||
filepath = "/home/lion--_1.txt"
|
||||
file_obj = gr.processing_utils.create_tmp_copy_of_file(filepath)
|
||||
assert "lion--_1" in file_obj.name
|
||||
assert file_obj.name.endswith(".txt")
|
||||
|
||||
float_dtype_list = [
|
||||
float,
|
||||
float,
|
||||
@ -186,7 +238,7 @@ class TestOutputPreprocessing:
|
||||
|
||||
for dtype_in, dtype_out in dtype_combin:
|
||||
x = x.astype(dtype_in)
|
||||
y = gr.processing_utils._convert(x, dtype_out)
|
||||
y = processing_utils._convert(x, dtype_out)
|
||||
assert y.dtype == np.dtype(dtype_out)
|
||||
|
||||
def test_subclass_conversion(self):
|
||||
@ -194,22 +246,22 @@ class TestOutputPreprocessing:
|
||||
x = np.array([-1, 1])
|
||||
for dtype in TestOutputPreprocessing.float_dtype_list:
|
||||
x = x.astype(dtype)
|
||||
y = gr.processing_utils._convert(x, np.floating)
|
||||
y = processing_utils._convert(x, np.floating)
|
||||
assert y.dtype == x.dtype
|
||||
|
||||
|
||||
class TestVideoProcessing:
|
||||
def test_video_has_playable_codecs(self, test_file_dir):
|
||||
assert gr.processing_utils.video_is_playable(
|
||||
assert processing_utils.video_is_playable(
|
||||
str(test_file_dir / "video_sample.mp4")
|
||||
)
|
||||
assert gr.processing_utils.video_is_playable(
|
||||
assert processing_utils.video_is_playable(
|
||||
str(test_file_dir / "video_sample.ogg")
|
||||
)
|
||||
assert gr.processing_utils.video_is_playable(
|
||||
assert processing_utils.video_is_playable(
|
||||
str(test_file_dir / "video_sample.webm")
|
||||
)
|
||||
assert not gr.processing_utils.video_is_playable(
|
||||
assert not processing_utils.video_is_playable(
|
||||
str(test_file_dir / "bad_video_sample.mp4")
|
||||
)
|
||||
|
||||
@ -223,32 +275,38 @@ class TestVideoProcessing:
|
||||
self, exception_to_raise, test_file_dir
|
||||
):
|
||||
with patch("ffmpy.FFprobe.run", side_effect=exception_to_raise):
|
||||
with tempfile.NamedTemporaryFile(suffix="out.avi") as tmp_not_playable_vid:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix="out.avi", delete=False
|
||||
) as tmp_not_playable_vid:
|
||||
shutil.copy(
|
||||
str(test_file_dir / "bad_video_sample.mp4"),
|
||||
tmp_not_playable_vid.name,
|
||||
)
|
||||
assert gr.processing_utils.video_is_playable(tmp_not_playable_vid.name)
|
||||
assert processing_utils.video_is_playable(tmp_not_playable_vid.name)
|
||||
|
||||
def test_convert_video_to_playable_mp4(self, test_file_dir):
|
||||
with tempfile.NamedTemporaryFile(suffix="out.avi") as tmp_not_playable_vid:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix="out.avi", delete=False
|
||||
) as tmp_not_playable_vid:
|
||||
shutil.copy(
|
||||
str(test_file_dir / "bad_video_sample.mp4"), tmp_not_playable_vid.name
|
||||
)
|
||||
playable_vid = gr.processing_utils.convert_video_to_playable_mp4(
|
||||
playable_vid = processing_utils.convert_video_to_playable_mp4(
|
||||
tmp_not_playable_vid.name
|
||||
)
|
||||
assert gr.processing_utils.video_is_playable(playable_vid)
|
||||
assert processing_utils.video_is_playable(playable_vid)
|
||||
|
||||
@patch("ffmpy.FFmpeg.run", side_effect=raise_ffmpy_runtime_exception)
|
||||
def test_video_conversion_returns_original_video_if_fails(
|
||||
self, mock_run, test_file_dir
|
||||
):
|
||||
with tempfile.NamedTemporaryFile(suffix="out.avi") as tmp_not_playable_vid:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix="out.avi", delete=False
|
||||
) as tmp_not_playable_vid:
|
||||
shutil.copy(
|
||||
str(test_file_dir / "bad_video_sample.mp4"), tmp_not_playable_vid.name
|
||||
)
|
||||
playable_vid = gr.processing_utils.convert_video_to_playable_mp4(
|
||||
playable_vid = processing_utils.convert_video_to_playable_mp4(
|
||||
tmp_not_playable_vid.name
|
||||
)
|
||||
# If the conversion succeeded it'd be .mp4
|
||||
|
Loading…
x
Reference in New Issue
Block a user