Set orig_name in python client file uploads (#8371)

* Add code

* add changeset

* URL case

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Hannah <hannahblair@users.noreply.github.com>
This commit is contained in:
Freddy Boulton 2024-05-27 11:26:52 -04:00 committed by GitHub
parent 24ab22d261
commit a373b0edd3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 31 additions and 2 deletions

View File

@ -0,0 +1,6 @@
---
"gradio": patch
"gradio_client": patch
---
fix:Set orig_name in python client file uploads

View File

@ -1338,6 +1338,7 @@ class Endpoint:
file_path = f
else:
file_path = f["path"]
orig_name = Path(file_path)
if not utils.is_http_url_like(file_path):
component_id = self.dependency["inputs"][data_index]
component_config = next(
@ -1356,7 +1357,7 @@ class Endpoint:
f"set in {component_config.get('label', '') + ''} component."
)
with open(file_path, "rb") as f:
files = [("files", (Path(file_path).name, f))]
files = [("files", (orig_name.name, f))]
r = httpx.post(
self.client.upload_url,
headers=self.client.headers,
@ -1367,7 +1368,14 @@ class Endpoint:
r.raise_for_status()
result = r.json()
file_path = result[0]
return {"path": file_path}
# Only return orig_name if has a suffix because components
# use the suffix of the original name to determine format to save it to in cache.
return {
"path": file_path,
"orig_name": utils.strip_invalid_filename_characters(orig_name.name)
if orig_name.suffix
else None,
}
def _download_file(self, x: dict) -> str:
url_path = self.root_url + "file=" + x["path"]

View File

@ -324,6 +324,21 @@ class TestClientPredictions:
with open(output) as f:
assert f.read() == "Hello file!"
def test_upload_preserves_orig_name(self):
demo = gr.Interface(lambda x: x, "image", "text")
with connect(demo) as client:
test_file = str(Path(__file__).parent / "files" / "cheetah1.jpg")
output = client.endpoints[0]._upload_file({"path": test_file}, data_index=0)
assert output["orig_name"] == "cheetah1.jpg"
output = client.endpoints[0]._upload_file(
{
"path": "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
},
data_index=0,
)
assert output["orig_name"] == "bus.png"
@pytest.mark.flaky
def test_cancel_from_client_queued(self, cancel_from_client_demo):
with connect(cancel_from_client_demo) as client: