Files should now be supplied as file(...) in the Client, and some fixes to gr.load() as well (#7575)

* more fixes for gr.load()

* client

* add changeset

* format

* docstring

* add assertion

* warning

* add changeset

* add changeset

* changes

* fixes

* more fixes

* fix files

* add test for dir

* add changeset

* Delete .changeset/giant-bears-check.md

* add changeset

* changes

* add changeset

* print

* format

* add changeset

* docs

* add to tests

* format

* add changeset

* move compatibility code out

* fixed

* changes

* changes

* factory method

* add changeset

* changes

* changes

* sse v2.1

* file()

* changes

* typing

* changes

* cleanup

* changes

* changes

* changes

* fixes

* changes

* fix

* add changeset

* changes

* more changes

* abc

* test

* add payloads

* lint

* test

* lint

* changes

* payload

* fixes

* fix tests

* fix

* clean

* fix frontend

* lint

* add changeset

* cleanup

* format

* get examples to show up in loaded spaces

* add filedata prop to frontend

* add skip component parameter

* address feedback

* with meta

* load

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2024-03-08 12:29:02 -08:00 committed by GitHub
parent 7c66a29dea
commit d0688b3c25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 800 additions and 605 deletions

View File

@ -0,0 +1,8 @@
---
"@gradio/app": minor
"@gradio/client": minor
"gradio": minor
"gradio_client": minor
---
fix:Files should now be supplied as `file(...)` in the Client, and some fixes to `gr.load()` as well

View File

@ -755,7 +755,11 @@ export function api_factory(
}
}
};
} else if (protocol == "sse_v1" || protocol == "sse_v2") {
} else if (
protocol == "sse_v1" ||
protocol == "sse_v2" ||
protocol == "sse_v2.1"
) {
// latest API format. v2 introduces sending diffs for intermediate outputs in generative functions, which makes payloads lighter.
fire_event({
type: "status",
@ -849,7 +853,10 @@ export function api_factory(
endpoint: _endpoint,
fn_index
});
if (data && protocol === "sse_v2") {
if (
data &&
(protocol === "sse_v2" || protocol === "sse_v2.1")
) {
apply_diff_stream(event_id!, data);
}
}

View File

@ -20,7 +20,7 @@ export interface Config {
show_api: boolean;
stylesheets: string[];
path: string;
protocol?: "sse_v2" | "sse_v1" | "sse" | "ws";
protocol?: "sse_v2.1" | "sse_v2" | "sse_v1" | "sse" | "ws";
}
export interface Payload {

View File

@ -69,6 +69,7 @@ export class FileData {
is_stream?: boolean;
mime_type?: string;
alt_text?: string;
readonly meta = { _type: "gradio.FileData" };
constructor({
path,

View File

@ -1,7 +1,8 @@
from gradio_client.client import Client
from gradio_client.utils import __version__
from gradio_client.utils import __version__, file
__all__ = [
"Client",
"file",
"__version__",
]

View File

@ -21,7 +21,6 @@ from typing import Any, Callable, Literal
import httpx
import huggingface_hub
import websockets
from huggingface_hub import CommitOperationAdd, SpaceHardware, SpaceStage
from huggingface_hub.utils import (
RepositoryNotFoundError,
@ -30,9 +29,10 @@ from huggingface_hub.utils import (
)
from packaging import version
from gradio_client import serializing, utils
from gradio_client import utils
from gradio_client.compatibility import EndpointV3Compatibility
from gradio_client.documentation import document
from gradio_client.exceptions import AuthenticationError, SerializationSetupError
from gradio_client.exceptions import AuthenticationError
from gradio_client.utils import (
Communicator,
JobStatus,
@ -71,14 +71,16 @@ class Client:
src: str,
hf_token: str | None = None,
max_workers: int = 40,
serialize: bool | None = None,
output_dir: str | Path = DEFAULT_TEMP_DIR,
serialize: bool | None = None, # TODO: remove in 1.0
output_dir: str
| Path = DEFAULT_TEMP_DIR, # Maybe this can be combined with `download_files` in 1.0
verbose: bool = True,
auth: tuple[str, str] | None = None,
*,
headers: dict[str, str] | None = None,
upload_files: bool = True,
download_files: bool = True,
upload_files: bool = True, # TODO: remove and hardcode to False in 1.0
download_files: bool = True, # TODO: consider setting to False in 1.0
_skip_components: bool = True, # internal parameter to skip values certain components (e.g. State) that do not need to be displayed to users.
):
"""
Parameters:
@ -89,8 +91,8 @@ class Client:
output_dir: The directory to save files that are downloaded from the remote API. If None, reads from the GRADIO_TEMP_DIR environment variable. Defaults to a temporary directory on your machine.
verbose: Whether the client should print statements to the console.
headers: Additional headers to send to the remote Gradio app on every request. By default only the HF authorization and user-agent headers are sent. These headers will override the default headers if they have the same keys.
upload_files: Whether the client should treat input string filepath as files and upload them to the remote server. If False, the client will treat input string filepaths as strings always and not modify them.
download_files: Whether the client should download output files from the remote API and return them as string filepaths on the local machine. If False, the client will a FileData dataclass object with the filepath on the remote machine instead.
upload_files: Whether the client should treat input string filepath as files and upload them to the remote server. If False, the client will treat input string filepaths as strings always and not modify them, and files should be passed in explicitly using `gradio_client.file("path/to/file/or/url")` instead. This parameter will be deleted and False will become the default in a future version.
download_files: Whether the client should download output files from the remote API and return them as string filepaths on the local machine. If False, the client will return a FileData dataclass object with the filepath on the remote machine instead.
"""
self.verbose = verbose
self.hf_token = hf_token
@ -101,6 +103,7 @@ class Client:
upload_files = serialize
self.upload_files = upload_files
self.download_files = download_files
self._skip_components = _skip_components
self.headers = build_hf_headers(
token=hf_token,
library_name="gradio_client",
@ -143,7 +146,9 @@ class Client:
self._login(auth)
self.config = self._get_config()
self.protocol: str = self.config.get("protocol", "ws")
self.protocol: Literal[
"ws", "sse", "sse_v1", "sse_v2", "sse_v2.1"
] = self.config.get("protocol", "ws")
self.api_url = urllib.parse.urljoin(self.src, utils.API_URL)
self.sse_url = urllib.parse.urljoin(
self.src, utils.SSE_URL_V0 if self.protocol == "sse" else utils.SSE_URL
@ -445,6 +450,7 @@ class Client:
"sse",
"sse_v1",
"sse_v2",
"sse_v2.1",
):
helper = self.new_helper(inferred_fn_index)
end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper)
@ -972,8 +978,7 @@ class Endpoint:
# This is still hacky as it does not tell us which part of the payload is a file.
# If a component has a complex payload, part of which is a file, this will simply
# return True, which means that all parts of the payload will be uploaded as files
# if they are valid file paths. The better approach would be to traverse the
# component's api_info and figure out exactly which part of the payload is a file.
# if they are valid file paths. We will deprecate this 1.0.
if "api_info" not in component:
return False
return utils.value_is_file(component["api_info"])
@ -990,11 +995,12 @@ class Endpoint:
def _inner(*data):
if not self.is_valid:
raise utils.InvalidAPIEndpointError()
data = self.insert_state(*data)
if self.client.upload_files:
data = self.serialize(*data)
data = self.insert_empty_state(*data)
data = self.process_input_files(*data)
predictions = _predict(*data)
predictions = self.process_predictions(*predictions)
# Append final output only if not already present
# for consistency between generators and not generators
if helper:
@ -1022,7 +1028,7 @@ class Endpoint:
result = utils.synchronize_async(
self._sse_fn_v0, data, hash_data, helper
)
elif self.protocol in ("sse_v1", "sse_v2"):
elif self.protocol in ("sse_v1", "sse_v2", "sse_v2.1"):
event_id = utils.synchronize_async(
self.client.send_data, data, hash_data
)
@ -1059,68 +1065,52 @@ class Endpoint:
return _predict
def _predict_resolve(self, *data) -> Any:
"""Needed for gradio.load(), which has a slightly different signature for serializing/deserializing"""
outputs = self.make_predict()(*data)
if len(self.dependency["outputs"]) == 1:
return outputs[0]
return outputs
def _upload(
self, file_paths: list[str | list[str]]
) -> list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]]:
if not file_paths:
return []
# Put all the filepaths in one file
# but then keep track of which index in the
# original list they came from so we can recreate
# the original structure
files = []
indices = []
for i, fs in enumerate(file_paths):
if not isinstance(fs, list):
fs = [fs]
for f in fs:
files.append(("files", (Path(f).name, open(f, "rb")))) # noqa: SIM115
indices.append(i)
r = httpx.post(
self.client.upload_url,
headers=self.client.headers,
cookies=self.client.cookies,
files=files,
)
if r.status_code != 200:
uploaded = file_paths
else:
uploaded = []
result = r.json()
for i, fs in enumerate(file_paths):
if isinstance(fs, list):
output = [o for ix, o in enumerate(result) if indices[ix] == i]
res = [
{
"path": o,
"orig_name": Path(f).name,
}
for f, o in zip(fs, output)
]
else:
o = next(o for ix, o in enumerate(result) if indices[ix] == i)
res = {
"path": o,
"orig_name": Path(fs).name,
}
uploaded.append(res)
return uploaded
def insert_state(self, *data) -> tuple:
def insert_empty_state(self, *data) -> tuple:
data = list(data)
for i, input_component_type in enumerate(self.input_component_types):
if input_component_type.is_state:
data.insert(i, None)
return tuple(data)
def process_input_files(self, *data) -> tuple:
data_ = []
for i, d in enumerate(data):
if self.client.upload_files and self.input_component_types[i].value_is_file:
d = utils.traverse(
d,
self._upload_file,
lambda f: utils.is_filepath(f)
or utils.is_file_obj_with_meta(f)
or utils.is_http_url_like(f),
)
elif not self.client.upload_files:
d = utils.traverse(d, self._upload_file, utils.is_file_obj_with_meta)
data_.append(d)
return tuple(data_)
def process_predictions(self, *predictions):
# If self.download_file is True, we assume that that the user is using the Client directly (as opposed
# within gr.load) and therefore, download any files generated by the server and skip values for
# components that the user likely does not want to see (e.g. gr.State, gr.Tab).
if self.client.download_files:
predictions = self.download_files(*predictions)
if self.client._skip_components:
predictions = self.remove_skipped_components(*predictions)
predictions = self.reduce_singleton_output(*predictions)
return predictions
def download_files(self, *data) -> tuple:
data_ = list(data)
if self.client.protocol == "sse_v2.1":
data_ = utils.traverse(
data_, self._download_file, utils.is_file_obj_with_meta
)
else:
data_ = utils.traverse(data_, self._download_file, utils.is_file_obj)
return tuple(data_)
def remove_skipped_components(self, *data) -> tuple:
""""""
data = [d for d, oct in zip(data, self.output_component_types) if not oct.skip]
return tuple(data)
@ -1130,88 +1120,32 @@ class Endpoint:
else:
return data
def _gather_files(self, *data):
file_list = []
def get_file(d):
if utils.is_file_obj(d):
file_list.append(d["path"])
else:
file_list.append(d)
return ReplaceMe(len(file_list) - 1)
def handle_url(s):
return {"path": s, "orig_name": s.split("/")[-1]}
new_data = []
for i, d in enumerate(data):
if self.input_component_types[i].value_is_file:
# Check file dicts and filepaths to upload
# file dict is a corner case but still needed for completeness
# most users should be using filepaths
d = utils.traverse(
d, get_file, lambda s: utils.is_file_obj(s) or utils.is_filepath(s)
)
# Handle URLs here since we don't upload them
d = utils.traverse(d, handle_url, lambda s: utils.is_url(s))
new_data.append(d)
return file_list, new_data
def _add_uploaded_files_to_data(self, data: list[Any], files: list[Any]):
def replace(d: ReplaceMe) -> dict:
return files[d.index]
new_data = []
for d in data:
d = utils.traverse(
d, replace, is_root=lambda node: isinstance(node, ReplaceMe)
def _upload_file(self, f: str | dict):
if isinstance(f, str):
warnings.warn(
f'The Client is treating: "{f}" as a file path. In future versions, this behavior will not happen automatically. '
f'\n\nInstead, please provide file path or URLs like this: gradio_client.file("{f}"). '
"\n\nNote: to stop treating strings as filepaths unless file() is used, set upload_files=False in Client()."
)
new_data.append(d)
return new_data
def serialize(self, *data) -> tuple:
files, new_data = self._gather_files(*data)
uploaded_files = self._upload(files)
data = list(new_data)
data = self._add_uploaded_files_to_data(data, uploaded_files)
o = tuple(data)
return o
def download_file(self, file_data: dict) -> str | None:
if file_data is None:
return None
if isinstance(file_data, str):
file_name = utils.decode_base64_to_file(
file_data, dir=self.client.output_dir
).name
elif isinstance(file_data, dict):
filepath = file_data.get("path")
if not filepath:
raise ValueError(f"The 'path' field is missing in {file_data}")
file_name = utils.download_file(
self.root_url + "file=" + filepath,
save_dir=self.client.output_dir,
file_path = f
else:
file_path = f["path"]
if not utils.is_http_url_like(file_path):
file_path = utils.upload_file(
file_path=file_path,
upload_url=self.client.upload_url,
headers=self.client.headers,
cookies=self.client.cookies,
)
return {"path": file_path}
else:
raise ValueError(
f"A FileSerializable component can only deserialize a string or a dict, not a {type(file_name)}: {file_name}"
)
return file_name
def deserialize(self, *data) -> tuple:
data_ = list(data)
data_: list[Any] = utils.traverse(data_, self.download_file, utils.is_file_obj)
return tuple(data_)
def process_predictions(self, *predictions):
if self.client.download_files:
predictions = self.deserialize(*predictions)
predictions = self.remove_skipped_components(*predictions)
predictions = self.reduce_singleton_output(*predictions)
return predictions
def _download_file(self, x: dict) -> str | None:
return utils.download_file(
self.root_url + "file=" + x["path"],
save_dir=self.client.output_dir,
headers=self.client.headers,
cookies=self.client.cookies,
)
async def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator):
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=None)) as client:
@ -1227,7 +1161,10 @@ class Endpoint:
)
async def _sse_fn_v1_v2(
self, helper: Communicator, event_id: str, protocol: Literal["sse_v1", "sse_v2"]
self,
helper: Communicator,
event_id: str,
protocol: Literal["sse_v1", "sse_v2", "sse_v2.1"],
):
return await utils.get_pred_from_sse_v1_v2(
helper,
@ -1239,313 +1176,6 @@ class Endpoint:
)
class EndpointV3Compatibility:
"""Endpoint class for connecting to v3 endpoints. Backwards compatibility."""
def __init__(self, client: Client, fn_index: int, dependency: dict, *_args):
self.client: Client = client
self.fn_index = fn_index
self.dependency = dependency
api_name = dependency.get("api_name")
self.api_name: str | Literal[False] | None = (
"/" + api_name if isinstance(api_name, str) else api_name
)
self.use_ws = self._use_websocket(self.dependency)
self.protocol = "ws" if self.use_ws else "http"
self.input_component_types = []
self.output_component_types = []
self.root_url = client.src + "/" if not client.src.endswith("/") else client.src
self.is_continuous = dependency.get("types", {}).get("continuous", False)
try:
# Only a real API endpoint if backend_fn is True (so not just a frontend function), serializers are valid,
# and api_name is not False (meaning that the developer has explicitly disabled the API endpoint)
self.serializers, self.deserializers = self._setup_serializers()
self.is_valid = self.dependency["backend_fn"] and self.api_name is not False
except SerializationSetupError:
self.is_valid = False
self.backend_fn = dependency.get("backend_fn")
self.show_api = True
def __repr__(self):
return f"Endpoint src: {self.client.src}, api_name: {self.api_name}, fn_index: {self.fn_index}"
def __str__(self):
return self.__repr__()
def make_end_to_end_fn(self, helper: Communicator | None = None):
_predict = self.make_predict(helper)
def _inner(*data):
if not self.is_valid:
raise utils.InvalidAPIEndpointError()
data = self.insert_state(*data)
if self.client.upload_files:
data = self.serialize(*data)
predictions = _predict(*data)
predictions = self.process_predictions(*predictions)
# Append final output only if not already present
# for consistency between generators and not generators
if helper:
with helper.lock:
if not helper.job.outputs:
helper.job.outputs.append(predictions)
return predictions
return _inner
def make_predict(self, helper: Communicator | None = None):
def _predict(*data) -> tuple:
data = json.dumps(
{
"data": data,
"fn_index": self.fn_index,
"session_hash": self.client.session_hash,
}
)
hash_data = json.dumps(
{
"fn_index": self.fn_index,
"session_hash": self.client.session_hash,
}
)
if self.use_ws:
result = utils.synchronize_async(self._ws_fn, data, hash_data, helper)
if "error" in result:
raise ValueError(result["error"])
else:
response = httpx.post(
self.client.api_url, headers=self.client.headers, json=data
)
result = json.loads(response.content.decode("utf-8"))
try:
output = result["data"]
except KeyError as ke:
is_public_space = (
self.client.space_id
and not huggingface_hub.space_info(self.client.space_id).private
)
if "error" in result and "429" in result["error"] and is_public_space:
raise utils.TooManyRequestsError(
f"Too many requests to the API, please try again later. To avoid being rate-limited, "
f"please duplicate the Space using Client.duplicate({self.client.space_id}) "
f"and pass in your Hugging Face token."
) from None
elif "error" in result:
raise ValueError(result["error"]) from None
raise KeyError(
f"Could not find 'data' key in response. Response received: {result}"
) from ke
return tuple(output)
return _predict
def _predict_resolve(self, *data) -> Any:
"""Needed for gradio.load(), which has a slightly different signature for serializing/deserializing"""
outputs = self.make_predict()(*data)
if len(self.dependency["outputs"]) == 1:
return outputs[0]
return outputs
def _upload(
self, file_paths: list[str | list[str]]
) -> list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]]:
if not file_paths:
return []
# Put all the filepaths in one file
# but then keep track of which index in the
# original list they came from so we can recreate
# the original structure
files = []
indices = []
for i, fs in enumerate(file_paths):
if not isinstance(fs, list):
fs = [fs]
for f in fs:
files.append(("files", (Path(f).name, open(f, "rb")))) # noqa: SIM115
indices.append(i)
r = httpx.post(self.client.upload_url, headers=self.client.headers, files=files)
if r.status_code != 200:
uploaded = file_paths
else:
uploaded = []
result = r.json()
for i, fs in enumerate(file_paths):
if isinstance(fs, list):
output = [o for ix, o in enumerate(result) if indices[ix] == i]
res = [
{
"is_file": True,
"name": o,
"orig_name": Path(f).name,
"data": None,
}
for f, o in zip(fs, output)
]
else:
o = next(o for ix, o in enumerate(result) if indices[ix] == i)
res = {
"is_file": True,
"name": o,
"orig_name": Path(fs).name,
"data": None,
}
uploaded.append(res)
return uploaded
def _add_uploaded_files_to_data(
self,
files: list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]],
data: list[Any],
) -> None:
"""Helper function to modify the input data with the uploaded files."""
file_counter = 0
for i, t in enumerate(self.input_component_types):
if t in ["file", "uploadbutton"]:
data[i] = files[file_counter]
file_counter += 1
def insert_state(self, *data) -> tuple:
data = list(data)
for i, input_component_type in enumerate(self.input_component_types):
if input_component_type == utils.STATE_COMPONENT:
data.insert(i, None)
return tuple(data)
def remove_skipped_components(self, *data) -> tuple:
data = [
d
for d, oct in zip(data, self.output_component_types)
if oct not in utils.SKIP_COMPONENTS
]
return tuple(data)
def reduce_singleton_output(self, *data) -> Any:
if (
len(
[
oct
for oct in self.output_component_types
if oct not in utils.SKIP_COMPONENTS
]
)
== 1
):
return data[0]
else:
return data
def serialize(self, *data) -> tuple:
if len(data) != len(self.serializers):
raise ValueError(
f"Expected {len(self.serializers)} arguments, got {len(data)}"
)
files = [
f
for f, t in zip(data, self.input_component_types)
if t in ["file", "uploadbutton"]
]
uploaded_files = self._upload(files)
data = list(data)
self._add_uploaded_files_to_data(uploaded_files, data)
o = tuple([s.serialize(d) for s, d in zip(self.serializers, data)])
return o
def deserialize(self, *data) -> tuple:
if len(data) != len(self.deserializers):
raise ValueError(
f"Expected {len(self.deserializers)} outputs, got {len(data)}"
)
outputs = tuple(
[
s.deserialize(
d,
save_dir=self.client.output_dir,
hf_token=self.client.hf_token,
root_url=self.root_url,
)
for s, d in zip(self.deserializers, data)
]
)
return outputs
def process_predictions(self, *predictions):
if self.client.download_files:
predictions = self.deserialize(*predictions)
predictions = self.remove_skipped_components(*predictions)
predictions = self.reduce_singleton_output(*predictions)
return predictions
def _setup_serializers(
self,
) -> tuple[list[serializing.Serializable], list[serializing.Serializable]]:
inputs = self.dependency["inputs"]
serializers = []
for i in inputs:
for component in self.client.config["components"]:
if component["id"] == i:
component_name = component["type"]
self.input_component_types.append(component_name)
if component.get("serializer"):
serializer_name = component["serializer"]
if serializer_name not in serializing.SERIALIZER_MAPPING:
raise SerializationSetupError(
f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
)
serializer = serializing.SERIALIZER_MAPPING[serializer_name]
elif component_name in serializing.COMPONENT_MAPPING:
serializer = serializing.COMPONENT_MAPPING[component_name]
else:
raise SerializationSetupError(
f"Unknown component: {component_name}, you may need to update your gradio_client version."
)
serializers.append(serializer()) # type: ignore
outputs = self.dependency["outputs"]
deserializers = []
for i in outputs:
for component in self.client.config["components"]:
if component["id"] == i:
component_name = component["type"]
self.output_component_types.append(component_name)
if component.get("serializer"):
serializer_name = component["serializer"]
if serializer_name not in serializing.SERIALIZER_MAPPING:
raise SerializationSetupError(
f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
)
deserializer = serializing.SERIALIZER_MAPPING[serializer_name]
elif component_name in utils.SKIP_COMPONENTS:
deserializer = serializing.SimpleSerializable
elif component_name in serializing.COMPONENT_MAPPING:
deserializer = serializing.COMPONENT_MAPPING[component_name]
else:
raise SerializationSetupError(
f"Unknown component: {component_name}, you may need to update your gradio_client version."
)
deserializers.append(deserializer()) # type: ignore
return serializers, deserializers
def _use_websocket(self, dependency: dict) -> bool:
queue_enabled = self.client.config.get("enable_queue", False)
queue_uses_websocket = version.parse(
self.client.config.get("version", "2.0")
) >= version.Version("3.2")
dependency_uses_queue = dependency.get("queue", False) is not False
return queue_enabled and queue_uses_websocket and dependency_uses_queue
async def _ws_fn(self, data, hash_data, helper: Communicator):
async with websockets.connect( # type: ignore
self.client.ws_url,
open_timeout=10,
extra_headers=self.client.headers,
max_size=1024 * 1024 * 1024,
) as websocket:
return await utils.get_pred_from_ws(websocket, data, hash_data, helper)
@document("result", "outputs", "status")
class Job(Future):
"""

View File

@ -0,0 +1,327 @@
""" This module contains the EndpointV3Compatibility class, which is used to connect to Gradio apps running 3.x.x versions of Gradio."""
from __future__ import annotations
import json
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
import httpx
import huggingface_hub
import websockets
from packaging import version
from gradio_client import serializing, utils
from gradio_client.exceptions import SerializationSetupError
from gradio_client.utils import (
Communicator,
)
if TYPE_CHECKING:
from gradio_client import Client
class EndpointV3Compatibility:
"""Endpoint class for connecting to v3 endpoints. Backwards compatibility."""
def __init__(self, client: Client, fn_index: int, dependency: dict, *_args):
self.client: Client = client
self.fn_index = fn_index
self.dependency = dependency
api_name = dependency.get("api_name")
self.api_name: str | Literal[False] | None = (
"/" + api_name if isinstance(api_name, str) else api_name
)
self.use_ws = self._use_websocket(self.dependency)
self.protocol = "ws" if self.use_ws else "http"
self.input_component_types = []
self.output_component_types = []
self.root_url = client.src + "/" if not client.src.endswith("/") else client.src
self.is_continuous = dependency.get("types", {}).get("continuous", False)
try:
# Only a real API endpoint if backend_fn is True (so not just a frontend function), serializers are valid,
# and api_name is not False (meaning that the developer has explicitly disabled the API endpoint)
self.serializers, self.deserializers = self._setup_serializers()
self.is_valid = self.dependency["backend_fn"] and self.api_name is not False
except SerializationSetupError:
self.is_valid = False
self.backend_fn = dependency.get("backend_fn")
self.show_api = True
def __repr__(self):
return f"Endpoint src: {self.client.src}, api_name: {self.api_name}, fn_index: {self.fn_index}"
def __str__(self):
return self.__repr__()
def make_end_to_end_fn(self, helper: Communicator | None = None):
_predict = self.make_predict(helper)
def _inner(*data):
if not self.is_valid:
raise utils.InvalidAPIEndpointError()
data = self.insert_state(*data)
if self.client.upload_files:
data = self.serialize(*data)
predictions = _predict(*data)
predictions = self.process_predictions(*predictions)
# Append final output only if not already present
# for consistency between generators and not generators
if helper:
with helper.lock:
if not helper.job.outputs:
helper.job.outputs.append(predictions)
return predictions
return _inner
def make_predict(self, helper: Communicator | None = None):
def _predict(*data) -> tuple:
data = json.dumps(
{
"data": data,
"fn_index": self.fn_index,
"session_hash": self.client.session_hash,
}
)
hash_data = json.dumps(
{
"fn_index": self.fn_index,
"session_hash": self.client.session_hash,
}
)
if self.use_ws:
result = utils.synchronize_async(self._ws_fn, data, hash_data, helper)
if "error" in result:
raise ValueError(result["error"])
else:
response = httpx.post(
self.client.api_url, headers=self.client.headers, json=data
)
result = json.loads(response.content.decode("utf-8"))
try:
output = result["data"]
except KeyError as ke:
is_public_space = (
self.client.space_id
and not huggingface_hub.space_info(self.client.space_id).private
)
if "error" in result and "429" in result["error"] and is_public_space:
raise utils.TooManyRequestsError(
f"Too many requests to the API, please try again later. To avoid being rate-limited, "
f"please duplicate the Space using Client.duplicate({self.client.space_id}) "
f"and pass in your Hugging Face token."
) from None
elif "error" in result:
raise ValueError(result["error"]) from None
raise KeyError(
f"Could not find 'data' key in response. Response received: {result}"
) from ke
return tuple(output)
return _predict
def _predict_resolve(self, *data) -> Any:
"""Needed for gradio.load(), which has a slightly different signature for serializing/deserializing"""
outputs = self.make_predict()(*data)
if len(self.dependency["outputs"]) == 1:
return outputs[0]
return outputs
def _upload(
self, file_paths: list[str | list[str]]
) -> list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]]:
if not file_paths:
return []
# Put all the filepaths in one file
# but then keep track of which index in the
# original list they came from so we can recreate
# the original structure
files = []
indices = []
for i, fs in enumerate(file_paths):
if not isinstance(fs, list):
fs = [fs]
for f in fs:
files.append(("files", (Path(f).name, open(f, "rb")))) # noqa: SIM115
indices.append(i)
r = httpx.post(self.client.upload_url, headers=self.client.headers, files=files)
if r.status_code != 200:
uploaded = file_paths
else:
uploaded = []
result = r.json()
for i, fs in enumerate(file_paths):
if isinstance(fs, list):
output = [o for ix, o in enumerate(result) if indices[ix] == i]
res = [
{
"is_file": True,
"name": o,
"orig_name": Path(f).name,
"data": None,
}
for f, o in zip(fs, output)
]
else:
o = next(o for ix, o in enumerate(result) if indices[ix] == i)
res = {
"is_file": True,
"name": o,
"orig_name": Path(fs).name,
"data": None,
}
uploaded.append(res)
return uploaded
def _add_uploaded_files_to_data(
self,
files: list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]],
data: list[Any],
) -> None:
"""Helper function to modify the input data with the uploaded files."""
file_counter = 0
for i, t in enumerate(self.input_component_types):
if t in ["file", "uploadbutton"]:
data[i] = files[file_counter]
file_counter += 1
def insert_state(self, *data) -> tuple:
data = list(data)
for i, input_component_type in enumerate(self.input_component_types):
if input_component_type == utils.STATE_COMPONENT:
data.insert(i, None)
return tuple(data)
def remove_skipped_components(self, *data) -> tuple:
data = [
d
for d, oct in zip(data, self.output_component_types)
if oct not in utils.SKIP_COMPONENTS
]
return tuple(data)
def reduce_singleton_output(self, *data) -> Any:
if (
len(
[
oct
for oct in self.output_component_types
if oct not in utils.SKIP_COMPONENTS
]
)
== 1
):
return data[0]
else:
return data
def serialize(self, *data) -> tuple:
if len(data) != len(self.serializers):
raise ValueError(
f"Expected {len(self.serializers)} arguments, got {len(data)}"
)
files = [
f
for f, t in zip(data, self.input_component_types)
if t in ["file", "uploadbutton"]
]
uploaded_files = self._upload(files)
data = list(data)
self._add_uploaded_files_to_data(uploaded_files, data)
o = tuple([s.serialize(d) for s, d in zip(self.serializers, data)])
return o
def deserialize(self, *data) -> tuple:
if len(data) != len(self.deserializers):
raise ValueError(
f"Expected {len(self.deserializers)} outputs, got {len(data)}"
)
outputs = tuple(
[
s.deserialize(
d,
save_dir=self.client.output_dir,
hf_token=self.client.hf_token,
root_url=self.root_url,
)
for s, d in zip(self.deserializers, data)
]
)
return outputs
def process_predictions(self, *predictions):
if self.client.download_files:
predictions = self.deserialize(*predictions)
predictions = self.remove_skipped_components(*predictions)
predictions = self.reduce_singleton_output(*predictions)
return predictions
def _setup_serializers(
self,
) -> tuple[list[serializing.Serializable], list[serializing.Serializable]]:
inputs = self.dependency["inputs"]
serializers = []
for i in inputs:
for component in self.client.config["components"]:
if component["id"] == i:
component_name = component["type"]
self.input_component_types.append(component_name)
if component.get("serializer"):
serializer_name = component["serializer"]
if serializer_name not in serializing.SERIALIZER_MAPPING:
raise SerializationSetupError(
f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
)
serializer = serializing.SERIALIZER_MAPPING[serializer_name]
elif component_name in serializing.COMPONENT_MAPPING:
serializer = serializing.COMPONENT_MAPPING[component_name]
else:
raise SerializationSetupError(
f"Unknown component: {component_name}, you may need to update your gradio_client version."
)
serializers.append(serializer()) # type: ignore
outputs = self.dependency["outputs"]
deserializers = []
for i in outputs:
for component in self.client.config["components"]:
if component["id"] == i:
component_name = component["type"]
self.output_component_types.append(component_name)
if component.get("serializer"):
serializer_name = component["serializer"]
if serializer_name not in serializing.SERIALIZER_MAPPING:
raise SerializationSetupError(
f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
)
deserializer = serializing.SERIALIZER_MAPPING[serializer_name]
elif component_name in utils.SKIP_COMPONENTS:
deserializer = serializing.SimpleSerializable
elif component_name in serializing.COMPONENT_MAPPING:
deserializer = serializing.COMPONENT_MAPPING[component_name]
else:
raise SerializationSetupError(
f"Unknown component: {component_name}, you may need to update your gradio_client version."
)
deserializers.append(deserializer()) # type: ignore
return serializers, deserializers
def _use_websocket(self, dependency: dict) -> bool:
queue_enabled = self.client.config.get("enable_queue", False)
queue_uses_websocket = version.parse(
self.client.config.get("version", "2.0")
) >= version.Version("3.2")
dependency_uses_queue = dependency.get("queue", False) is not False
return queue_enabled and queue_uses_websocket and dependency_uses_queue
async def _ws_fn(self, data, hash_data, helper: Communicator):
async with websockets.connect( # type: ignore
self.client.ws_url,
open_timeout=10,
extra_headers=self.client.headers,
max_size=1024 * 1024 * 1024,
) as websocket:
return await utils.get_pred_from_ws(websocket, data, hash_data, helper)

View File

@ -246,10 +246,12 @@ class Communicator:
########################
def is_http_url_like(possible_url: str) -> bool:
def is_http_url_like(possible_url) -> bool:
"""
Check if the given string looks like an HTTP(S) URL.
Check if the given value is a string that looks like an HTTP(S) URL.
"""
if not isinstance(possible_url, str):
return False
return possible_url.startswith(("http://", "https://"))
@ -390,7 +392,7 @@ async def get_pred_from_sse_v1_v2(
cookies: dict[str, str] | None,
pending_messages_per_event: dict[str, list[Message | None]],
event_id: str,
protocol: Literal["sse_v1", "sse_v2"],
protocol: Literal["sse_v1", "sse_v2", "sse_v2.1"],
) -> dict[str, Any] | None:
done, pending = await asyncio.wait(
[
@ -510,7 +512,7 @@ async def stream_sse_v1_v2(
helper: Communicator,
pending_messages_per_event: dict[str, list[Message | None]],
event_id: str,
protocol: Literal["sse_v1", "sse_v2"],
protocol: Literal["sse_v1", "sse_v2", "sse_v2.1"],
) -> dict[str, Any]:
try:
pending_messages = pending_messages_per_event[event_id]
@ -546,10 +548,10 @@ async def stream_sse_v1_v2(
log=log_message,
)
output = msg.get("output", {}).get("data", [])
if (
msg["msg"] == ServerMessage.process_generating
and protocol == "sse_v2"
):
if msg["msg"] == ServerMessage.process_generating and protocol in [
"sse_v2",
"sse_v2.1",
]:
if pending_responses_for_diffs is None:
pending_responses_for_diffs = list(output)
else:
@ -623,6 +625,20 @@ def apply_diff(obj, diff):
########################
def upload_file(
file_path: str,
upload_url: str,
headers: dict[str, str] | None = None,
cookies: dict[str, str] | None = None,
):
with open(file_path, "rb") as f:
files = [("files", (Path(file_path).name, f))]
r = httpx.post(upload_url, headers=headers, cookies=cookies, files=files)
r.raise_for_status()
result = r.json()
return result[0]
def download_file(
url_path: str,
save_dir: str,
@ -898,13 +914,18 @@ def get_type(schema: dict):
raise APIInfoParseError(f"Cannot parse type for {schema}")
OLD_FILE_DATA = "Dict(path: str, url: str | None, size: int | None, orig_name: str | None, mime_type: str | None)"
FILE_DATA = "Dict(path: str, url: str | None, size: int | None, orig_name: str | None, mime_type: str | None, is_stream: bool)"
FILE_DATA_FORMATS = [
"Dict(path: str, url: str | None, size: int | None, orig_name: str | None, mime_type: str | None)",
"Dict(path: str, url: str | None, size: int | None, orig_name: str | None, mime_type: str | None, is_stream: bool)",
"Dict(path: str, url: str | None, size: int | None, orig_name: str | None, mime_type: str | None, is_stream: bool, meta: Dict())",
]
CURRENT_FILE_DATA_FORMAT = FILE_DATA_FORMATS[-1]
def json_schema_to_python_type(schema: Any) -> str:
type_ = _json_schema_to_python_type(schema, schema.get("$defs"))
return type_.replace(FILE_DATA, "filepath")
return type_.replace(CURRENT_FILE_DATA_FORMAT, "filepath")
def _json_schema_to_python_type(schema: Any, defs) -> str:
@ -980,7 +1001,7 @@ def _json_schema_to_python_type(schema: Any, defs) -> str:
raise APIInfoParseError(f"Cannot parse schema {schema}")
def traverse(json_obj: Any, func: Callable, is_root: Callable) -> Any:
def traverse(json_obj: Any, func: Callable, is_root: Callable[..., bool]) -> Any:
if is_root(json_obj):
return func(json_obj)
elif isinstance(json_obj, dict):
@ -999,27 +1020,57 @@ def traverse(json_obj: Any, func: Callable, is_root: Callable) -> Any:
def value_is_file(api_info: dict) -> bool:
info = _json_schema_to_python_type(api_info, api_info.get("$defs"))
return FILE_DATA in info or OLD_FILE_DATA in info
return any(file_data_format in info for file_data_format in FILE_DATA_FORMATS)
def is_filepath(s):
return isinstance(s, str) and Path(s).exists()
def is_filepath(s) -> bool:
"""
Check if the given value is a valid str or Path filepath on the local filesystem, e.g. "path/to/file".
"""
return isinstance(s, (str, Path)) and Path(s).exists() and Path(s).is_file()
def is_url(s):
return isinstance(s, str) and is_http_url_like(s)
def is_file_obj(d) -> bool:
"""
Check if the given value is a valid FileData object dictionary in versions of Gradio<=4.20, e.g.
{
"path": "path/to/file",
}
"""
return isinstance(d, dict) and "path" in d and isinstance(d["path"], str)
def is_file_obj(d):
return isinstance(d, dict) and "path" in d
def is_file_obj_with_url(d):
def is_file_obj_with_meta(d) -> bool:
"""
Check if the given value is a valid FileData object dictionary in newer versions of Gradio
where the file objects include a specific "meta" key, e.g.
{
"path": "path/to/file",
"meta": {"_type: "gradio.FileData"}
}
"""
return (
isinstance(d, dict) and "path" in d and "url" in d and isinstance(d["url"], str)
isinstance(d, dict)
and "path" in d
and isinstance(d["path"], str)
and "meta" in d
and d["meta"].get("_type", "") == "gradio.FileData"
)
def is_file_obj_with_url(d) -> bool:
"""
Check if the given value is a valid FileData object dictionary in newer versions of Gradio
where the file objects include a specific "meta" key, and ALSO include a "url" key, e.g.
{
"path": "path/to/file",
"url": "/file=path/to/file",
"meta": {"_type: "gradio.FileData"}
}
"""
return is_file_obj_with_meta(d) and "url" in d and isinstance(d["url"], str)
SKIP_COMPONENTS = {
"state",
"row",
@ -1034,3 +1085,16 @@ SKIP_COMPONENTS = {
"interpretation",
"dataset",
}
def file(filepath_or_url: str | Path):
s = str(filepath_or_url)
data = {"path": s, "meta": {"_type": "gradio.FileData"}}
if is_http_url_like(s):
return {**data, "orig_name": s.split("/")[-1], "url": s}
elif Path(s).exists():
return {**data, "orig_name": Path(s).name}
else:
raise ValueError(
f"File {s} does not exist on local filesystem and is not a valid URL."
)

View File

@ -19,7 +19,7 @@ from gradio.networking import Server
from huggingface_hub import HfFolder
from huggingface_hub.utils import RepositoryNotFoundError
from gradio_client import Client
from gradio_client import Client, file
from gradio_client.client import DEFAULT_TEMP_DIR
from gradio_client.exceptions import AuthenticationError
from gradio_client.utils import (
@ -255,7 +255,9 @@ class TestClientPredictions:
with connect(video_component) as client:
job = client.submit(
{
"video": "https://huggingface.co/spaces/gradio/video_component/resolve/main/files/a.mp4"
"video": file(
"https://huggingface.co/spaces/gradio/video_component/resolve/main/files/a.mp4"
)
},
fn_index=0,
)
@ -269,7 +271,9 @@ class TestClientPredictions:
with connect(video_component, output_dir=temp_dir) as client:
job = client.submit(
{
"video": "https://huggingface.co/spaces/gradio/video_component/resolve/main/files/a.mp4"
"video": file(
"https://huggingface.co/spaces/gradio/video_component/resolve/main/files/a.mp4"
)
},
fn_index=0,
)
@ -370,13 +374,15 @@ class TestClientPredictions:
def test_stream_audio(self, stream_audio):
with connect(stream_audio) as client:
job1 = client.submit(
"https://gradio-builds.s3.amazonaws.com/demo-files/bark_demo.mp4",
file("https://gradio-builds.s3.amazonaws.com/demo-files/bark_demo.mp4"),
api_name="/predict",
)
assert Path(job1.result()).exists()
job2 = client.submit(
"https://gradio-builds.s3.amazonaws.com/demo-files/audio_sample.wav",
file(
"https://gradio-builds.s3.amazonaws.com/demo-files/audio_sample.wav"
),
api_name="/predict",
)
assert Path(job2.result()).exists()
@ -497,6 +503,13 @@ class TestClientPredictions:
ret = client.predict(message, initial_history, api_name="/submit")
assert ret == ("", [["", None], ["Hello", "I love you"]])
def test_does_not_upload_dir(self, stateful_chatbot):
with connect(stateful_chatbot) as client:
initial_history = [["", None]]
message = "Hello"
ret = client.predict(message, initial_history, api_name="/submit")
assert ret == ("", [["", None], ["Hello", "I love you"]])
def test_can_call_mounted_app_via_api(self):
def greet(name):
return "Hello " + name + "!"
@ -1004,13 +1017,13 @@ class TestAPIInfo:
"description": "",
}
assert isinstance(inputs[0]["example_input"], list)
assert isinstance(inputs[0]["example_input"][0], str)
assert isinstance(inputs[0]["example_input"][0], dict)
assert inputs[1]["python_type"] == {
"type": "filepath",
"description": "",
}
assert isinstance(inputs[1]["example_input"], str)
assert isinstance(inputs[1]["example_input"], dict)
assert outputs[0]["python_type"] == {
"type": "List[filepath]",
@ -1158,43 +1171,6 @@ class TestEndpoints:
"file7",
]
@pytest.mark.flaky
def test_upload_v4(self):
client = Client(
src="gradio-tests/not-actually-private-file-uploadv4-sse",
)
response = MagicMock(status_code=200)
response.json.return_value = [
"file1",
"file2",
"file3",
"file4",
"file5",
"file6",
"file7",
]
with patch("httpx.post", MagicMock(return_value=response)):
with patch("builtins.open", MagicMock()):
with patch.object(pathlib.Path, "name") as mock_name:
mock_name.side_effect = lambda x: x
results = client.endpoints[0]._upload(
["pre1", ["pre2", "pre3", "pre4"], ["pre5", "pre6"], "pre7"]
)
res = []
for re in results:
if isinstance(re, list):
res.append([r["path"] for r in re])
else:
res.append(re["path"])
assert res == [
"file1",
["file2", "file3", "file4"],
["file5", "file6"],
"file7",
]
cpu = huggingface_hub.SpaceHardware.CPU_BASIC

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from pathlib import Path
from typing import Any
from gradio_client import file
from gradio_client.documentation import document
from gradio.components.base import Component
@ -98,7 +99,9 @@ class SimpleImage(Component):
return FileData(path=str(value), orig_name=Path(value).name)
def example_payload(self) -> Any:
return "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
return file(
"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
)
def example_value(self) -> Any:
return "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"

View File

@ -121,6 +121,9 @@ class Block:
self.state_session_capacity = 10000
self.temp_files: set[str] = set()
self.GRADIO_CACHE = get_upload_folder()
# Keep tracks of files that should not be deleted when the delete_cache parmaeter is set
# These files are the default value of the component and files that are used in examples
self.keep_in_cache = set()
if render:
self.render()
@ -253,9 +256,16 @@ class Block:
else:
url_or_file_path = str(utils.abspath(url_or_file_path))
if not utils.is_in_or_equal(url_or_file_path, self.GRADIO_CACHE):
temp_file_path = processing_utils.save_file_to_cache(
url_or_file_path, cache_dir=self.GRADIO_CACHE
)
try:
temp_file_path = processing_utils.save_file_to_cache(
url_or_file_path, cache_dir=self.GRADIO_CACHE
)
except FileNotFoundError:
# This can happen if when using gr.load() and the file is on a remote Space
# but the file is not the `value` of the component. For example, if the file
# is the `avatar_image` of the `Chatbot` component. In this case, we skip
# copying the file to the cache and just use the remote file path.
return url_or_file_path
else:
temp_file_path = url_or_file_path
self.temp_files.add(temp_file_path)
@ -753,15 +763,18 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
# targets field
_targets = dependency.pop("targets")
trigger = dependency.pop("trigger", None)
targets = [
getattr(
original_mapping[
target if isinstance(target, int) else target[0]
],
trigger if isinstance(target, int) else target[1],
)
for target in _targets
]
is_then_event = False
# This assumes that you cannot combine multiple .then() events in a single
# gr.on() event, which is true for now. If this changes, we will need to
# update this code.
if not isinstance(_targets[0], int) and _targets[0][1] == "then":
if len(_targets) != 1:
raise ValueError(
"This logic assumes that .then() events are not combined with other events in a single gr.on() event"
)
is_then_event = True
dependency.pop("backend_fn")
dependency.pop("documentation", None)
dependency["inputs"] = [
@ -773,12 +786,30 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
dependency.pop("status_tracker", None)
dependency["preprocess"] = False
dependency["postprocess"] = False
targets = [
EventListenerMethod(
t.__self__ if t.has_trigger else None, t.event_name
if is_then_event:
targets = [EventListenerMethod(None, "then")]
dependency["trigger_after"] = dependency.pop("trigger_after")
dependency["trigger_only_on_success"] = dependency.pop(
"trigger_only_on_success"
)
for t in targets
]
dependency["no_target"] = True
else:
targets = [
getattr(
original_mapping[
target if isinstance(target, int) else target[0]
],
trigger if isinstance(target, int) else target[1],
)
for target in _targets
]
targets = [
EventListenerMethod(
t.__self__ if t.has_trigger else None,
t.event_name, # type: ignore
)
for t in targets
]
dependency = blocks.set_event_trigger(
targets=targets, fn=fn, **dependency
)[0]
@ -866,7 +897,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
show_progress: whether to show progress animation while running.
api_name: defines how the endpoint appears in the API docs. Can be a string, None, or False. If set to a string, the endpoint will be exposed in the API docs with the given name. If None (default), the name of the function will be used as the API endpoint. If False, the endpoint will not be exposed in the API docs and downstream apps (including those that `gr.load` this app) will not be able to use this event.
js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components
no_target: if True, sets "targets" to [], used for Blocks "load" event
no_target: if True, sets "targets" to [], used for the Blocks.load() event and .then() events
queue: If True, will place the request on the queue, if the queue has been enabled. If False, will not put this event on the queue, even if the queue has been enabled. If None, will use the queue setting of the gradio app.
batch: whether this function takes in a batch of inputs
max_batch_size: the maximum batch size to send to the function
@ -884,7 +915,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
# Support for singular parameter
_targets = [
(
target.block._id if target.block and not no_target else None,
target.block._id if not no_target and target.block else None,
target.event_name,
)
for target in targets
@ -1256,7 +1287,8 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
serialized_input = client_utils.traverse(
inputs[i],
format_file,
lambda s: client_utils.is_filepath(s) or client_utils.is_url(s),
lambda s: client_utils.is_filepath(s)
or client_utils.is_http_url_like(s),
)
else:
serialized_input = inputs[i]
@ -1735,7 +1767,7 @@ Received outputs:
"is_colab": utils.colab_check(),
"stylesheets": self.stylesheets,
"theme": self.theme.name,
"protocol": "sse_v2",
"protocol": "sse_v2.1",
"body_css": {
"body_background_fill": self.theme._get_computed_value(
"body_background_fill"

View File

@ -7,6 +7,7 @@ from typing import Any, List
import gradio_client.utils as client_utils
import numpy as np
import PIL.Image
from gradio_client import file
from gradio_client.documentation import document
from gradio import processing_utils, utils
@ -205,7 +206,9 @@ class AnnotatedImage(Component):
def example_payload(self) -> Any:
return {
"image": "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png",
"image": file(
"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
),
"annotations": [],
}

View File

@ -8,6 +8,7 @@ from typing import Any, Callable, Literal
import httpx
import numpy as np
from gradio_client import file
from gradio_client import utils as client_utils
from gradio_client.documentation import document
@ -182,7 +183,9 @@ class Audio(
)
def example_payload(self) -> Any:
return "https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav"
return file(
"https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav"
)
def example_value(self) -> Any:
return "https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav"

View File

@ -190,9 +190,6 @@ class Component(ComponentBase, Block):
self.scale = scale
self.min_width = min_width
self.interactive = interactive
# Keep tracks of files that should not be deleted when the delete_cache parmaeter is set
# These files are the default value of the component and files that are used in examples
self.keep_in_cache = set()
# load_event is set in the Blocks.attach_load_events method
self.load_event: None | dict[str, Any] = None
@ -203,6 +200,7 @@ class Component(ComponentBase, Block):
initial_value,
self, # type: ignore
postprocess=True,
keep_in_cache=True,
)
if is_file_obj(self.value):
self.keep_in_cache.add(self.value["path"])

View File

@ -5,7 +5,6 @@ from __future__ import annotations
from typing import Any, Literal
from gradio_client.documentation import document
from gradio_client.utils import is_file_obj
from gradio import processing_utils
from gradio.components.base import (
@ -95,12 +94,9 @@ class Dataset(Component):
# use the previous name to be backwards-compatible with previously-created
# custom components
example[i] = component.as_example(ex)
example[i] = processing_utils.move_files_to_cache(
example[i],
component,
)
if is_file_obj(example[i]):
self.keep_in_cache.add(example[i]["path"])
example[i] = processing_utils.move_files_to_cache(
example[i], component, keep_in_cache=True
)
self.type = type
self.label = label
if headers is not None:

View File

@ -6,6 +6,7 @@ import tempfile
from pathlib import Path
from typing import Callable, Literal
from gradio_client import file
from gradio_client.documentation import document
from gradio.components.base import Component
@ -99,8 +100,10 @@ class DownloadButton(Component):
return None
return FileData(path=str(value))
def example_payload(self) -> str:
return "https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
def example_payload(self) -> dict:
return file(
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
)
def example_value(self) -> str:
return "https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"

View File

@ -8,6 +8,7 @@ from pathlib import Path
from typing import Any, Callable, Literal
import gradio_client.utils as client_utils
from gradio_client import file
from gradio_client.documentation import document
from gradio import processing_utils
@ -199,10 +200,14 @@ class File(Component):
def example_payload(self) -> Any:
if self.file_count == "single":
return "https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
return file(
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
)
else:
return [
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
file(
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
)
]
def example_value(self) -> Any:

View File

@ -9,6 +9,7 @@ from urllib.parse import urlparse
import numpy as np
import PIL.Image
from gradio_client import file
from gradio_client.documentation import document
from gradio_client.utils import is_http_url_like
@ -223,7 +224,9 @@ class Gallery(Component):
def example_payload(self) -> Any:
return [
{
"image": "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
"image": file(
"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
)
},
]

View File

@ -8,6 +8,7 @@ from typing import Any, Literal, cast
import numpy as np
import PIL.Image
from gradio_client import file
from gradio_client.documentation import document
from PIL import ImageOps
@ -209,7 +210,9 @@ class Image(StreamingInput, Component):
)
def example_payload(self) -> Any:
return "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
return file(
"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
)
def example_value(self) -> Any:
return "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"

View File

@ -9,6 +9,7 @@ from typing import Any, Iterable, List, Literal, Optional, TypedDict, Union, cas
import numpy as np
import PIL.Image
from gradio_client import file
from gradio_client.documentation import document
from gradio import image_utils, utils
@ -216,9 +217,7 @@ class ImageEditor(Component):
) -> np.ndarray | PIL.Image.Image | str | None:
if file is None:
return None
im = PIL.Image.open(file.path)
if file.orig_name:
p = Path(file.orig_name)
name = p.stem
@ -317,7 +316,9 @@ class ImageEditor(Component):
def example_payload(self) -> Any:
return {
"background": "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png",
"background": file(
"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
),
"layers": [],
"composite": None,
}

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from pathlib import Path
from typing import Callable
from gradio_client import file
from gradio_client.documentation import document
from gradio.components.base import Component
@ -118,7 +119,9 @@ class Model3D(Component):
return Path(input_data).name if input_data else ""
def example_payload(self):
return "https://raw.githubusercontent.com/gradio-app/gradio/main/demo/model3D/files/Fox.gltf"
return file(
"https://raw.githubusercontent.com/gradio-app/gradio/main/demo/model3D/files/Fox.gltf"
)
def example_value(self):
return "https://raw.githubusercontent.com/gradio-app/gradio/main/demo/model3D/files/Fox.gltf"

View File

@ -8,6 +8,7 @@ from pathlib import Path
from typing import Any, Callable, Literal
import gradio_client.utils as client_utils
from gradio_client import file
from gradio_client.documentation import document
from gradio import processing_utils
@ -114,10 +115,14 @@ class UploadButton(Component):
def example_payload(self) -> Any:
if self.file_count == "single":
return "https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
return file(
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
)
else:
return [
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
file(
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
)
]
def example_value(self) -> Any:

View File

@ -7,6 +7,7 @@ import warnings
from pathlib import Path
from typing import Any, Callable, Literal, Optional
from gradio_client import file
from gradio_client import utils as client_utils
from gradio_client.documentation import document
@ -353,7 +354,9 @@ class Video(Component):
def example_payload(self) -> Any:
return {
"video": "https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4",
"video": file(
"https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4"
),
}
def example_value(self) -> Any:

View File

@ -164,6 +164,7 @@ class FileData(GradioModel):
orig_name: Optional[str] = None # original filename
mime_type: Optional[str] = None
is_stream: bool = False
meta: dict = {"_type": "gradio.FileData"}
@property
def is_none(self):

View File

@ -418,7 +418,13 @@ def from_spaces(
def from_spaces_blocks(space: str, hf_token: str | None) -> Blocks:
client = Client(space, hf_token=hf_token, download_files=False)
client = Client(
space,
hf_token=hf_token,
upload_files=False,
download_files=False,
_skip_components=False,
)
# We set deserialize to False to avoid downloading output files from the server.
# Instead, we serve them as URLs using the /proxy/ endpoint directly from the server.

View File

@ -241,6 +241,7 @@ def move_files_to_cache(
block: Block,
postprocess: bool = False,
check_in_upload_folder=False,
keep_in_cache=False,
) -> dict:
"""Move any files in `data` to cache and (optionally), adds URL prefixes (/file=...) needed to access the cached file.
Also handles the case where the file is on an external Gradio app (/proxy=...).
@ -252,6 +253,7 @@ def move_files_to_cache(
block: The component whose data is being processed
postprocess: Whether its running from postprocessing
check_in_upload_folder: If True, instead of moving the file to cache, checks if the file is in already in cache (exception if not).
keep_in_cache: If True, the file will not be deleted from cache when the server is shut down.
"""
def _move_to_cache(d: dict):
@ -278,6 +280,8 @@ def move_files_to_cache(
if temp_file_path is None:
raise ValueError("Did not determine a file path for the resource.")
payload.path = temp_file_path
if keep_in_cache:
block.keep_in_cache.add(payload.path)
url_prefix = "/stream/" if payload.is_stream else "/file="
if block.proxy_url:

View File

@ -11,10 +11,10 @@ Using the `gradio_client` library, we can easily use the Gradio as an API to tra
Here's the entire code to do it:
```python
from gradio_client import Client
from gradio_client import Client, file
client = Client("abidlabs/whisper")
client.predict("audio_sample.wav")
client.predict(file("audio_sample.wav"))
>> "This is a test of the whisper speech recognition model."
```
@ -25,12 +25,12 @@ The Gradio client works with any hosted Gradio app, whether it be an image gener
## Installation
If you already have a recent version of `gradio`, then the `gradio_client` is included as a dependency.
If you already have a recent version of `gradio`, then the `gradio_client` is included as a dependency. But note that this documentation reflects the latest version of the `gradio_client`, so upgrade if you're not sure!
Otherwise, the lightweight `gradio_client` package can be installed from pip (or pip3) and is tested to work with Python versions 3.9 or higher:
The lightweight `gradio_client` package can be installed from pip (or pip3) and is tested to work with Python versions 3.9 or higher:
```bash
$ pip install gradio_client
$ pip install --upgrade gradio_client
```
## Connecting to a running Gradio App
@ -62,12 +62,12 @@ The `gradio_client` includes a class method: `Client.duplicate()` to make this p
```python
import os
from gradio_client import Client
from gradio_client import Client, file
HF_TOKEN = os.environ.get("HF_TOKEN")
client = Client.duplicate("abidlabs/whisper", hf_token=HF_TOKEN)
client.predict("audio_sample.wav")
client.predict(file("audio_sample.wav"))
>> "This is a test of the whisper speech recognition model."
```
@ -130,13 +130,13 @@ client.predict(4, "add", 5)
>> 9.0
```
For certain inputs, such as images, you should pass in the filepath or URL to the file. Likewise, for the corresponding output types, you will get a filepath or URL returned.
For when working with files (e.g. image files), you should pass in the filepath or URL to the file enclosed within `gradio_client.file()`.
```python
from gradio_client import Client
from gradio_client import Client, file
client = Client("abidlabs/whisper")
client.predict("https://audio-samples.github.io/samples/mp3/blizzard_unconditional/sample-0.mp3")
client.predict(file("https://audio-samples.github.io/samples/mp3/blizzard_unconditional/sample-0.mp3"))
>> "My thought I have nobody by a beauty and will as you poured. Mr. Rochester is serve in that so don't find simpus, and devoted abode, to at might in a r—"
```
@ -202,8 +202,8 @@ The `Job` class also has a `.cancel()` instance method that cancels jobs that ha
```py
client = Client("abidlabs/whisper")
job1 = client.submit("audio_sample1.wav")
job2 = client.submit("audio_sample2.wav")
job1 = client.submit(file("audio_sample1.wav"))
job2 = client.submit(file("audio_sample2.wav"))
job1.cancel() # will return False, assuming the job has started
job2.cancel() # will return True, indicating that the job has been canceled
```

View File

@ -1,10 +1,19 @@
<script lang="ts">
import type { ComponentMeta, Dependency } from "../types";
import CopyButton from "./CopyButton.svelte";
import { represent_value } from "./utils";
import { represent_value, is_potentially_nested_file_data } from "./utils";
import { Block } from "@gradio/atoms";
import EndpointDetail from "./EndpointDetail.svelte";
interface EndpointParameter {
label: string;
type: string;
python_type: { type: string };
component: string;
example_input: string;
serializer: string;
}
export let dependency: Dependency;
export let dependency_index: number;
export let root: string;
@ -18,19 +27,12 @@
let python_code: HTMLElement;
let js_code: HTMLElement;
let has_file_path = endpoint_parameters.some((param: EndpointParameter) =>
is_potentially_nested_file_data(param.example_input)
);
let blob_components = ["Audio", "File", "Image", "Video"];
let blob_examples: any[] = endpoint_parameters.filter(
(param: {
label: string;
type: string;
python_type: {
type: string;
description: string;
};
component: string;
example_input: string;
serializer: string;
}) => blob_components.includes(param.component)
(param: EndpointParameter) => blob_components.includes(param.component)
);
</script>
@ -47,7 +49,7 @@
<CopyButton code={python_code?.innerText} />
</div>
<div bind:this={python_code}>
<pre>from gradio_client import Client
<pre>from gradio_client import Client{#if has_file_path}, file{/if}
client = Client(<span class="token string">"{root}"</span>)
result = client.predict(<!--
@ -64,7 +66,8 @@ result = client.predict(<!--
-->{/if}<!--
--><span class="desc"
><!--
--> # {python_type.type} {#if python_type.description}({python_type.description}){/if}<!----> in '{label}' <!--
--> # {python_type.type} {#if python_type.description}({python_type.description})
{/if}<!---->in '{label}' <!--
-->{component} component<!--
--></span
><!--
@ -84,7 +87,7 @@ print(result)</pre>
<pre>import &lbrace; client &rbrace; from "@gradio/client";
{#each blob_examples as { label, type, python_type, component, example_input, serializer }, i}<!--
-->
const response_{i} = await fetch("{example_input}");
const response_{i} = await fetch("{example_input.url}");
const example{component} = await response_{i}.blob();
{/each}<!--
-->

View File

@ -7,7 +7,7 @@ export function represent_value(
if (type === undefined) {
return lang === "py" ? "None" : null;
}
if (type === "string" || type === "str" || type == "filepath") {
if (type === "string" || type === "str") {
return lang === null ? value : '"' + value + '"';
} else if (type === "number") {
return lang === null ? parseFloat(value) : value;
@ -35,5 +35,72 @@ export function represent_value(
}
return value;
}
return JSON.stringify(value);
if (lang === "py") {
value = replace_file_data_with_file_function(value);
}
return stringify_except_file_function(value);
}
export function is_potentially_nested_file_data(obj: any): boolean {
if (typeof obj === "object" && obj !== null) {
if (obj.hasOwnProperty("path") && obj.hasOwnProperty("meta")) {
if (
typeof obj.meta === "object" &&
obj.meta !== null &&
obj.meta._type === "gradio.FileData"
) {
return true;
}
}
}
if (typeof obj === "object" && obj !== null) {
for (let key in obj) {
if (typeof obj[key] === "object") {
let result = is_potentially_nested_file_data(obj[key]);
if (result) {
return true;
}
}
}
}
return false;
}
function replace_file_data_with_file_function(obj: any): any {
if (typeof obj === "object" && obj !== null && !Array.isArray(obj)) {
if (
"path" in obj &&
"meta" in obj &&
obj.meta?._type === "gradio.FileData"
) {
return `file('${obj.path}')`;
}
}
if (Array.isArray(obj)) {
obj.forEach((item, index) => {
if (typeof item === "object" && item !== null) {
obj[index] = replace_file_data_with_file_function(item); // Recurse and update array elements
}
});
} else if (typeof obj === "object" && obj !== null) {
Object.keys(obj).forEach((key) => {
obj[key] = replace_file_data_with_file_function(obj[key]); // Recurse and update object properties
});
}
return obj;
}
function stringify_except_file_function(obj: any): string {
const jsonString = JSON.stringify(obj, (key, value) => {
if (
typeof value === "string" &&
value.startsWith("file(") &&
value.endsWith(")")
) {
return `UNQUOTED${value}`; // Flag the special strings
}
return value;
});
const regex = /"UNQUOTEDfile\(([^)]*)\)"/g;
return jsonString.replace(regex, (match, p1) => `file(${p1})`);
}

View File

@ -37,7 +37,7 @@ from gradio.components.dataframe import DataframeData
from gradio.components.file_explorer import FileExplorerData
from gradio.components.image_editor import EditorData
from gradio.components.video import VideoData
from gradio.data_classes import FileData, ListFiles
from gradio.data_classes import FileData, GradioModel, GradioRootModel, ListFiles
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -1543,6 +1543,7 @@ class TestVideo:
"size": None,
"url": None,
"is_stream": False,
"meta": {"_type": "gradio.FileData"},
},
"subtitles": None,
}
@ -1555,6 +1556,7 @@ class TestVideo:
"size": None,
"url": None,
"is_stream": False,
"meta": {"_type": "gradio.FileData"},
},
"subtitles": {
"path": "s1.srt",
@ -1563,6 +1565,7 @@ class TestVideo:
"size": None,
"url": None,
"is_stream": False,
"meta": {"_type": "gradio.FileData"},
},
}
postprocessed_video["video"]["path"] = os.path.basename(
@ -2262,6 +2265,7 @@ class TestGallery:
"size": None,
"url": url,
"is_stream": False,
"meta": {"_type": "gradio.FileData"},
},
"caption": None,
}
@ -2290,6 +2294,7 @@ class TestGallery:
"size": None,
"url": None,
"is_stream": False,
"meta": {"_type": "gradio.FileData"},
},
"caption": "foo_caption",
},
@ -2301,6 +2306,7 @@ class TestGallery:
"size": None,
"url": None,
"is_stream": False,
"meta": {"_type": "gradio.FileData"},
},
"caption": "bar_caption",
},
@ -2312,6 +2318,7 @@ class TestGallery:
"size": None,
"url": None,
"is_stream": False,
"meta": {"_type": "gradio.FileData"},
},
"caption": None,
},
@ -2323,6 +2330,7 @@ class TestGallery:
"size": None,
"url": None,
"is_stream": False,
"meta": {"_type": "gradio.FileData"},
},
"caption": None,
},
@ -2977,3 +2985,25 @@ def test_component_example_values(io_components):
else:
c: Component = component()
c.postprocess(c.example_value())
def test_component_example_payloads(io_components):
for component in io_components:
if component == PDF:
continue
elif component in [gr.BarPlot, gr.LinePlot, gr.ScatterPlot]:
c: Component = component(x="x", y="y")
else:
c: Component = component()
data = c.example_payload()
data = processing_utils.move_files_to_cache(
data,
c,
check_in_upload_folder=False,
)
if getattr(c, "data_model", None) and data is not None:
if issubclass(c.data_model, GradioModel): # type: ignore
data = c.data_model(**data) # type: ignore
elif issubclass(c.data_model, GradioRootModel): # type: ignore
data = c.data_model(root=data) # type: ignore
c.preprocess(data)

View File

@ -397,6 +397,9 @@ class TestLoadInterfaceWithExamples:
demo = gr.load("spaces/gradio-tests/test-calculator-2v4-sse")
assert demo(2, "add", 4) == 6
def test_loading_chatbot_with_avatar_images_does_not_raise_errors(self):
gr.load("gradio/chatbot_multimodal", src="spaces")
def test_get_tabular_examples_replaces_nan_with_str_nan():
readme = """

View File

@ -339,10 +339,12 @@ def test_add_root_url():
"file": {
"path": "path",
"url": "/file=path",
"meta": {"_type": "gradio.FileData"},
},
"file2": {
"path": "path2",
"url": "https://www.gradio.app",
"meta": {"_type": "gradio.FileData"},
},
}
root_url = "http://localhost:7860"
@ -350,10 +352,12 @@ def test_add_root_url():
"file": {
"path": "path",
"url": f"{root_url}/file=path",
"meta": {"_type": "gradio.FileData"},
},
"file2": {
"path": "path2",
"url": "https://www.gradio.app",
"meta": {"_type": "gradio.FileData"},
},
}
assert processing_utils.add_root_url(data, root_url, None) == expected
@ -362,10 +366,12 @@ def test_add_root_url():
"file": {
"path": "path",
"url": f"{new_root_url}/file=path",
"meta": {"_type": "gradio.FileData"},
},
"file2": {
"path": "path2",
"url": "https://www.gradio.app",
"meta": {"_type": "gradio.FileData"},
},
}
assert (