Clean up backend of File and UploadButton and change the return type of preprocess() from TemporaryFIle to string filepath (#6060)

* changes

* add changeset

* upload button

* file

* add changeset

* valid types

* fix tests

* address review

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2023-10-23 12:11:32 -07:00 committed by GitHub
parent 9053c95a10
commit 447dfe06bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 81 additions and 54 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": minor
---
feat:Clean up backend of `File` and `UploadButton` and change the return type of `preprocess()` from TemporaryFIle to string filepath

View File

@ -13,6 +13,7 @@ from gradio_client.documentation import document, set_documentation_group
from gradio.components.base import Component
from gradio.data_classes import FileData, GradioRootModel
from gradio.events import Events
from gradio.utils import NamedString
set_documentation_group("component")
@ -39,7 +40,7 @@ class File(Component):
*,
file_count: Literal["single", "multiple", "directory"] = "single",
file_types: list[str] | None = None,
type: Literal["file", "binary"] = "file",
type: Literal["filepath", "binary"] = "filepath",
label: str | None = None,
every: float | None = None,
show_label: bool | None = None,
@ -86,7 +87,7 @@ class File(Component):
f"Parameter file_types must be a list. Received {file_types.__class__.__name__}"
)
valid_types = [
"file",
"filepath",
"binary",
]
if type not in valid_types:
@ -116,38 +117,31 @@ class File(Component):
self.type = type
self.height = height
@staticmethod
def _process_single_file(
f, type: Literal["file", "bytes", "binary"], cache_dir: str
) -> bytes | tempfile._TemporaryFileWrapper:
def _process_single_file(self, f: dict[str, Any]) -> bytes | str:
file_name, data, is_file = (
f["name"],
f["data"],
f.get("is_file", False),
)
if type == "file":
file = tempfile.NamedTemporaryFile(delete=False, dir=cache_dir)
if self.type == "filepath":
file = tempfile.NamedTemporaryFile(delete=False, dir=self.GRADIO_CACHE)
file.name = file_name
file.orig_name = file_name # type: ignore
return file
elif type in {"bytes", "binary"}:
return NamedString(file.name)
elif self.type == "binary":
if is_file:
with open(file_name, "rb") as file_data:
return file_data.read()
return client_utils.decode_base64_to_binary(data)[0]
else:
raise ValueError(
"Unknown type: " + str(type) + ". Please choose from: 'file', 'bytes'."
"Unknown type: "
+ str(type)
+ ". Please choose from: 'filepath', 'binary'."
)
def preprocess(
self, x: list[dict[str, Any]] | None
) -> (
bytes
| tempfile._TemporaryFileWrapper
| list[bytes | tempfile._TemporaryFileWrapper]
| None
):
) -> bytes | str | list[bytes | str] | None:
"""
Parameters:
x: List of JSON objects with filename as 'name' property and base64 data as 'data' property
@ -159,16 +153,14 @@ class File(Component):
if self.file_count == "single":
if isinstance(x, list):
return self._process_single_file(
x[0], type=self.type, cache_dir=self.GRADIO_CACHE # type: ignore
)
return self._process_single_file(x[0])
else:
return self._process_single_file(x, type=self.type, cache_dir=self.GRADIO_CACHE) # type: ignore
return self._process_single_file(x)
else:
if isinstance(x, list):
return [self._process_single_file(f, type=self.type, cache_dir=self.GRADIO_CACHE) for f in x] # type: ignore
return [self._process_single_file(f) for f in x]
else:
return self._process_single_file(x, type=self.type, cache_dir=self.GRADIO_CACHE) # type: ignore
return [self._process_single_file(x)]
def postprocess(self, y: str | list[str] | None) -> ListFiles | FileData | None:
"""

View File

