mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-21 01:01:05 +08:00
Remove deprecated parameters from Python Client (#8444)
* deprecate * add changeset * file -> handle_file * more updates * format * add changeset * fix connect * fix tests * fix more tests * remove outdated test * serialize * address review comments * fix dir --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
3dbce6b308
commit
2cd02ff3b7
7
.changeset/neat-trains-repair.md
Normal file
7
.changeset/neat-trains-repair.md
Normal file
@ -0,0 +1,7 @@
|
||||
---
|
||||
"@gradio/app": minor
|
||||
"gradio": minor
|
||||
"gradio_client": minor
|
||||
---
|
||||
|
||||
feat:Remove deprecated parameters from Python Client
|
@ -1,8 +1,9 @@
|
||||
from gradio_client.client import Client
|
||||
from gradio_client.utils import __version__, file
|
||||
from gradio_client.utils import __version__, file, handle_file
|
||||
|
||||
__all__ = [
|
||||
"Client",
|
||||
"file",
|
||||
"handle_file",
|
||||
"__version__",
|
||||
]
|
||||
|
@ -77,39 +77,26 @@ class Client:
|
||||
src: str,
|
||||
hf_token: str | None = None,
|
||||
max_workers: int = 40,
|
||||
serialize: bool | None = None, # TODO: remove in 1.0
|
||||
output_dir: str
|
||||
| Path = DEFAULT_TEMP_DIR, # Maybe this can be combined with `download_files` in 1.0
|
||||
verbose: bool = True,
|
||||
auth: tuple[str, str] | None = None,
|
||||
*,
|
||||
headers: dict[str, str] | None = None,
|
||||
upload_files: bool = True, # TODO: remove and hardcode to False in 1.0
|
||||
download_files: bool = True, # TODO: consider setting to False in 1.0
|
||||
_skip_components: bool = True, # internal parameter to skip values certain components (e.g. State) that do not need to be displayed to users.
|
||||
download_files: str | Path | Literal[False] = DEFAULT_TEMP_DIR,
|
||||
ssl_verify: bool = True,
|
||||
_skip_components: bool = True, # internal parameter to skip values certain components (e.g. State) that do not need to be displayed to users.
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
src: Either the name of the Hugging Face Space to load, (e.g. "abidlabs/whisper-large-v2") or the full URL (including "http" or "https") of the hosted Gradio app to load (e.g. "http://mydomain.com/app" or "https://bec81a83-5b5c-471e.gradio.live/").
|
||||
hf_token: The Hugging Face token to use to access private Spaces. Automatically fetched if you are logged in via the Hugging Face Hub CLI. Obtain from: https://huggingface.co/settings/token
|
||||
max_workers: The maximum number of thread workers that can be used to make requests to the remote Gradio app simultaneously.
|
||||
serialize: Deprecated. Please use the equivalent `upload_files` parameter instead.
|
||||
output_dir: The directory to save files that are downloaded from the remote API. If None, reads from the GRADIO_TEMP_DIR environment variable. Defaults to a temporary directory on your machine.
|
||||
verbose: Whether the client should print statements to the console.
|
||||
headers: Additional headers to send to the remote Gradio app on every request. By default only the HF authorization and user-agent headers are sent. These headers will override the default headers if they have the same keys.
|
||||
upload_files: Whether the client should treat input string filepath as files and upload them to the remote server. If False, the client will treat input string filepaths as strings always and not modify them, and files should be passed in explicitly using `gradio_client.file("path/to/file/or/url")` instead. This parameter will be deleted and False will become the default in a future version.
|
||||
download_files: Whether the client should download output files from the remote API and return them as string filepaths on the local machine. If False, the client will return a FileData dataclass object with the filepath on the remote machine instead.
|
||||
headers: Additional headers to send to the remote Gradio app on every request. By default only the HF authorization and user-agent headers are sent. This parameter will override the default headers if they have the same keys.
|
||||
download_files: Directory where the client should download output files on the local machine from the remote API. By default, uses the value of the GRADIO_TEMP_DIR environment variable which, if not set by the user, is a temporary directory on your machine. If False, the client does not download files and returns a FileData dataclass object with the filepath on the remote machine instead.
|
||||
ssl_verify: If False, skips certificate validation which allows the client to connect to Gradio apps that are using self-signed certificates.
|
||||
"""
|
||||
self.verbose = verbose
|
||||
self.hf_token = hf_token
|
||||
if serialize is not None:
|
||||
warnings.warn(
|
||||
"The `serialize` parameter is deprecated and will be removed. Please use the equivalent `upload_files` parameter instead."
|
||||
)
|
||||
upload_files = serialize
|
||||
self.upload_files = upload_files
|
||||
self.download_files = download_files
|
||||
self._skip_components = _skip_components
|
||||
self.headers = build_hf_headers(
|
||||
@ -122,9 +109,14 @@ class Client:
|
||||
self.ssl_verify = ssl_verify
|
||||
self.space_id = None
|
||||
self.cookies: dict[str, str] = {}
|
||||
self.output_dir = (
|
||||
str(output_dir) if isinstance(output_dir, Path) else output_dir
|
||||
)
|
||||
if isinstance(self.download_files, (str, Path)):
|
||||
if not os.path.exists(self.download_files):
|
||||
os.makedirs(self.download_files, exist_ok=True)
|
||||
if not os.path.isdir(self.download_files):
|
||||
raise ValueError(f"Path: {self.download_files} is not a directory.")
|
||||
self.output_dir = str(self.download_files)
|
||||
else:
|
||||
self.output_dir = DEFAULT_TEMP_DIR
|
||||
|
||||
if src.startswith("http://") or src.startswith("https://"):
|
||||
_src = src if src.endswith("/") else src + "/"
|
||||
@ -554,10 +546,7 @@ class Client:
|
||||
return job
|
||||
|
||||
def _get_api_info(self):
|
||||
if self.upload_files:
|
||||
api_info_url = urllib.parse.urljoin(self.src, utils.API_INFO_URL)
|
||||
else:
|
||||
api_info_url = urllib.parse.urljoin(self.src, utils.RAW_API_INFO_URL)
|
||||
api_info_url = urllib.parse.urljoin(self.src, utils.RAW_API_INFO_URL)
|
||||
if self.app_version > version.Version("3.36.1"):
|
||||
r = httpx.get(
|
||||
api_info_url,
|
||||
@ -574,7 +563,7 @@ class Client:
|
||||
utils.SPACE_FETCHER_URL,
|
||||
json={
|
||||
"config": json.dumps(self.config),
|
||||
"serialize": self.upload_files,
|
||||
"serialize": False,
|
||||
},
|
||||
)
|
||||
if fetch.is_success:
|
||||
@ -737,7 +726,7 @@ class Client:
|
||||
default_value = info.get("parameter_default")
|
||||
default_value = utils.traverse(
|
||||
default_value,
|
||||
lambda x: f"file(\"{x['url']}\")",
|
||||
lambda x: f"handle_file(\"{x['url']}\")",
|
||||
utils.is_file_obj_with_meta,
|
||||
)
|
||||
default_info = (
|
||||
@ -1273,20 +1262,11 @@ class Endpoint:
|
||||
def process_input_files(self, *data) -> tuple:
|
||||
data_ = []
|
||||
for i, d in enumerate(data):
|
||||
if self.client.upload_files and self.input_component_types[i].value_is_file:
|
||||
d = utils.traverse(
|
||||
d,
|
||||
partial(self._upload_file, data_index=i),
|
||||
lambda f: utils.is_filepath(f)
|
||||
or utils.is_file_obj_with_meta(f)
|
||||
or utils.is_http_url_like(f),
|
||||
)
|
||||
elif not self.client.upload_files:
|
||||
d = utils.traverse(
|
||||
d,
|
||||
partial(self._upload_file, data_index=i),
|
||||
utils.is_file_obj_with_meta,
|
||||
)
|
||||
d = utils.traverse(
|
||||
d,
|
||||
partial(self._upload_file, data_index=i),
|
||||
utils.is_file_obj_with_meta,
|
||||
)
|
||||
data_.append(d)
|
||||
return tuple(data_)
|
||||
|
||||
@ -1329,15 +1309,7 @@ class Endpoint:
|
||||
return data
|
||||
|
||||
def _upload_file(self, f: str | dict, data_index: int) -> dict[str, str]:
|
||||
if isinstance(f, str):
|
||||
warnings.warn(
|
||||
f'The Client is treating: "{f}" as a file path. In future versions, this behavior will not happen automatically. '
|
||||
f'\n\nInstead, please provide file path or URLs like this: gradio_client.file("{f}"). '
|
||||
"\n\nNote: to stop treating strings as filepaths unless file() is used, set upload_files=False in Client()."
|
||||
)
|
||||
file_path = f
|
||||
else:
|
||||
file_path = f["path"]
|
||||
file_path = f["path"]
|
||||
orig_name = Path(file_path)
|
||||
if not utils.is_http_url_like(file_path):
|
||||
component_id = self.dependency["inputs"][data_index]
|
||||
|
@ -61,8 +61,7 @@ class EndpointV3Compatibility:
|
||||
if not self.is_valid:
|
||||
raise utils.InvalidAPIEndpointError()
|
||||
data = self.insert_state(*data)
|
||||
if self.client.upload_files:
|
||||
data = self.serialize(*data)
|
||||
data = self.serialize(*data)
|
||||
predictions = _predict(*data)
|
||||
predictions = self.process_predictions(*predictions)
|
||||
# Append final output only if not already present
|
||||
|
@ -1080,7 +1080,7 @@ SKIP_COMPONENTS = {
|
||||
}
|
||||
|
||||
|
||||
def file(filepath_or_url: str | Path):
|
||||
def handle_file(filepath_or_url: str | Path):
|
||||
s = str(filepath_or_url)
|
||||
data = {"path": s, "meta": {"_type": "gradio.FileData"}}
|
||||
if is_http_url_like(s):
|
||||
@ -1093,6 +1093,13 @@ def file(filepath_or_url: str | Path):
|
||||
)
|
||||
|
||||
|
||||
def file(filepath_or_url: str | Path):
|
||||
warnings.warn(
|
||||
"file() is deprecated and will be removed in a future version. Use handle_file() instead."
|
||||
)
|
||||
return handle_file(filepath_or_url)
|
||||
|
||||
|
||||
def construct_args(
|
||||
parameters_info: list[ParameterInfo] | None, args: tuple, kwargs: dict
|
||||
) -> list:
|
||||
|
@ -20,7 +20,7 @@ from gradio.http_server import Server
|
||||
from huggingface_hub import HfFolder
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
from gradio_client import Client, file
|
||||
from gradio_client import Client, handle_file
|
||||
from gradio_client.client import DEFAULT_TEMP_DIR
|
||||
from gradio_client.exceptions import AppError, AuthenticationError
|
||||
from gradio_client.utils import (
|
||||
@ -37,13 +37,12 @@ HF_TOKEN = os.getenv("HF_TOKEN") or HfFolder.get_token()
|
||||
@contextmanager
|
||||
def connect(
|
||||
demo: gr.Blocks,
|
||||
serialize: bool = True,
|
||||
output_dir: str = DEFAULT_TEMP_DIR,
|
||||
download_files: str = DEFAULT_TEMP_DIR,
|
||||
**kwargs,
|
||||
):
|
||||
_, local_url, _ = demo.launch(prevent_thread_lock=True, **kwargs)
|
||||
try:
|
||||
yield Client(local_url, serialize=serialize, output_dir=output_dir)
|
||||
yield Client(local_url, download_files=download_files)
|
||||
finally:
|
||||
# A more verbose version of .close()
|
||||
# because we should set a timeout
|
||||
@ -92,11 +91,11 @@ class TestClientPredictions:
|
||||
with connect(max_file_size_demo, max_file_size="15kb") as client:
|
||||
with pytest.raises(ValueError, match="exceeds the maximum file size"):
|
||||
client.predict(
|
||||
file(Path(__file__).parent / "files" / "cheetah1.jpg"),
|
||||
handle_file(Path(__file__).parent / "files" / "cheetah1.jpg"),
|
||||
api_name="/upload_1b",
|
||||
)
|
||||
client.predict(
|
||||
file(Path(__file__).parent / "files" / "alphabet.txt"),
|
||||
handle_file(Path(__file__).parent / "files" / "alphabet.txt"),
|
||||
api_name="/upload_1b",
|
||||
)
|
||||
|
||||
@ -254,17 +253,11 @@ class TestClientPredictions:
|
||||
job = client.submit("foo", "add", 9, fn_index=0)
|
||||
job.result()
|
||||
|
||||
def test_raises_exception_no_queue(self, sentiment_classification_demo):
|
||||
with pytest.raises(Exception):
|
||||
with connect(sentiment_classification_demo) as client:
|
||||
job = client.submit([5], api_name="/sleep")
|
||||
job.result()
|
||||
|
||||
def test_job_output_video(self, video_component):
|
||||
with connect(video_component) as client:
|
||||
job = client.submit(
|
||||
{
|
||||
"video": file(
|
||||
"video": handle_file(
|
||||
"https://huggingface.co/spaces/gradio/video_component/resolve/main/files/a.mp4"
|
||||
)
|
||||
},
|
||||
@ -277,10 +270,10 @@ class TestClientPredictions:
|
||||
)
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
with connect(video_component, output_dir=temp_dir) as client:
|
||||
with connect(video_component, download_files=temp_dir) as client:
|
||||
job = client.submit(
|
||||
{
|
||||
"video": file(
|
||||
"video": handle_file(
|
||||
"https://huggingface.co/spaces/gradio/video_component/resolve/main/files/a.mp4"
|
||||
)
|
||||
},
|
||||
@ -430,13 +423,15 @@ class TestClientPredictions:
|
||||
def test_stream_audio(self, stream_audio):
|
||||
with connect(stream_audio) as client:
|
||||
job1 = client.submit(
|
||||
file("https://gradio-builds.s3.amazonaws.com/demo-files/bark_demo.mp4"),
|
||||
handle_file(
|
||||
"https://gradio-builds.s3.amazonaws.com/demo-files/bark_demo.mp4"
|
||||
),
|
||||
api_name="/predict",
|
||||
)
|
||||
assert Path(job1.result()).exists()
|
||||
|
||||
job2 = client.submit(
|
||||
file(
|
||||
handle_file(
|
||||
"https://gradio-builds.s3.amazonaws.com/demo-files/audio_sample.wav"
|
||||
),
|
||||
api_name="/predict",
|
||||
@ -552,13 +547,6 @@ class TestClientPredictions:
|
||||
client.submit(1, "foo", f.name, fn_index=0).result()
|
||||
serialize.assert_called_once_with(1, "foo", f.name)
|
||||
|
||||
def test_state_without_serialize(self, stateful_chatbot):
|
||||
with connect(stateful_chatbot, serialize=False) as client:
|
||||
initial_history = [["", None]]
|
||||
message = "Hello"
|
||||
ret = client.predict(message, initial_history, api_name="/submit")
|
||||
assert ret == ("", [["", None], ["Hello", "I love you"]])
|
||||
|
||||
def test_does_not_upload_dir(self, stateful_chatbot):
|
||||
with connect(stateful_chatbot) as client:
|
||||
initial_history = [["", None]]
|
||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from gradio_client import file
|
||||
from gradio_client import handle_file
|
||||
from gradio_client.documentation import document
|
||||
|
||||
from gradio.components.base import Component
|
||||
@ -102,7 +102,7 @@ class SimpleImage(Component):
|
||||
return FileData(path=str(value), orig_name=Path(value).name)
|
||||
|
||||
def example_payload(self) -> Any:
|
||||
return file(
|
||||
return handle_file(
|
||||
"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
|
||||
)
|
||||
|
||||
|
@ -7,7 +7,7 @@ from typing import Any, List
|
||||
import gradio_client.utils as client_utils
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
from gradio_client import file
|
||||
from gradio_client import handle_file
|
||||
from gradio_client.documentation import document
|
||||
|
||||
from gradio import processing_utils, utils
|
||||
@ -217,7 +217,7 @@ class AnnotatedImage(Component):
|
||||
|
||||
def example_payload(self) -> Any:
|
||||
return {
|
||||
"image": file(
|
||||
"image": handle_file(
|
||||
"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
|
||||
),
|
||||
"annotations": [],
|
||||
|
@ -8,7 +8,7 @@ from typing import Any, Callable, Literal
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
from gradio_client import file
|
||||
from gradio_client import handle_file
|
||||
from gradio_client import utils as client_utils
|
||||
from gradio_client.documentation import document
|
||||
|
||||
@ -186,7 +186,7 @@ class Audio(
|
||||
)
|
||||
|
||||
def example_payload(self) -> Any:
|
||||
return file(
|
||||
return handle_file(
|
||||
"https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav"
|
||||
)
|
||||
|
||||
|
@ -6,7 +6,7 @@ import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal
|
||||
|
||||
from gradio_client import file
|
||||
from gradio_client import handle_file
|
||||
from gradio_client.documentation import document
|
||||
|
||||
from gradio.components.base import Component
|
||||
@ -104,7 +104,7 @@ class DownloadButton(Component):
|
||||
return FileData(path=str(value))
|
||||
|
||||
def example_payload(self) -> dict:
|
||||
return file(
|
||||
return handle_file(
|
||||
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
|
||||
)
|
||||
|
||||
|
@ -8,7 +8,7 @@ from pathlib import Path
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
import gradio_client.utils as client_utils
|
||||
from gradio_client import file
|
||||
from gradio_client import handle_file
|
||||
from gradio_client.documentation import document
|
||||
|
||||
from gradio import processing_utils
|
||||
@ -203,12 +203,12 @@ class File(Component):
|
||||
|
||||
def example_payload(self) -> Any:
|
||||
if self.file_count == "single":
|
||||
return file(
|
||||
return handle_file(
|
||||
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
|
||||
)
|
||||
else:
|
||||
return [
|
||||
file(
|
||||
handle_file(
|
||||
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
|
||||
)
|
||||
]
|
||||
|
@ -9,7 +9,7 @@ from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
from gradio_client import file
|
||||
from gradio_client import handle_file
|
||||
from gradio_client.documentation import document
|
||||
from gradio_client.utils import is_http_url_like
|
||||
|
||||
@ -231,7 +231,7 @@ class Gallery(Component):
|
||||
def example_payload(self) -> Any:
|
||||
return [
|
||||
{
|
||||
"image": file(
|
||||
"image": handle_file(
|
||||
"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
|
||||
)
|
||||
},
|
||||
|
@ -8,7 +8,7 @@ from typing import Any, Literal, cast
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
from gradio_client import file
|
||||
from gradio_client import handle_file
|
||||
from gradio_client.documentation import document
|
||||
from PIL import ImageOps
|
||||
|
||||
@ -216,7 +216,7 @@ class Image(StreamingInput, Component):
|
||||
)
|
||||
|
||||
def example_payload(self) -> Any:
|
||||
return file(
|
||||
return handle_file(
|
||||
"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
|
||||
)
|
||||
|
||||
|
@ -10,7 +10,7 @@ from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
from gradio_client import file
|
||||
from gradio_client import handle_file
|
||||
from gradio_client.documentation import document
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
@ -398,7 +398,7 @@ class ImageEditor(Component):
|
||||
|
||||
def example_payload(self) -> Any:
|
||||
return {
|
||||
"background": file(
|
||||
"background": handle_file(
|
||||
"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
|
||||
),
|
||||
"layers": [],
|
||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from gradio_client import file
|
||||
from gradio_client import handle_file
|
||||
from gradio_client.documentation import document
|
||||
|
||||
from gradio.components.base import Component
|
||||
@ -121,7 +121,7 @@ class Model3D(Component):
|
||||
return Path(input_data).name if input_data else ""
|
||||
|
||||
def example_payload(self):
|
||||
return file(
|
||||
return handle_file(
|
||||
"https://raw.githubusercontent.com/gradio-app/gradio/main/demo/model3D/files/Fox.gltf"
|
||||
)
|
||||
|
||||
|
@ -8,7 +8,7 @@ from pathlib import Path
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
import gradio_client.utils as client_utils
|
||||
from gradio_client import file
|
||||
from gradio_client import handle_file
|
||||
from gradio_client.documentation import document
|
||||
|
||||
from gradio import processing_utils
|
||||
@ -118,12 +118,12 @@ class UploadButton(Component):
|
||||
|
||||
def example_payload(self) -> Any:
|
||||
if self.file_count == "single":
|
||||
return file(
|
||||
return handle_file(
|
||||
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
|
||||
)
|
||||
else:
|
||||
return [
|
||||
file(
|
||||
handle_file(
|
||||
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
|
||||
)
|
||||
]
|
||||
|
@ -7,7 +7,7 @@ import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Literal, Optional
|
||||
|
||||
from gradio_client import file
|
||||
from gradio_client import handle_file
|
||||
from gradio_client import utils as client_utils
|
||||
from gradio_client.documentation import document
|
||||
|
||||
@ -357,7 +357,7 @@ class Video(Component):
|
||||
|
||||
def example_payload(self) -> Any:
|
||||
return {
|
||||
"video": file(
|
||||
"video": handle_file(
|
||||
"https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4"
|
||||
),
|
||||
}
|
||||
|
@ -440,7 +440,6 @@ def from_spaces_blocks(space: str, hf_token: str | None) -> Blocks:
|
||||
client = Client(
|
||||
space,
|
||||
hf_token=hf_token,
|
||||
upload_files=False,
|
||||
download_files=False,
|
||||
_skip_components=False,
|
||||
)
|
||||
|
@ -49,7 +49,7 @@
|
||||
<div bind:this={python_code}>
|
||||
<pre><span class="highlight">from</span> gradio_client <span
|
||||
class="highlight">import</span
|
||||
> Client{#if has_file_path}, file{/if}
|
||||
> Client{#if has_file_path}, handle_file{/if}
|
||||
|
||||
client = Client(<span class="token string">"{root}"</span>)
|
||||
result = client.<span class="highlight">predict</span
|
||||
|
@ -77,7 +77,7 @@ function replace_file_data_with_file_function(obj: any): any {
|
||||
"meta" in obj &&
|
||||
obj.meta?._type === "gradio.FileData"
|
||||
) {
|
||||
return `file('${obj.url}')`;
|
||||
return `handle_file('${obj.url}')`;
|
||||
}
|
||||
}
|
||||
if (Array.isArray(obj)) {
|
||||
@ -101,15 +101,15 @@ function stringify_except_file_function(obj: any): string {
|
||||
}
|
||||
if (
|
||||
typeof value === "string" &&
|
||||
value.startsWith("file(") &&
|
||||
value.startsWith("handle_file(") &&
|
||||
value.endsWith(")")
|
||||
) {
|
||||
return `UNQUOTED${value}`; // Flag the special strings
|
||||
}
|
||||
return value;
|
||||
});
|
||||
const regex = /"UNQUOTEDfile\(([^)]*)\)"/g;
|
||||
jsonString = jsonString.replace(regex, (match, p1) => `file(${p1})`);
|
||||
const regex = /"UNQUOTEDhandle_file\(([^)]*)\)"/g;
|
||||
jsonString = jsonString.replace(regex, (match, p1) => `handle_file(${p1})`);
|
||||
const regexNone = /"UNQUOTEDNone"/g;
|
||||
return jsonString.replace(regexNone, "None");
|
||||
}
|
||||
|
@ -44,10 +44,10 @@ def io_components():
|
||||
@pytest.fixture
|
||||
def connect():
|
||||
@contextmanager
|
||||
def _connect(demo: gr.Blocks, serialize=True, **kwargs):
|
||||
def _connect(demo: gr.Blocks, **kwargs):
|
||||
_, local_url, _ = demo.launch(prevent_thread_lock=True, **kwargs)
|
||||
try:
|
||||
client = Client(local_url, serialize=serialize)
|
||||
client = Client(local_url)
|
||||
yield client
|
||||
finally:
|
||||
client.close()
|
||||
|
@ -384,7 +384,7 @@ class TestTempFile:
|
||||
|
||||
def test_no_empty_image_files(self, gradio_temp_dir, connect):
|
||||
file_dir = pathlib.Path(__file__).parent / "test_files"
|
||||
image = str(file_dir / "bus.png")
|
||||
image = grc.handle_file(str(file_dir / "bus.png"))
|
||||
|
||||
demo = gr.Interface(
|
||||
lambda x: x,
|
||||
@ -400,7 +400,7 @@ class TestTempFile:
|
||||
|
||||
@pytest.mark.parametrize("component", [gr.UploadButton, gr.File])
|
||||
def test_file_component_uploads(self, component, connect, gradio_temp_dir):
|
||||
code_file = str(pathlib.Path(__file__))
|
||||
code_file = grc.handle_file(str(pathlib.Path(__file__)))
|
||||
demo = gr.Interface(lambda x: x.name, component(), gr.File())
|
||||
with connect(demo) as client:
|
||||
_ = client.predict(code_file, api_name="/predict")
|
||||
@ -413,7 +413,7 @@ class TestTempFile:
|
||||
|
||||
def test_no_empty_video_files(self, gradio_temp_dir, connect):
|
||||
file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
|
||||
video = str(file_dir / "video_sample.mp4")
|
||||
video = grc.handle_file(str(file_dir / "video_sample.mp4"))
|
||||
demo = gr.Interface(lambda x: x, gr.Video(), gr.Video())
|
||||
with connect(demo) as client:
|
||||
_ = client.predict({"video": video}, api_name="/predict")
|
||||
@ -423,7 +423,7 @@ class TestTempFile:
|
||||
|
||||
def test_no_empty_audio_files(self, gradio_temp_dir, connect):
|
||||
file_dir = pathlib.Path(pathlib.Path(__file__).parent, "test_files")
|
||||
audio = str(file_dir / "audio_sample.wav")
|
||||
audio = grc.handle_file(str(file_dir / "audio_sample.wav"))
|
||||
|
||||
def reverse_audio(audio):
|
||||
sr, data = audio
|
||||
@ -1716,7 +1716,7 @@ def test_static_files_single_app(connect, gradio_temp_dir):
|
||||
assert len(list(gradio_temp_dir.glob("**/*.*"))) == 0
|
||||
|
||||
with connect(demo) as client:
|
||||
client.predict("test/test_files/bus.png")
|
||||
client.predict(grc.handle_file("test/test_files/bus.png"))
|
||||
|
||||
# Input/Output got saved to cache
|
||||
assert len(list(gradio_temp_dir.glob("**/*.*"))) == 2
|
||||
|
Loading…
Reference in New Issue
Block a user