mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-01 11:45:36 +08:00
Clean up backend of File
and UploadButton
and change the return type of preprocess()
from TemporaryFIle to string filepath (#6060)
* changes * add changeset * upload button * file * add changeset * valid types * fix tests * address review --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
9053c95a10
commit
447dfe06bf
5
.changeset/quick-shirts-turn.md
Normal file
5
.changeset/quick-shirts-turn.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": minor
|
||||
---
|
||||
|
||||
feat:Clean up backend of `File` and `UploadButton` and change the return type of `preprocess()` from TemporaryFIle to string filepath
|
@ -13,6 +13,7 @@ from gradio_client.documentation import document, set_documentation_group
|
||||
from gradio.components.base import Component
|
||||
from gradio.data_classes import FileData, GradioRootModel
|
||||
from gradio.events import Events
|
||||
from gradio.utils import NamedString
|
||||
|
||||
set_documentation_group("component")
|
||||
|
||||
@ -39,7 +40,7 @@ class File(Component):
|
||||
*,
|
||||
file_count: Literal["single", "multiple", "directory"] = "single",
|
||||
file_types: list[str] | None = None,
|
||||
type: Literal["file", "binary"] = "file",
|
||||
type: Literal["filepath", "binary"] = "filepath",
|
||||
label: str | None = None,
|
||||
every: float | None = None,
|
||||
show_label: bool | None = None,
|
||||
@ -86,7 +87,7 @@ class File(Component):
|
||||
f"Parameter file_types must be a list. Received {file_types.__class__.__name__}"
|
||||
)
|
||||
valid_types = [
|
||||
"file",
|
||||
"filepath",
|
||||
"binary",
|
||||
]
|
||||
if type not in valid_types:
|
||||
@ -116,38 +117,31 @@ class File(Component):
|
||||
self.type = type
|
||||
self.height = height
|
||||
|
||||
@staticmethod
|
||||
def _process_single_file(
|
||||
f, type: Literal["file", "bytes", "binary"], cache_dir: str
|
||||
) -> bytes | tempfile._TemporaryFileWrapper:
|
||||
def _process_single_file(self, f: dict[str, Any]) -> bytes | str:
|
||||
file_name, data, is_file = (
|
||||
f["name"],
|
||||
f["data"],
|
||||
f.get("is_file", False),
|
||||
)
|
||||
if type == "file":
|
||||
file = tempfile.NamedTemporaryFile(delete=False, dir=cache_dir)
|
||||
if self.type == "filepath":
|
||||
file = tempfile.NamedTemporaryFile(delete=False, dir=self.GRADIO_CACHE)
|
||||
file.name = file_name
|
||||
file.orig_name = file_name # type: ignore
|
||||
return file
|
||||
elif type in {"bytes", "binary"}:
|
||||
return NamedString(file.name)
|
||||
elif self.type == "binary":
|
||||
if is_file:
|
||||
with open(file_name, "rb") as file_data:
|
||||
return file_data.read()
|
||||
return client_utils.decode_base64_to_binary(data)[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown type: " + str(type) + ". Please choose from: 'file', 'bytes'."
|
||||
"Unknown type: "
|
||||
+ str(type)
|
||||
+ ". Please choose from: 'filepath', 'binary'."
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self, x: list[dict[str, Any]] | None
|
||||
) -> (
|
||||
bytes
|
||||
| tempfile._TemporaryFileWrapper
|
||||
| list[bytes | tempfile._TemporaryFileWrapper]
|
||||
| None
|
||||
):
|
||||
) -> bytes | str | list[bytes | str] | None:
|
||||
"""
|
||||
Parameters:
|
||||
x: List of JSON objects with filename as 'name' property and base64 data as 'data' property
|
||||
@ -159,16 +153,14 @@ class File(Component):
|
||||
|
||||
if self.file_count == "single":
|
||||
if isinstance(x, list):
|
||||
return self._process_single_file(
|
||||
x[0], type=self.type, cache_dir=self.GRADIO_CACHE # type: ignore
|
||||
)
|
||||
return self._process_single_file(x[0])
|
||||
else:
|
||||
return self._process_single_file(x, type=self.type, cache_dir=self.GRADIO_CACHE) # type: ignore
|
||||
return self._process_single_file(x)
|
||||
else:
|
||||
if isinstance(x, list):
|
||||
return [self._process_single_file(f, type=self.type, cache_dir=self.GRADIO_CACHE) for f in x] # type: ignore
|
||||
return [self._process_single_file(f) for f in x]
|
||||
else:
|
||||
return self._process_single_file(x, type=self.type, cache_dir=self.GRADIO_CACHE) # type: ignore
|
||||
return [self._process_single_file(x)]
|
||||
|
||||
def postprocess(self, y: str | list[str] | None) -> ListFiles | FileData | None:
|
||||
"""
|
||||
|
@ -6,12 +6,13 @@ import tempfile
|
||||
import warnings
|
||||
from typing import Any, Callable, List, Literal
|
||||
|
||||
from gradio_client import utils as client_utils
|
||||
from gradio_client.documentation import document, set_documentation_group
|
||||
|
||||
from gradio.components.base import Component
|
||||
from gradio.components.file import File
|
||||
from gradio.data_classes import FileData, GradioRootModel
|
||||
from gradio.events import Events
|
||||
from gradio.utils import NamedString
|
||||
|
||||
set_documentation_group("component")
|
||||
|
||||
@ -49,7 +50,7 @@ class UploadButton(Component):
|
||||
render: bool = True,
|
||||
root_url: str | None = None,
|
||||
_skip_init_processing: bool = False,
|
||||
type: Literal["file", "bytes"] = "file",
|
||||
type: Literal["filepath", "bytes"] = "filepath",
|
||||
file_count: Literal["single", "multiple", "directory"] = "single",
|
||||
file_types: list[str] | None = None,
|
||||
):
|
||||
@ -72,6 +73,14 @@ class UploadButton(Component):
|
||||
file_count: if single, allows user to upload one file. If "multiple", user uploads multiple files. If "directory", user uploads all files in selected directory. Return type will be list for each file in case of "multiple" or "directory".
|
||||
file_types: List of type of files to be uploaded. "file" allows any file to be uploaded, "image" allows only image files to be uploaded, "audio" allows only audio files to be uploaded, "video" allows only video files to be uploaded, "text" allows only text files to be uploaded.
|
||||
"""
|
||||
valid_types = [
|
||||
"filepath",
|
||||
"binary",
|
||||
]
|
||||
if type not in valid_types:
|
||||
raise ValueError(
|
||||
f"Invalid value for parameter `type`: {type}. Please choose from one of: {valid_types}"
|
||||
)
|
||||
self.type = type
|
||||
self.file_count = file_count
|
||||
if file_count == "directory" and file_types is not None:
|
||||
@ -115,14 +124,31 @@ class UploadButton(Component):
|
||||
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
|
||||
]
|
||||
|
||||
def _process_single_file(self, f: dict[str, Any]) -> bytes | str:
|
||||
file_name, data, is_file = (
|
||||
f["name"],
|
||||
f["data"],
|
||||
f.get("is_file", False),
|
||||
)
|
||||
if self.type == "filepath":
|
||||
file = tempfile.NamedTemporaryFile(delete=False, dir=self.GRADIO_CACHE)
|
||||
file.name = file_name
|
||||
return NamedString(file.name)
|
||||
elif self.type == "binary":
|
||||
if is_file:
|
||||
with open(file_name, "rb") as file_data:
|
||||
return file_data.read()
|
||||
return client_utils.decode_base64_to_binary(data)[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown type: "
|
||||
+ str(type)
|
||||
+ ". Please choose from: 'filepath', 'binary'."
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self, x: list[dict[str, Any]] | None
|
||||
) -> (
|
||||
bytes
|
||||
| tempfile._TemporaryFileWrapper
|
||||
| list[bytes | tempfile._TemporaryFileWrapper]
|
||||
| None
|
||||
):
|
||||
) -> bytes | str | list[bytes | str] | None:
|
||||
"""
|
||||
Parameters:
|
||||
x: List of JSON objects with filename as 'name' property and base64 data as 'data' property
|
||||
@ -134,25 +160,14 @@ class UploadButton(Component):
|
||||
|
||||
if self.file_count == "single":
|
||||
if isinstance(x, list):
|
||||
return File._process_single_file(
|
||||
x[0], type=self.type, cache_dir=self.GRADIO_CACHE # type: ignore
|
||||
)
|
||||
return self._process_single_file(x[0])
|
||||
else:
|
||||
return File._process_single_file(
|
||||
x, type=self.type, cache_dir=self.GRADIO_CACHE # type: ignore
|
||||
)
|
||||
return self._process_single_file(x)
|
||||
else:
|
||||
if isinstance(x, list):
|
||||
return [
|
||||
File._process_single_file(
|
||||
f, type=self.type, cache_dir=self.GRADIO_CACHE # type: ignore
|
||||
)
|
||||
for f in x
|
||||
]
|
||||
return [self._process_single_file(f) for f in x]
|
||||
else:
|
||||
return File._process_single_file(
|
||||
x, type=self.type, cache_dir=self.GRADIO_CACHE # type: ignore
|
||||
)
|
||||
return [self._process_single_file(x)]
|
||||
|
||||
def postprocess(self, y):
|
||||
return super().postprocess(y)
|
||||
|
@ -421,7 +421,7 @@ class Files(components.File):
|
||||
value: str | list[str] | Callable | None = None,
|
||||
*,
|
||||
file_count: Literal["multiple"] = "multiple",
|
||||
type: Literal["file", "binary"] = "file",
|
||||
type: Literal["filepath", "binary"] = "filepath",
|
||||
label: str | None = None,
|
||||
show_label: bool = True,
|
||||
interactive: bool | None = None,
|
||||
|
@ -959,3 +959,16 @@ def recover_kwargs(config: dict, additional_keys_to_ignore: list[str] | None = N
|
||||
for k, v in config.items()
|
||||
if k not in not_kwargs and k not in (additional_keys_to_ignore or [])
|
||||
}
|
||||
|
||||
|
||||
class NamedString(str):
|
||||
"""
|
||||
Subclass of str that includes a .name attribute equal to the value of the string itself. This class is used when returning
|
||||
a value from the `.preprocess()` methods of the File and UploadButton components. Before Gradio 4.0, these methods returned a file
|
||||
object which was then converted to a string filepath using the `.name` attribute. In Gradio 4.0, these methods now return a str
|
||||
filepath directly, but to maintain backwards compatibility, we use this class instead of a regular str.
|
||||
"""
|
||||
|
||||
def __init__(self, *args):
|
||||
super().__init__()
|
||||
self.name = str(self) if args else ""
|
||||
|
@ -882,13 +882,14 @@ class TestFile:
|
||||
x_file = deepcopy(media_data.BASE64_FILE)
|
||||
file_input = gr.File()
|
||||
output = file_input.preprocess(x_file)
|
||||
assert isinstance(output, tempfile._TemporaryFileWrapper)
|
||||
assert isinstance(output, str)
|
||||
|
||||
x_file["is_file"] = True
|
||||
input1 = file_input.preprocess(x_file)
|
||||
input2 = file_input.preprocess(x_file)
|
||||
assert input1.name == input2.name
|
||||
assert Path(input1.name).name == "sample_file.pdf"
|
||||
assert input1 == input1.name # Testing backwards compatibility
|
||||
assert input1 == input2
|
||||
assert Path(input1).name == "sample_file.pdf"
|
||||
|
||||
file_input = gr.File(label="Upload Your File")
|
||||
assert file_input.get_config() == {
|
||||
@ -908,7 +909,7 @@ class TestFile:
|
||||
"root_url": None,
|
||||
"selectable": False,
|
||||
"height": None,
|
||||
"type": "file",
|
||||
"type": "filepath",
|
||||
}
|
||||
assert file_input.preprocess(None) is None
|
||||
x_file["is_example"] = True
|
||||
@ -966,12 +967,13 @@ class TestUploadButton:
|
||||
x_file = deepcopy(media_data.BASE64_FILE)
|
||||
upload_input = gr.UploadButton()
|
||||
input = upload_input.preprocess(x_file)
|
||||
assert isinstance(input, tempfile._TemporaryFileWrapper)
|
||||
assert isinstance(input, str)
|
||||
|
||||
x_file["is_file"] = True
|
||||
input1 = upload_input.preprocess(x_file)
|
||||
input2 = upload_input.preprocess(x_file)
|
||||
assert input1.name == input2.name
|
||||
assert input1 == input1.name # Testing backwards compatibility
|
||||
assert input1 == input2
|
||||
|
||||
def test_raises_if_file_types_is_not_list(self):
|
||||
with pytest.raises(
|
||||
|
Loading…
Reference in New Issue
Block a user