Fix bug where file examples can be corrupted if has multiple extensions (#4440)

* Fix bug

* Add to changelog

* Add test

* Remove breakpoint

* fix test

* increment version

* update client version req

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Freddy Boulton 2023-06-08 05:05:01 +09:00 committed by GitHub
parent e364f81ffc
commit 4a58ccee39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 63 additions and 33 deletions

View File

@ -6,6 +6,25 @@ No changes to highlight.
## Bug Fixes:
No changes to highlight.
## Breaking Changes:
No changes to highlight.
## Full Changelog:
No changes to highlight.
# 0.2.6
## New Features:
No changes to highlight.
## Bug Fixes:
- Fixed bug file deserialization didn't preserve all file extensions by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4440](https://github.com/gradio-app/gradio/pull/4440)
- Fixed bug where mounted apps could not be called via the client by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4435](https://github.com/gradio-app/gradio/pull/4435)
## Breaking Changes:

View File

@ -286,11 +286,9 @@ class FileSerializable(Serializable):
root_url + "file=" + filepath,
hf_token=hf_token,
dir=save_dir,
).name
)
else:
file_name = utils.create_tmp_copy_of_file(
filepath, dir=save_dir
).name
file_name = utils.create_tmp_copy_of_file(filepath, dir=save_dir)
else:
data = x.get("data")
assert data is not None, f"The 'data' field is missing in {x}"

View File

@ -6,6 +6,7 @@ import json
import mimetypes
import os
import pkgutil
import secrets
import shutil
import tempfile
from concurrent.futures import CancelledError
@ -273,40 +274,27 @@ async def get_pred_from_ws(
def download_tmp_copy_of_file(
url_path: str, hf_token: str | None = None, dir: str | None = None
) -> tempfile._TemporaryFileWrapper:
) -> str:
if dir is not None:
os.makedirs(dir, exist_ok=True)
headers = {"Authorization": "Bearer " + hf_token} if hf_token else {}
prefix = Path(url_path).stem
suffix = Path(url_path).suffix
file_obj = tempfile.NamedTemporaryFile(
delete=False,
prefix=prefix,
suffix=suffix,
dir=dir,
)
directory = Path(dir or tempfile.gettempdir()) / secrets.token_hex(20)
directory.mkdir(exist_ok=True, parents=True)
file_path = directory / Path(url_path).name
with requests.get(url_path, headers=headers, stream=True) as r, open(
file_obj.name, "wb"
file_path, "wb"
) as f:
shutil.copyfileobj(r.raw, f)
return file_obj
return str(file_path.resolve())
def create_tmp_copy_of_file(
file_path: str, dir: str | None = None
) -> tempfile._TemporaryFileWrapper:
if dir is not None:
os.makedirs(dir, exist_ok=True)
prefix = Path(file_path).stem
suffix = Path(file_path).suffix
file_obj = tempfile.NamedTemporaryFile(
delete=False,
prefix=prefix,
suffix=suffix,
dir=dir,
)
shutil.copy2(file_path, file_obj.name)
return file_obj
def create_tmp_copy_of_file(file_path: str, dir: str | None = None) -> str:
directory = Path(dir or tempfile.gettempdir()) / secrets.token_hex(20)
directory.mkdir(exist_ok=True, parents=True)
dest = directory / Path(file_path).name
shutil.copy2(file_path, dest)
return str(dest.resolve())
def get_mimetype(filename: str) -> str | None:

View File

@ -1 +1 @@
0.2.5
0.2.6

View File

@ -62,7 +62,7 @@ def test_download_private_file():
url_path = "https://gradio-tests-not-actually-private-space.hf.space/file=lion.jpg"
hf_token = "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" # Intentionally revealing this key for testing purposes
file = utils.download_tmp_copy_of_file(url_path=url_path, hf_token=hf_token)
assert file.name.endswith(".jpg")
assert Path(file).name.endswith(".jpg")
@pytest.mark.parametrize(

View File

@ -3,7 +3,7 @@ aiohttp
altair>=4.2.0
fastapi
ffmpy
gradio_client>=0.2.4
gradio_client>=0.2.6
httpx
huggingface_hub>=0.14.0
Jinja2

View File

@ -1,5 +1,6 @@
import os
import tempfile
from pathlib import Path
from unittest.mock import patch
import pytest
@ -372,3 +373,27 @@ async def test_multiple_file_flagging(tmp_path):
assert len(prediction[0]) == 2
assert all(isinstance(d, dict) for d in prediction[0])
@pytest.mark.asyncio
async def test_examples_keep_all_suffixes(tmp_path):
with patch("gradio.helpers.CACHED_FOLDER", str(tmp_path)):
file_1 = tmp_path / "foo.bar.txt"
file_1.write_text("file 1")
file_2 = tmp_path / "file_2"
file_2.mkdir(parents=True)
file_2 = file_2 / "foo.bar.txt"
file_2.write_text("file 2")
io = gr.Interface(
fn=lambda x: x.name,
inputs=gr.File(),
outputs=[gr.File()],
examples=[[str(file_1)], [str(file_2)]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
assert Path(prediction[0]["name"]).read_text() == "file 1"
assert prediction[0]["orig_name"] == "foo.bar.txt"
prediction = await io.examples_handler.load_from_cache(1)
assert Path(prediction[0]["name"]).read_text() == "file 2"
assert prediction[0]["orig_name"] == "foo.bar.txt"