Fix output directory of files in client & when calling Blocks as function (#4501)

* output dirs

* remove

* remove

* remove

* changelog

* format

* add tests

* docstring

* changelog

* client

* blocks

* fix test
This commit is contained in:
Abubakar Abid 2023-06-14 17:08:10 -05:00 committed by GitHub
parent fbc8e37e02
commit d65512cb3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 70 additions and 14 deletions

View File

@ -64,6 +64,8 @@ demo.launch()
- Fixes bug where `/proxy` route was being incorrectly constructed by the frontend by [@abidlabs](https://github.com/abidlabs) in [PR 4430](https://github.com/gradio-app/gradio/pull/4430).
- Fix z-index of status component by [@hannahblair](https://github.com/hannahblair) in [PR 4429](https://github.com/gradio-app/gradio/pull/4429)
- Fix video rendering in Safari by [@aliabid94](https://github.com/aliabid94) in [PR 4433](https://github.com/gradio-app/gradio/pull/4433).
- The output directory for files downloaded when calling Blocks as a function is now set to a temporary directory by default (instead of the working directory in some cases) by [@abidlabs](https://github.com/abidlabs) in [PR 4501](https://github.com/gradio-app/gradio/pull/4501)
## Other Changes:

View File

@ -16,6 +16,24 @@ No changes to highlight.
No changes to highlight.
# 0.2.7
## New Features:
- The output directory for files downloaded via the Client can now be set by the `output_dir` parameter in `Client` by [@abidlabs](https://github.com/abidlabs) in [PR 4501](https://github.com/gradio-app/gradio/pull/4501)
## Bug Fixes:
- The output directory for files downloaded via the Client are now set to a temporary directory by default (instead of the working directory in some cases) by [@abidlabs](https://github.com/abidlabs) in [PR 4501](https://github.com/gradio-app/gradio/pull/4501)
## Breaking Changes:
No changes to highlight.
## Full Changelog:
No changes to highlight.
# 0.2.6
## New Features:

View File

@ -3,7 +3,9 @@ from __future__ import annotations
import concurrent.futures
import json
import os
import re
import tempfile
import threading
import time
import urllib.parse
@ -39,6 +41,11 @@ from gradio_client.utils import (
set_documentation_group("py-client")
DEFAULT_TEMP_DIR = os.environ.get("GRADIO_TEMP_DIR") or str(
Path(tempfile.gettempdir()) / "gradio"
)
@document("predict", "submit", "view_api", "duplicate")
class Client:
"""
@ -63,6 +70,7 @@ class Client:
hf_token: str | None = None,
max_workers: int = 40,
serialize: bool = True,
output_dir: str | Path | None = DEFAULT_TEMP_DIR,
verbose: bool = True,
):
"""
@ -71,6 +79,7 @@ class Client:
hf_token: The Hugging Face token to use to access private Spaces. Automatically fetched if you are logged in via the Hugging Face Hub CLI. Obtain from: https://huggingface.co/settings/token
max_workers: The maximum number of thread workers that can be used to make requests to the remote Gradio app simultaneously.
serialize: Whether the client should serialize the inputs and deserialize the outputs of the remote API. If set to False, the client will pass the inputs and outputs as-is, without serializing/deserializing them. E.g. you if you set this to False, you'd submit an image in base64 format instead of a filepath, and you'd get back an image in base64 format from the remote API instead of a filepath.
output_dir: The directory to save files that are downloaded from the remote API. If None, reads from the GRADIO_TEMP_DIR environment variable. Defaults to a temporary directory on your machine.
verbose: Whether the client should print statements to the console.
"""
self.verbose = verbose
@ -82,6 +91,7 @@ class Client:
library_version=utils.__version__,
)
self.space_id = None
self.output_dir = output_dir
if src.startswith("http://") or src.startswith("https://"):
_src = src if src.endswith("/") else src + "/"
@ -795,7 +805,12 @@ class Endpoint:
), f"Expected {len(self.deserializers)} outputs, got {len(data)}"
outputs = tuple(
[
s.deserialize(d, hf_token=self.client.hf_token, root_url=self.root_url)
s.deserialize(
d,
save_dir=self.client.output_dir,
hf_token=self.client.hf_token,
root_url=self.root_url,
)
for s, d in zip(self.deserializers, data)
]
)

View File

@ -383,8 +383,8 @@ def decode_base64_to_file(
dir: str | Path | None = None,
prefix: str | None = None,
):
if dir is not None:
os.makedirs(dir, exist_ok=True)
directory = Path(dir or tempfile.gettempdir()) / secrets.token_hex(20)
directory.mkdir(exist_ok=True, parents=True)
data, extension = decode_base64_to_binary(encoding)
if file_path is not None and prefix is None:
filename = Path(file_path).name
@ -397,13 +397,15 @@ def decode_base64_to_file(
prefix = strip_invalid_filename_characters(prefix)
if extension is None:
file_obj = tempfile.NamedTemporaryFile(delete=False, prefix=prefix, dir=dir)
file_obj = tempfile.NamedTemporaryFile(
delete=False, prefix=prefix, dir=directory
)
else:
file_obj = tempfile.NamedTemporaryFile(
delete=False,
prefix=prefix,
suffix="." + extension,
dir=dir,
dir=directory,
)
file_obj.write(data)
file_obj.flush()

View File

@ -1 +1 @@
0.2.6
0.2.7

View File

@ -7,6 +7,7 @@ import uuid
from concurrent.futures import CancelledError, TimeoutError
from contextlib import contextmanager
from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import MagicMock, patch
import gradio as gr
@ -17,6 +18,7 @@ from gradio.networking import Server
from huggingface_hub.utils import RepositoryNotFoundError
from gradio_client import Client
from gradio_client.client import DEFAULT_TEMP_DIR
from gradio_client.serializing import Serializable
from gradio_client.utils import Communicator, ProgressUnit, Status, StatusUpdate
@ -172,7 +174,17 @@ class TestClientPredictions:
"https://huggingface.co/spaces/gradio/video_component/resolve/main/files/a.mp4",
fn_index=0,
)
assert pathlib.Path(job.result()).exists()
assert Path(job.result()).exists()
assert Path(DEFAULT_TEMP_DIR).resolve() in Path(job.result()).resolve().parents
temp_dir = tempfile.mkdtemp()
client = Client(src="gradio/video_component", output_dir=temp_dir)
job = client.submit(
"https://huggingface.co/spaces/gradio/video_component/resolve/main/files/a.mp4",
fn_index=0,
)
assert Path(job.result()).exists()
assert Path(temp_dir).resolve() in Path(job.result()).resolve().parents
def test_progress_updates(self, progress_demo):
with connect(progress_demo) as client:

View File

@ -1134,7 +1134,10 @@ class Blocks(BlockContext):
block, components.IOComponent
), f"{block.__class__} Component with id {output_id} not a valid output component."
deserialized = block.deserialize(
outputs[o], root_url=block.root_url, hf_token=Context.hf_token
outputs[o],
save_dir=block.DEFAULT_TEMP_DIR,
root_url=block.root_url,
hf_token=Context.hf_token,
)
predictions.append(deserialized)

View File

@ -3,7 +3,7 @@ aiohttp
altair>=4.2.0
fastapi
ffmpy
gradio_client>=0.2.6
gradio_client>=0.2.7
httpx
huggingface_hub>=0.14.0
Jinja2

View File

@ -6,12 +6,14 @@ import os
import pathlib
import random
import sys
import tempfile
import time
import unittest.mock as mock
import uuid
import warnings
from contextlib import contextmanager
from functools import partial
from pathlib import Path
from string import capwords
from unittest.mock import patch
@ -148,7 +150,7 @@ class TestBlocksMethods:
inp.submit(fn=update, inputs=inp, outputs=out, api_name="greet")
gr.Image().style(height=54, width=240)
gr.Image(height=54, width=240)
config1 = demo1.get_config_file()
demo2 = gr.Blocks.from_config(config1, [update], "https://fake.hf.space")
@ -482,16 +484,18 @@ class TestTempFile:
return random.sample(images, n_images)
monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path))
gallery = gr.Gallery()
demo = gr.Interface(
create_images,
inputs=[gr.Slider(value=3, minimum=1, maximum=3, step=1)],
outputs=[gr.Gallery().style(grid=2, preview=True)],
inputs="slider",
outputs=gallery,
)
with connect(demo) as client:
path = client.predict(3)
_ = client.predict(3)
_ = client.predict(3)
# only three files created
# only three files created and in temp directory
assert len([f for f in tmp_path.glob("**/*") if f.is_file()]) == 3
assert Path(tempfile.gettempdir()).resolve() in Path(path).resolve().parents
def test_no_empty_image_files(self, tmp_path, connect, monkeypatch):
file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")