mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-12 12:40:29 +08:00
Fix bug where file examples can be corrupted if has multiple extensions (#4440)
* Fix bug * Add to changelog * Add test * Remove breakpoint * fix test * increment version * update client version req --------- Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
e364f81ffc
commit
4a58ccee39
@ -6,6 +6,25 @@ No changes to highlight.
|
||||
|
||||
## Bug Fixes:
|
||||
|
||||
No changes to highlight.
|
||||
|
||||
## Breaking Changes:
|
||||
|
||||
No changes to highlight.
|
||||
|
||||
## Full Changelog:
|
||||
|
||||
No changes to highlight.
|
||||
|
||||
# 0.2.6
|
||||
|
||||
## New Features:
|
||||
|
||||
No changes to highlight.
|
||||
|
||||
## Bug Fixes:
|
||||
|
||||
- Fixed bug file deserialization didn't preserve all file extensions by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4440](https://github.com/gradio-app/gradio/pull/4440)
|
||||
- Fixed bug where mounted apps could not be called via the client by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4435](https://github.com/gradio-app/gradio/pull/4435)
|
||||
|
||||
## Breaking Changes:
|
||||
|
@ -286,11 +286,9 @@ class FileSerializable(Serializable):
|
||||
root_url + "file=" + filepath,
|
||||
hf_token=hf_token,
|
||||
dir=save_dir,
|
||||
).name
|
||||
)
|
||||
else:
|
||||
file_name = utils.create_tmp_copy_of_file(
|
||||
filepath, dir=save_dir
|
||||
).name
|
||||
file_name = utils.create_tmp_copy_of_file(filepath, dir=save_dir)
|
||||
else:
|
||||
data = x.get("data")
|
||||
assert data is not None, f"The 'data' field is missing in {x}"
|
||||
|
@ -6,6 +6,7 @@ import json
|
||||
import mimetypes
|
||||
import os
|
||||
import pkgutil
|
||||
import secrets
|
||||
import shutil
|
||||
import tempfile
|
||||
from concurrent.futures import CancelledError
|
||||
@ -273,40 +274,27 @@ async def get_pred_from_ws(
|
||||
|
||||
def download_tmp_copy_of_file(
|
||||
url_path: str, hf_token: str | None = None, dir: str | None = None
|
||||
) -> tempfile._TemporaryFileWrapper:
|
||||
) -> str:
|
||||
if dir is not None:
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
headers = {"Authorization": "Bearer " + hf_token} if hf_token else {}
|
||||
prefix = Path(url_path).stem
|
||||
suffix = Path(url_path).suffix
|
||||
file_obj = tempfile.NamedTemporaryFile(
|
||||
delete=False,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
dir=dir,
|
||||
)
|
||||
directory = Path(dir or tempfile.gettempdir()) / secrets.token_hex(20)
|
||||
directory.mkdir(exist_ok=True, parents=True)
|
||||
file_path = directory / Path(url_path).name
|
||||
|
||||
with requests.get(url_path, headers=headers, stream=True) as r, open(
|
||||
file_obj.name, "wb"
|
||||
file_path, "wb"
|
||||
) as f:
|
||||
shutil.copyfileobj(r.raw, f)
|
||||
return file_obj
|
||||
return str(file_path.resolve())
|
||||
|
||||
|
||||
def create_tmp_copy_of_file(
|
||||
file_path: str, dir: str | None = None
|
||||
) -> tempfile._TemporaryFileWrapper:
|
||||
if dir is not None:
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
prefix = Path(file_path).stem
|
||||
suffix = Path(file_path).suffix
|
||||
file_obj = tempfile.NamedTemporaryFile(
|
||||
delete=False,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
dir=dir,
|
||||
)
|
||||
shutil.copy2(file_path, file_obj.name)
|
||||
return file_obj
|
||||
def create_tmp_copy_of_file(file_path: str, dir: str | None = None) -> str:
|
||||
directory = Path(dir or tempfile.gettempdir()) / secrets.token_hex(20)
|
||||
directory.mkdir(exist_ok=True, parents=True)
|
||||
dest = directory / Path(file_path).name
|
||||
shutil.copy2(file_path, dest)
|
||||
return str(dest.resolve())
|
||||
|
||||
|
||||
def get_mimetype(filename: str) -> str | None:
|
||||
|
@ -1 +1 @@
|
||||
0.2.5
|
||||
0.2.6
|
||||
|
@ -62,7 +62,7 @@ def test_download_private_file():
|
||||
url_path = "https://gradio-tests-not-actually-private-space.hf.space/file=lion.jpg"
|
||||
hf_token = "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" # Intentionally revealing this key for testing purposes
|
||||
file = utils.download_tmp_copy_of_file(url_path=url_path, hf_token=hf_token)
|
||||
assert file.name.endswith(".jpg")
|
||||
assert Path(file).name.endswith(".jpg")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -3,7 +3,7 @@ aiohttp
|
||||
altair>=4.2.0
|
||||
fastapi
|
||||
ffmpy
|
||||
gradio_client>=0.2.4
|
||||
gradio_client>=0.2.6
|
||||
httpx
|
||||
huggingface_hub>=0.14.0
|
||||
Jinja2
|
||||
|
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -372,3 +373,27 @@ async def test_multiple_file_flagging(tmp_path):
|
||||
|
||||
assert len(prediction[0]) == 2
|
||||
assert all(isinstance(d, dict) for d in prediction[0])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_examples_keep_all_suffixes(tmp_path):
|
||||
with patch("gradio.helpers.CACHED_FOLDER", str(tmp_path)):
|
||||
file_1 = tmp_path / "foo.bar.txt"
|
||||
file_1.write_text("file 1")
|
||||
file_2 = tmp_path / "file_2"
|
||||
file_2.mkdir(parents=True)
|
||||
file_2 = file_2 / "foo.bar.txt"
|
||||
file_2.write_text("file 2")
|
||||
io = gr.Interface(
|
||||
fn=lambda x: x.name,
|
||||
inputs=gr.File(),
|
||||
outputs=[gr.File()],
|
||||
examples=[[str(file_1)], [str(file_2)]],
|
||||
cache_examples=True,
|
||||
)
|
||||
prediction = await io.examples_handler.load_from_cache(0)
|
||||
assert Path(prediction[0]["name"]).read_text() == "file 1"
|
||||
assert prediction[0]["orig_name"] == "foo.bar.txt"
|
||||
prediction = await io.examples_handler.load_from_cache(1)
|
||||
assert Path(prediction[0]["name"]).read_text() == "file 2"
|
||||
assert prediction[0]["orig_name"] == "foo.bar.txt"
|
||||
|
Loading…
x
Reference in New Issue
Block a user