@ -6,12 +6,13 @@ import tempfile
import warnings
from typing import Any, Callable, List, Literal
from gradio_client import utils as client_utils
from gradio_client.documentation import document, set_documentation_group
from gradio.components.base import Component
from gradio.components.file import File
from gradio.data_classes import FileData, GradioRootModel
from gradio.events import Events
from gradio.utils import NamedString
set_documentation_group("component")
@ -49,7 +50,7 @@ class UploadButton(Component):
render: bool = True,
root_url: str | None = None,
_skip_init_processing: bool = False,
type: Literal["file", "bytes"] = "file",
type: Literal["filepath", "bytes"] = "filepath",
file_count: Literal["single", "multiple", "directory"] = "single",
file_types: list[str] | None = None,
):
@ -72,6 +73,14 @@ class UploadButton(Component):
file_count: if single, allows user to upload one file. If "multiple", user uploads multiple files. If "directory", user uploads all files in selected directory. Return type will be list for each file in case of "multiple" or "directory".
file_types: List of type of files to be uploaded. "file" allows any file to be uploaded, "image" allows only image files to be uploaded, "audio" allows only audio files to be uploaded, "video" allows only video files to be uploaded, "text" allows only text files to be uploaded.
"""
valid_types = [
"filepath",
"binary",
]
if type not in valid_types:
raise ValueError(
f"Invalid value for parameter `type`: {type}. Please choose from one of: {valid_types}"
)
self.type = type
self.file_count = file_count
if file_count == "directory" and file_types is not None:
@ -115,14 +124,31 @@ class UploadButton(Component):
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
]
def _process_single_file(self, f: dict[str, Any]) -> bytes | str:
file_name, data, is_file = (
f["name"],
f["data"],
f.get("is_file", False),
)
if self.type == "filepath":
file = tempfile.NamedTemporaryFile(delete=False, dir=self.GRADIO_CACHE)
file.name = file_name
return NamedString(file.name)
elif self.type == "binary":
if is_file:
with open(file_name, "rb") as file_data:
return file_data.read()
return client_utils.decode_base64_to_binary(data)[0]
else:
raise ValueError(
"Unknown type: "
+ str(type)
+ ". Please choose from: 'filepath', 'binary'."
)
def preprocess(
self, x: list[dict[str, Any]] | None
) -> (
bytes
| tempfile._TemporaryFileWrapper
| list[bytes | tempfile._TemporaryFileWrapper]
| None
):
) -> bytes | str | list[bytes | str] | None:
"""
Parameters:
x: List of JSON objects with filename as 'name' property and base64 data as 'data' property
@ -134,25 +160,14 @@ class UploadButton(Component):
if self.file_count == "single":
if isinstance(x, list):
return File._process_single_file(
x[0], type=self.type, cache_dir=self.GRADIO_CACHE # type: ignore
)
return self._process_single_file(x[0])
else:
return File._process_single_file(
x, type=self.type, cache_dir=self.GRADIO_CACHE # type: ignore
)
return self._process_single_file(x)
else:
if isinstance(x, list):
return [
File._process_single_file(
f, type=self.type, cache_dir=self.GRADIO_CACHE # type: ignore
)
for f in x
]
return [self._process_single_file(f) for f in x]
else:
return File._process_single_file(
x, type=self.type, cache_dir=self.GRADIO_CACHE # type: ignore
)
return [self._process_single_file(x)]
def postprocess(self, y):
return super().postprocess(y)

View File

@ -421,7 +421,7 @@ class Files(components.File):
value: str | list[str] | Callable | None = None,
*,
file_count: Literal["multiple"] = "multiple",
type: Literal["file", "binary"] = "file",
type: Literal["filepath", "binary"] = "filepath",
label: str | None = None,
show_label: bool = True,
interactive: bool | None = None,

View File

@ -959,3 +959,16 @@ def recover_kwargs(config: dict, additional_keys_to_ignore: list[str] | None = N
for k, v in config.items()
if k not in not_kwargs and k not in (additional_keys_to_ignore or [])
}
class NamedString(str):
"""
Subclass of str that includes a .name attribute equal to the value of the string itself. This class is used when returning
a value from the `.preprocess()` methods of the File and UploadButton components. Before Gradio 4.0, these methods returned a file
object which was then converted to a string filepath using the `.name` attribute. In Gradio 4.0, these methods now return a str
filepath directly, but to maintain backwards compatibility, we use this class instead of a regular str.
"""
def __init__(self, *args):
super().__init__()
self.name = str(self) if args else ""

View File

@ -882,13 +882,14 @@ class TestFile:
x_file = deepcopy(media_data.BASE64_FILE)
file_input = gr.File()
output = file_input.preprocess(x_file)
assert isinstance(output, tempfile._TemporaryFileWrapper)
assert isinstance(output, str)
x_file["is_file"] = True
input1 = file_input.preprocess(x_file)
input2 = file_input.preprocess(x_file)
assert input1.name == input2.name
assert Path(input1.name).name == "sample_file.pdf"
assert input1 == input1.name # Testing backwards compatibility
assert input1 == input2
assert Path(input1).name == "sample_file.pdf"
file_input = gr.File(label="Upload Your File")
assert file_input.get_config() == {
@ -908,7 +909,7 @@ class TestFile:
"root_url": None,
"selectable": False,
"height": None,
"type": "file",
"type": "filepath",
}
assert file_input.preprocess(None) is None
x_file["is_example"] = True
@ -966,12 +967,13 @@ class TestUploadButton:
x_file = deepcopy(media_data.BASE64_FILE)
upload_input = gr.UploadButton()
input = upload_input.preprocess(x_file)
assert isinstance(input, tempfile._TemporaryFileWrapper)
assert isinstance(input, str)
x_file["is_file"] = True
input1 = upload_input.preprocess(x_file)
input2 = upload_input.preprocess(x_file)
assert input1.name == input2.name
assert input1 == input1.name # Testing backwards compatibility
assert input1 == input2
def test_raises_if_file_types_is_not_list(self):
with pytest.raises(