mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-21 01:01:05 +08:00
Fix output directory of files in client & when calling Blocks as function (#4501)
* output dirs * remove * remove * remove * changelog * format * add tests * docstring * changelog * client * blocks * fix test
This commit is contained in:
parent
fbc8e37e02
commit
d65512cb3a
@ -64,6 +64,8 @@ demo.launch()
|
||||
- Fixes bug where `/proxy` route was being incorrectly constructed by the frontend by [@abidlabs](https://github.com/abidlabs) in [PR 4430](https://github.com/gradio-app/gradio/pull/4430).
|
||||
- Fix z-index of status component by [@hannahblair](https://github.com/hannahblair) in [PR 4429](https://github.com/gradio-app/gradio/pull/4429)
|
||||
- Fix video rendering in Safari by [@aliabid94](https://github.com/aliabid94) in [PR 4433](https://github.com/gradio-app/gradio/pull/4433).
|
||||
- The output directory for files downloaded when calling Blocks as a function is now set to a temporary directory by default (instead of the working directory in some cases) by [@abidlabs](https://github.com/abidlabs) in [PR 4501](https://github.com/gradio-app/gradio/pull/4501)
|
||||
|
||||
|
||||
## Other Changes:
|
||||
|
||||
|
@ -16,6 +16,24 @@ No changes to highlight.
|
||||
|
||||
No changes to highlight.
|
||||
|
||||
# 0.2.7
|
||||
|
||||
## New Features:
|
||||
|
||||
- The output directory for files downloaded via the Client can now be set by the `output_dir` parameter in `Client` by [@abidlabs](https://github.com/abidlabs) in [PR 4501](https://github.com/gradio-app/gradio/pull/4501)
|
||||
|
||||
## Bug Fixes:
|
||||
|
||||
- The output directory for files downloaded via the Client are now set to a temporary directory by default (instead of the working directory in some cases) by [@abidlabs](https://github.com/abidlabs) in [PR 4501](https://github.com/gradio-app/gradio/pull/4501)
|
||||
|
||||
## Breaking Changes:
|
||||
|
||||
No changes to highlight.
|
||||
|
||||
## Full Changelog:
|
||||
|
||||
No changes to highlight.
|
||||
|
||||
# 0.2.6
|
||||
|
||||
## New Features:
|
||||
|
@ -3,7 +3,9 @@ from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import urllib.parse
|
||||
@ -39,6 +41,11 @@ from gradio_client.utils import (
|
||||
set_documentation_group("py-client")
|
||||
|
||||
|
||||
DEFAULT_TEMP_DIR = os.environ.get("GRADIO_TEMP_DIR") or str(
|
||||
Path(tempfile.gettempdir()) / "gradio"
|
||||
)
|
||||
|
||||
|
||||
@document("predict", "submit", "view_api", "duplicate")
|
||||
class Client:
|
||||
"""
|
||||
@ -63,6 +70,7 @@ class Client:
|
||||
hf_token: str | None = None,
|
||||
max_workers: int = 40,
|
||||
serialize: bool = True,
|
||||
output_dir: str | Path | None = DEFAULT_TEMP_DIR,
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""
|
||||
@ -71,6 +79,7 @@ class Client:
|
||||
hf_token: The Hugging Face token to use to access private Spaces. Automatically fetched if you are logged in via the Hugging Face Hub CLI. Obtain from: https://huggingface.co/settings/token
|
||||
max_workers: The maximum number of thread workers that can be used to make requests to the remote Gradio app simultaneously.
|
||||
serialize: Whether the client should serialize the inputs and deserialize the outputs of the remote API. If set to False, the client will pass the inputs and outputs as-is, without serializing/deserializing them. E.g. you if you set this to False, you'd submit an image in base64 format instead of a filepath, and you'd get back an image in base64 format from the remote API instead of a filepath.
|
||||
output_dir: The directory to save files that are downloaded from the remote API. If None, reads from the GRADIO_TEMP_DIR environment variable. Defaults to a temporary directory on your machine.
|
||||
verbose: Whether the client should print statements to the console.
|
||||
"""
|
||||
self.verbose = verbose
|
||||
@ -82,6 +91,7 @@ class Client:
|
||||
library_version=utils.__version__,
|
||||
)
|
||||
self.space_id = None
|
||||
self.output_dir = output_dir
|
||||
|
||||
if src.startswith("http://") or src.startswith("https://"):
|
||||
_src = src if src.endswith("/") else src + "/"
|
||||
@ -795,7 +805,12 @@ class Endpoint:
|
||||
), f"Expected {len(self.deserializers)} outputs, got {len(data)}"
|
||||
outputs = tuple(
|
||||
[
|
||||
s.deserialize(d, hf_token=self.client.hf_token, root_url=self.root_url)
|
||||
s.deserialize(
|
||||
d,
|
||||
save_dir=self.client.output_dir,
|
||||
hf_token=self.client.hf_token,
|
||||
root_url=self.root_url,
|
||||
)
|
||||
for s, d in zip(self.deserializers, data)
|
||||
]
|
||||
)
|
||||
|
@ -383,8 +383,8 @@ def decode_base64_to_file(
|
||||
dir: str | Path | None = None,
|
||||
prefix: str | None = None,
|
||||
):
|
||||
if dir is not None:
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
directory = Path(dir or tempfile.gettempdir()) / secrets.token_hex(20)
|
||||
directory.mkdir(exist_ok=True, parents=True)
|
||||
data, extension = decode_base64_to_binary(encoding)
|
||||
if file_path is not None and prefix is None:
|
||||
filename = Path(file_path).name
|
||||
@ -397,13 +397,15 @@ def decode_base64_to_file(
|
||||
prefix = strip_invalid_filename_characters(prefix)
|
||||
|
||||
if extension is None:
|
||||
file_obj = tempfile.NamedTemporaryFile(delete=False, prefix=prefix, dir=dir)
|
||||
file_obj = tempfile.NamedTemporaryFile(
|
||||
delete=False, prefix=prefix, dir=directory
|
||||
)
|
||||
else:
|
||||
file_obj = tempfile.NamedTemporaryFile(
|
||||
delete=False,
|
||||
prefix=prefix,
|
||||
suffix="." + extension,
|
||||
dir=dir,
|
||||
dir=directory,
|
||||
)
|
||||
file_obj.write(data)
|
||||
file_obj.flush()
|
||||
|
@ -1 +1 @@
|
||||
0.2.6
|
||||
0.2.7
|
||||
|
@ -7,6 +7,7 @@ import uuid
|
||||
from concurrent.futures import CancelledError, TimeoutError
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import gradio as gr
|
||||
@ -17,6 +18,7 @@ from gradio.networking import Server
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
from gradio_client import Client
|
||||
from gradio_client.client import DEFAULT_TEMP_DIR
|
||||
from gradio_client.serializing import Serializable
|
||||
from gradio_client.utils import Communicator, ProgressUnit, Status, StatusUpdate
|
||||
|
||||
@ -172,7 +174,17 @@ class TestClientPredictions:
|
||||
"https://huggingface.co/spaces/gradio/video_component/resolve/main/files/a.mp4",
|
||||
fn_index=0,
|
||||
)
|
||||
assert pathlib.Path(job.result()).exists()
|
||||
assert Path(job.result()).exists()
|
||||
assert Path(DEFAULT_TEMP_DIR).resolve() in Path(job.result()).resolve().parents
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
client = Client(src="gradio/video_component", output_dir=temp_dir)
|
||||
job = client.submit(
|
||||
"https://huggingface.co/spaces/gradio/video_component/resolve/main/files/a.mp4",
|
||||
fn_index=0,
|
||||
)
|
||||
assert Path(job.result()).exists()
|
||||
assert Path(temp_dir).resolve() in Path(job.result()).resolve().parents
|
||||
|
||||
def test_progress_updates(self, progress_demo):
|
||||
with connect(progress_demo) as client:
|
||||
|
@ -1134,7 +1134,10 @@ class Blocks(BlockContext):
|
||||
block, components.IOComponent
|
||||
), f"{block.__class__} Component with id {output_id} not a valid output component."
|
||||
deserialized = block.deserialize(
|
||||
outputs[o], root_url=block.root_url, hf_token=Context.hf_token
|
||||
outputs[o],
|
||||
save_dir=block.DEFAULT_TEMP_DIR,
|
||||
root_url=block.root_url,
|
||||
hf_token=Context.hf_token,
|
||||
)
|
||||
predictions.append(deserialized)
|
||||
|
||||
|
@ -3,7 +3,7 @@ aiohttp
|
||||
altair>=4.2.0
|
||||
fastapi
|
||||
ffmpy
|
||||
gradio_client>=0.2.6
|
||||
gradio_client>=0.2.7
|
||||
httpx
|
||||
huggingface_hub>=0.14.0
|
||||
Jinja2
|
||||
|
@ -6,12 +6,14 @@ import os
|
||||
import pathlib
|
||||
import random
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import unittest.mock as mock
|
||||
import uuid
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from string import capwords
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -148,7 +150,7 @@ class TestBlocksMethods:
|
||||
|
||||
inp.submit(fn=update, inputs=inp, outputs=out, api_name="greet")
|
||||
|
||||
gr.Image().style(height=54, width=240)
|
||||
gr.Image(height=54, width=240)
|
||||
|
||||
config1 = demo1.get_config_file()
|
||||
demo2 = gr.Blocks.from_config(config1, [update], "https://fake.hf.space")
|
||||
@ -482,16 +484,18 @@ class TestTempFile:
|
||||
return random.sample(images, n_images)
|
||||
|
||||
monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path))
|
||||
gallery = gr.Gallery()
|
||||
demo = gr.Interface(
|
||||
create_images,
|
||||
inputs=[gr.Slider(value=3, minimum=1, maximum=3, step=1)],
|
||||
outputs=[gr.Gallery().style(grid=2, preview=True)],
|
||||
inputs="slider",
|
||||
outputs=gallery,
|
||||
)
|
||||
with connect(demo) as client:
|
||||
path = client.predict(3)
|
||||
_ = client.predict(3)
|
||||
_ = client.predict(3)
|
||||
# only three files created
|
||||
# only three files created and in temp directory
|
||||
assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 3
|
||||
assert Path(tempfile.gettempdir()).resolve() in Path(path).resolve().parents
|
||||
|
||||
def test_no_empty_image_files(self, tmp_path, connect, monkeypatch):
|
||||
file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
|
||||
|
Loading…
Reference in New Issue
Block a user