mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
Files should now be supplied as file(...)
in the Client, and some fixes to gr.load()
as well (#7575)
* more fixes for gr.load() * client * add changeset * format * docstring * add assertion * warning * add changeset * add changeset * changes * fixes * more fixes * fix files * add test for dir * add changeset * Delete .changeset/giant-bears-check.md * add changeset * changes * add changeset * print * format * add changeset * docs * add to tests * format * add changeset * move compatibility code out * fixed * changes * changes * factory method * add changeset * changes * changes * sse v2.1 * file() * changes * typing * changes * cleanup * changes * changes * changes * fixes * changes * fix * add changeset * changes * more changes * abc * test * add payloads * lint * test * lint * changes * payload * fixes * fix tests * fix * clean * fix frontend * lint * add changeset * cleanup * format * get examples to show up in loaded spaces * add filedata prop to frontend * add skip component parameter * address feedback * with meta * load --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
7c66a29dea
commit
d0688b3c25
8
.changeset/hungry-donuts-nail.md
Normal file
8
.changeset/hungry-donuts-nail.md
Normal file
@ -0,0 +1,8 @@
|
||||
---
|
||||
"@gradio/app": minor
|
||||
"@gradio/client": minor
|
||||
"gradio": minor
|
||||
"gradio_client": minor
|
||||
---
|
||||
|
||||
fix:Files should now be supplied as `file(...)` in the Client, and some fixes to `gr.load()` as well
|
@ -755,7 +755,11 @@ export function api_factory(
|
||||
}
|
||||
}
|
||||
};
|
||||
} else if (protocol == "sse_v1" || protocol == "sse_v2") {
|
||||
} else if (
|
||||
protocol == "sse_v1" ||
|
||||
protocol == "sse_v2" ||
|
||||
protocol == "sse_v2.1"
|
||||
) {
|
||||
// latest API format. v2 introduces sending diffs for intermediate outputs in generative functions, which makes payloads lighter.
|
||||
fire_event({
|
||||
type: "status",
|
||||
@ -849,7 +853,10 @@ export function api_factory(
|
||||
endpoint: _endpoint,
|
||||
fn_index
|
||||
});
|
||||
if (data && protocol === "sse_v2") {
|
||||
if (
|
||||
data &&
|
||||
(protocol === "sse_v2" || protocol === "sse_v2.1")
|
||||
) {
|
||||
apply_diff_stream(event_id!, data);
|
||||
}
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ export interface Config {
|
||||
show_api: boolean;
|
||||
stylesheets: string[];
|
||||
path: string;
|
||||
protocol?: "sse_v2" | "sse_v1" | "sse" | "ws";
|
||||
protocol?: "sse_v2.1" | "sse_v2" | "sse_v1" | "sse" | "ws";
|
||||
}
|
||||
|
||||
export interface Payload {
|
||||
|
@ -69,6 +69,7 @@ export class FileData {
|
||||
is_stream?: boolean;
|
||||
mime_type?: string;
|
||||
alt_text?: string;
|
||||
readonly meta = { _type: "gradio.FileData" };
|
||||
|
||||
constructor({
|
||||
path,
|
||||
|
@ -1,7 +1,8 @@
|
||||
from gradio_client.client import Client
|
||||
from gradio_client.utils import __version__
|
||||
from gradio_client.utils import __version__, file
|
||||
|
||||
__all__ = [
|
||||
"Client",
|
||||
"file",
|
||||
"__version__",
|
||||
]
|
||||
|
@ -21,7 +21,6 @@ from typing import Any, Callable, Literal
|
||||
|
||||
import httpx
|
||||
import huggingface_hub
|
||||
import websockets
|
||||
from huggingface_hub import CommitOperationAdd, SpaceHardware, SpaceStage
|
||||
from huggingface_hub.utils import (
|
||||
RepositoryNotFoundError,
|
||||
@ -30,9 +29,10 @@ from huggingface_hub.utils import (
|
||||
)
|
||||
from packaging import version
|
||||
|
||||
from gradio_client import serializing, utils
|
||||
from gradio_client import utils
|
||||
from gradio_client.compatibility import EndpointV3Compatibility
|
||||
from gradio_client.documentation import document
|
||||
from gradio_client.exceptions import AuthenticationError, SerializationSetupError
|
||||
from gradio_client.exceptions import AuthenticationError
|
||||
from gradio_client.utils import (
|
||||
Communicator,
|
||||
JobStatus,
|
||||
@ -71,14 +71,16 @@ class Client:
|
||||
src: str,
|
||||
hf_token: str | None = None,
|
||||
max_workers: int = 40,
|
||||
serialize: bool | None = None,
|
||||
output_dir: str | Path = DEFAULT_TEMP_DIR,
|
||||
serialize: bool | None = None, # TODO: remove in 1.0
|
||||
output_dir: str
|
||||
| Path = DEFAULT_TEMP_DIR, # Maybe this can be combined with `download_files` in 1.0
|
||||
verbose: bool = True,
|
||||
auth: tuple[str, str] | None = None,
|
||||
*,
|
||||
headers: dict[str, str] | None = None,
|
||||
upload_files: bool = True,
|
||||
download_files: bool = True,
|
||||
upload_files: bool = True, # TODO: remove and hardcode to False in 1.0
|
||||
download_files: bool = True, # TODO: consider setting to False in 1.0
|
||||
_skip_components: bool = True, # internal parameter to skip values certain components (e.g. State) that do not need to be displayed to users.
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
@ -89,8 +91,8 @@ class Client:
|
||||
output_dir: The directory to save files that are downloaded from the remote API. If None, reads from the GRADIO_TEMP_DIR environment variable. Defaults to a temporary directory on your machine.
|
||||
verbose: Whether the client should print statements to the console.
|
||||
headers: Additional headers to send to the remote Gradio app on every request. By default only the HF authorization and user-agent headers are sent. These headers will override the default headers if they have the same keys.
|
||||
upload_files: Whether the client should treat input string filepath as files and upload them to the remote server. If False, the client will treat input string filepaths as strings always and not modify them.
|
||||
download_files: Whether the client should download output files from the remote API and return them as string filepaths on the local machine. If False, the client will a FileData dataclass object with the filepath on the remote machine instead.
|
||||
upload_files: Whether the client should treat input string filepath as files and upload them to the remote server. If False, the client will treat input string filepaths as strings always and not modify them, and files should be passed in explicitly using `gradio_client.file("path/to/file/or/url")` instead. This parameter will be deleted and False will become the default in a future version.
|
||||
download_files: Whether the client should download output files from the remote API and return them as string filepaths on the local machine. If False, the client will return a FileData dataclass object with the filepath on the remote machine instead.
|
||||
"""
|
||||
self.verbose = verbose
|
||||
self.hf_token = hf_token
|
||||
@ -101,6 +103,7 @@ class Client:
|
||||
upload_files = serialize
|
||||
self.upload_files = upload_files
|
||||
self.download_files = download_files
|
||||
self._skip_components = _skip_components
|
||||
self.headers = build_hf_headers(
|
||||
token=hf_token,
|
||||
library_name="gradio_client",
|
||||
@ -143,7 +146,9 @@ class Client:
|
||||
self._login(auth)
|
||||
|
||||
self.config = self._get_config()
|
||||
self.protocol: str = self.config.get("protocol", "ws")
|
||||
self.protocol: Literal[
|
||||
"ws", "sse", "sse_v1", "sse_v2", "sse_v2.1"
|
||||
] = self.config.get("protocol", "ws")
|
||||
self.api_url = urllib.parse.urljoin(self.src, utils.API_URL)
|
||||
self.sse_url = urllib.parse.urljoin(
|
||||
self.src, utils.SSE_URL_V0 if self.protocol == "sse" else utils.SSE_URL
|
||||
@ -445,6 +450,7 @@ class Client:
|
||||
"sse",
|
||||
"sse_v1",
|
||||
"sse_v2",
|
||||
"sse_v2.1",
|
||||
):
|
||||
helper = self.new_helper(inferred_fn_index)
|
||||
end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper)
|
||||
@ -972,8 +978,7 @@ class Endpoint:
|
||||
# This is still hacky as it does not tell us which part of the payload is a file.
|
||||
# If a component has a complex payload, part of which is a file, this will simply
|
||||
# return True, which means that all parts of the payload will be uploaded as files
|
||||
# if they are valid file paths. The better approach would be to traverse the
|
||||
# component's api_info and figure out exactly which part of the payload is a file.
|
||||
# if they are valid file paths. We will deprecate this 1.0.
|
||||
if "api_info" not in component:
|
||||
return False
|
||||
return utils.value_is_file(component["api_info"])
|
||||
@ -990,11 +995,12 @@ class Endpoint:
|
||||
def _inner(*data):
|
||||
if not self.is_valid:
|
||||
raise utils.InvalidAPIEndpointError()
|
||||
data = self.insert_state(*data)
|
||||
if self.client.upload_files:
|
||||
data = self.serialize(*data)
|
||||
|
||||
data = self.insert_empty_state(*data)
|
||||
data = self.process_input_files(*data)
|
||||
predictions = _predict(*data)
|
||||
predictions = self.process_predictions(*predictions)
|
||||
|
||||
# Append final output only if not already present
|
||||
# for consistency between generators and not generators
|
||||
if helper:
|
||||
@ -1022,7 +1028,7 @@ class Endpoint:
|
||||
result = utils.synchronize_async(
|
||||
self._sse_fn_v0, data, hash_data, helper
|
||||
)
|
||||
elif self.protocol in ("sse_v1", "sse_v2"):
|
||||
elif self.protocol in ("sse_v1", "sse_v2", "sse_v2.1"):
|
||||
event_id = utils.synchronize_async(
|
||||
self.client.send_data, data, hash_data
|
||||
)
|
||||
@ -1059,68 +1065,52 @@ class Endpoint:
|
||||
|
||||
return _predict
|
||||
|
||||
def _predict_resolve(self, *data) -> Any:
|
||||
"""Needed for gradio.load(), which has a slightly different signature for serializing/deserializing"""
|
||||
outputs = self.make_predict()(*data)
|
||||
if len(self.dependency["outputs"]) == 1:
|
||||
return outputs[0]
|
||||
return outputs
|
||||
|
||||
def _upload(
|
||||
self, file_paths: list[str | list[str]]
|
||||
) -> list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]]:
|
||||
if not file_paths:
|
||||
return []
|
||||
# Put all the filepaths in one file
|
||||
# but then keep track of which index in the
|
||||
# original list they came from so we can recreate
|
||||
# the original structure
|
||||
files = []
|
||||
indices = []
|
||||
for i, fs in enumerate(file_paths):
|
||||
if not isinstance(fs, list):
|
||||
fs = [fs]
|
||||
for f in fs:
|
||||
files.append(("files", (Path(f).name, open(f, "rb")))) # noqa: SIM115
|
||||
indices.append(i)
|
||||
r = httpx.post(
|
||||
self.client.upload_url,
|
||||
headers=self.client.headers,
|
||||
cookies=self.client.cookies,
|
||||
files=files,
|
||||
)
|
||||
if r.status_code != 200:
|
||||
uploaded = file_paths
|
||||
else:
|
||||
uploaded = []
|
||||
result = r.json()
|
||||
for i, fs in enumerate(file_paths):
|
||||
if isinstance(fs, list):
|
||||
output = [o for ix, o in enumerate(result) if indices[ix] == i]
|
||||
res = [
|
||||
{
|
||||
"path": o,
|
||||
"orig_name": Path(f).name,
|
||||
}
|
||||
for f, o in zip(fs, output)
|
||||
]
|
||||
else:
|
||||
o = next(o for ix, o in enumerate(result) if indices[ix] == i)
|
||||
res = {
|
||||
"path": o,
|
||||
"orig_name": Path(fs).name,
|
||||
}
|
||||
uploaded.append(res)
|
||||
return uploaded
|
||||
|
||||
def insert_state(self, *data) -> tuple:
|
||||
def insert_empty_state(self, *data) -> tuple:
|
||||
data = list(data)
|
||||
for i, input_component_type in enumerate(self.input_component_types):
|
||||
if input_component_type.is_state:
|
||||
data.insert(i, None)
|
||||
return tuple(data)
|
||||
|
||||
def process_input_files(self, *data) -> tuple:
|
||||
data_ = []
|
||||
for i, d in enumerate(data):
|
||||
if self.client.upload_files and self.input_component_types[i].value_is_file:
|
||||
d = utils.traverse(
|
||||
d,
|
||||
self._upload_file,
|
||||
lambda f: utils.is_filepath(f)
|
||||
or utils.is_file_obj_with_meta(f)
|
||||
or utils.is_http_url_like(f),
|
||||
)
|
||||
elif not self.client.upload_files:
|
||||
d = utils.traverse(d, self._upload_file, utils.is_file_obj_with_meta)
|
||||
data_.append(d)
|
||||
return tuple(data_)
|
||||
|
||||
def process_predictions(self, *predictions):
|
||||
# If self.download_file is True, we assume that that the user is using the Client directly (as opposed
|
||||
# within gr.load) and therefore, download any files generated by the server and skip values for
|
||||
# components that the user likely does not want to see (e.g. gr.State, gr.Tab).
|
||||
if self.client.download_files:
|
||||
predictions = self.download_files(*predictions)
|
||||
if self.client._skip_components:
|
||||
predictions = self.remove_skipped_components(*predictions)
|
||||
predictions = self.reduce_singleton_output(*predictions)
|
||||
return predictions
|
||||
|
||||
def download_files(self, *data) -> tuple:
|
||||
data_ = list(data)
|
||||
if self.client.protocol == "sse_v2.1":
|
||||
data_ = utils.traverse(
|
||||
data_, self._download_file, utils.is_file_obj_with_meta
|
||||
)
|
||||
else:
|
||||
data_ = utils.traverse(data_, self._download_file, utils.is_file_obj)
|
||||
return tuple(data_)
|
||||
|
||||
def remove_skipped_components(self, *data) -> tuple:
|
||||
""""""
|
||||
data = [d for d, oct in zip(data, self.output_component_types) if not oct.skip]
|
||||
return tuple(data)
|
||||
|
||||
@ -1130,88 +1120,32 @@ class Endpoint:
|
||||
else:
|
||||
return data
|
||||
|
||||
def _gather_files(self, *data):
|
||||
file_list = []
|
||||
|
||||
def get_file(d):
|
||||
if utils.is_file_obj(d):
|
||||
file_list.append(d["path"])
|
||||
else:
|
||||
file_list.append(d)
|
||||
return ReplaceMe(len(file_list) - 1)
|
||||
|
||||
def handle_url(s):
|
||||
return {"path": s, "orig_name": s.split("/")[-1]}
|
||||
|
||||
new_data = []
|
||||
for i, d in enumerate(data):
|
||||
if self.input_component_types[i].value_is_file:
|
||||
# Check file dicts and filepaths to upload
|
||||
# file dict is a corner case but still needed for completeness
|
||||
# most users should be using filepaths
|
||||
d = utils.traverse(
|
||||
d, get_file, lambda s: utils.is_file_obj(s) or utils.is_filepath(s)
|
||||
)
|
||||
# Handle URLs here since we don't upload them
|
||||
d = utils.traverse(d, handle_url, lambda s: utils.is_url(s))
|
||||
new_data.append(d)
|
||||
return file_list, new_data
|
||||
|
||||
def _add_uploaded_files_to_data(self, data: list[Any], files: list[Any]):
|
||||
def replace(d: ReplaceMe) -> dict:
|
||||
return files[d.index]
|
||||
|
||||
new_data = []
|
||||
for d in data:
|
||||
d = utils.traverse(
|
||||
d, replace, is_root=lambda node: isinstance(node, ReplaceMe)
|
||||
def _upload_file(self, f: str | dict):
|
||||
if isinstance(f, str):
|
||||
warnings.warn(
|
||||
f'The Client is treating: "{f}" as a file path. In future versions, this behavior will not happen automatically. '
|
||||
f'\n\nInstead, please provide file path or URLs like this: gradio_client.file("{f}"). '
|
||||
"\n\nNote: to stop treating strings as filepaths unless file() is used, set upload_files=False in Client()."
|
||||
)
|
||||
new_data.append(d)
|
||||
return new_data
|
||||
|
||||
def serialize(self, *data) -> tuple:
|
||||
files, new_data = self._gather_files(*data)
|
||||
uploaded_files = self._upload(files)
|
||||
data = list(new_data)
|
||||
data = self._add_uploaded_files_to_data(data, uploaded_files)
|
||||
o = tuple(data)
|
||||
return o
|
||||
|
||||
def download_file(self, file_data: dict) -> str | None:
|
||||
if file_data is None:
|
||||
return None
|
||||
if isinstance(file_data, str):
|
||||
file_name = utils.decode_base64_to_file(
|
||||
file_data, dir=self.client.output_dir
|
||||
).name
|
||||
elif isinstance(file_data, dict):
|
||||
filepath = file_data.get("path")
|
||||
if not filepath:
|
||||
raise ValueError(f"The 'path' field is missing in {file_data}")
|
||||
file_name = utils.download_file(
|
||||
self.root_url + "file=" + filepath,
|
||||
save_dir=self.client.output_dir,
|
||||
file_path = f
|
||||
else:
|
||||
file_path = f["path"]
|
||||
if not utils.is_http_url_like(file_path):
|
||||
file_path = utils.upload_file(
|
||||
file_path=file_path,
|
||||
upload_url=self.client.upload_url,
|
||||
headers=self.client.headers,
|
||||
cookies=self.client.cookies,
|
||||
)
|
||||
return {"path": file_path}
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"A FileSerializable component can only deserialize a string or a dict, not a {type(file_name)}: {file_name}"
|
||||
)
|
||||
return file_name
|
||||
|
||||
def deserialize(self, *data) -> tuple:
|
||||
data_ = list(data)
|
||||
data_: list[Any] = utils.traverse(data_, self.download_file, utils.is_file_obj)
|
||||
return tuple(data_)
|
||||
|
||||
def process_predictions(self, *predictions):
|
||||
if self.client.download_files:
|
||||
predictions = self.deserialize(*predictions)
|
||||
predictions = self.remove_skipped_components(*predictions)
|
||||
predictions = self.reduce_singleton_output(*predictions)
|
||||
return predictions
|
||||
def _download_file(self, x: dict) -> str | None:
|
||||
return utils.download_file(
|
||||
self.root_url + "file=" + x["path"],
|
||||
save_dir=self.client.output_dir,
|
||||
headers=self.client.headers,
|
||||
cookies=self.client.cookies,
|
||||
)
|
||||
|
||||
async def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator):
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=None)) as client:
|
||||
@ -1227,7 +1161,10 @@ class Endpoint:
|
||||
)
|
||||
|
||||
async def _sse_fn_v1_v2(
|
||||
self, helper: Communicator, event_id: str, protocol: Literal["sse_v1", "sse_v2"]
|
||||
self,
|
||||
helper: Communicator,
|
||||
event_id: str,
|
||||
protocol: Literal["sse_v1", "sse_v2", "sse_v2.1"],
|
||||
):
|
||||
return await utils.get_pred_from_sse_v1_v2(
|
||||
helper,
|
||||
@ -1239,313 +1176,6 @@ class Endpoint:
|
||||
)
|
||||
|
||||
|
||||
class EndpointV3Compatibility:
|
||||
"""Endpoint class for connecting to v3 endpoints. Backwards compatibility."""
|
||||
|
||||
def __init__(self, client: Client, fn_index: int, dependency: dict, *_args):
|
||||
self.client: Client = client
|
||||
self.fn_index = fn_index
|
||||
self.dependency = dependency
|
||||
api_name = dependency.get("api_name")
|
||||
self.api_name: str | Literal[False] | None = (
|
||||
"/" + api_name if isinstance(api_name, str) else api_name
|
||||
)
|
||||
self.use_ws = self._use_websocket(self.dependency)
|
||||
self.protocol = "ws" if self.use_ws else "http"
|
||||
self.input_component_types = []
|
||||
self.output_component_types = []
|
||||
self.root_url = client.src + "/" if not client.src.endswith("/") else client.src
|
||||
self.is_continuous = dependency.get("types", {}).get("continuous", False)
|
||||
try:
|
||||
# Only a real API endpoint if backend_fn is True (so not just a frontend function), serializers are valid,
|
||||
# and api_name is not False (meaning that the developer has explicitly disabled the API endpoint)
|
||||
self.serializers, self.deserializers = self._setup_serializers()
|
||||
self.is_valid = self.dependency["backend_fn"] and self.api_name is not False
|
||||
except SerializationSetupError:
|
||||
self.is_valid = False
|
||||
self.backend_fn = dependency.get("backend_fn")
|
||||
self.show_api = True
|
||||
|
||||
def __repr__(self):
|
||||
return f"Endpoint src: {self.client.src}, api_name: {self.api_name}, fn_index: {self.fn_index}"
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def make_end_to_end_fn(self, helper: Communicator | None = None):
|
||||
_predict = self.make_predict(helper)
|
||||
|
||||
def _inner(*data):
|
||||
if not self.is_valid:
|
||||
raise utils.InvalidAPIEndpointError()
|
||||
data = self.insert_state(*data)
|
||||
if self.client.upload_files:
|
||||
data = self.serialize(*data)
|
||||
predictions = _predict(*data)
|
||||
predictions = self.process_predictions(*predictions)
|
||||
# Append final output only if not already present
|
||||
# for consistency between generators and not generators
|
||||
if helper:
|
||||
with helper.lock:
|
||||
if not helper.job.outputs:
|
||||
helper.job.outputs.append(predictions)
|
||||
return predictions
|
||||
|
||||
return _inner
|
||||
|
||||
def make_predict(self, helper: Communicator | None = None):
|
||||
def _predict(*data) -> tuple:
|
||||
data = json.dumps(
|
||||
{
|
||||
"data": data,
|
||||
"fn_index": self.fn_index,
|
||||
"session_hash": self.client.session_hash,
|
||||
}
|
||||
)
|
||||
hash_data = json.dumps(
|
||||
{
|
||||
"fn_index": self.fn_index,
|
||||
"session_hash": self.client.session_hash,
|
||||
}
|
||||
)
|
||||
if self.use_ws:
|
||||
result = utils.synchronize_async(self._ws_fn, data, hash_data, helper)
|
||||
if "error" in result:
|
||||
raise ValueError(result["error"])
|
||||
else:
|
||||
response = httpx.post(
|
||||
self.client.api_url, headers=self.client.headers, json=data
|
||||
)
|
||||
result = json.loads(response.content.decode("utf-8"))
|
||||
try:
|
||||
output = result["data"]
|
||||
except KeyError as ke:
|
||||
is_public_space = (
|
||||
self.client.space_id
|
||||
and not huggingface_hub.space_info(self.client.space_id).private
|
||||
)
|
||||
if "error" in result and "429" in result["error"] and is_public_space:
|
||||
raise utils.TooManyRequestsError(
|
||||
f"Too many requests to the API, please try again later. To avoid being rate-limited, "
|
||||
f"please duplicate the Space using Client.duplicate({self.client.space_id}) "
|
||||
f"and pass in your Hugging Face token."
|
||||
) from None
|
||||
elif "error" in result:
|
||||
raise ValueError(result["error"]) from None
|
||||
raise KeyError(
|
||||
f"Could not find 'data' key in response. Response received: {result}"
|
||||
) from ke
|
||||
return tuple(output)
|
||||
|
||||
return _predict
|
||||
|
||||
def _predict_resolve(self, *data) -> Any:
|
||||
"""Needed for gradio.load(), which has a slightly different signature for serializing/deserializing"""
|
||||
outputs = self.make_predict()(*data)
|
||||
if len(self.dependency["outputs"]) == 1:
|
||||
return outputs[0]
|
||||
return outputs
|
||||
|
||||
def _upload(
|
||||
self, file_paths: list[str | list[str]]
|
||||
) -> list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]]:
|
||||
if not file_paths:
|
||||
return []
|
||||
# Put all the filepaths in one file
|
||||
# but then keep track of which index in the
|
||||
# original list they came from so we can recreate
|
||||
# the original structure
|
||||
files = []
|
||||
indices = []
|
||||
for i, fs in enumerate(file_paths):
|
||||
if not isinstance(fs, list):
|
||||
fs = [fs]
|
||||
for f in fs:
|
||||
files.append(("files", (Path(f).name, open(f, "rb")))) # noqa: SIM115
|
||||
indices.append(i)
|
||||
r = httpx.post(self.client.upload_url, headers=self.client.headers, files=files)
|
||||
if r.status_code != 200:
|
||||
uploaded = file_paths
|
||||
else:
|
||||
uploaded = []
|
||||
result = r.json()
|
||||
for i, fs in enumerate(file_paths):
|
||||
if isinstance(fs, list):
|
||||
output = [o for ix, o in enumerate(result) if indices[ix] == i]
|
||||
res = [
|
||||
{
|
||||
"is_file": True,
|
||||
"name": o,
|
||||
"orig_name": Path(f).name,
|
||||
"data": None,
|
||||
}
|
||||
for f, o in zip(fs, output)
|
||||
]
|
||||
else:
|
||||
o = next(o for ix, o in enumerate(result) if indices[ix] == i)
|
||||
res = {
|
||||
"is_file": True,
|
||||
"name": o,
|
||||
"orig_name": Path(fs).name,
|
||||
"data": None,
|
||||
}
|
||||
uploaded.append(res)
|
||||
return uploaded
|
||||
|
||||
def _add_uploaded_files_to_data(
|
||||
self,
|
||||
files: list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]],
|
||||
data: list[Any],
|
||||
) -> None:
|
||||
"""Helper function to modify the input data with the uploaded files."""
|
||||
file_counter = 0
|
||||
for i, t in enumerate(self.input_component_types):
|
||||
if t in ["file", "uploadbutton"]:
|
||||
data[i] = files[file_counter]
|
||||
file_counter += 1
|
||||
|
||||
def insert_state(self, *data) -> tuple:
|
||||
data = list(data)
|
||||
for i, input_component_type in enumerate(self.input_component_types):
|
||||
if input_component_type == utils.STATE_COMPONENT:
|
||||
data.insert(i, None)
|
||||
return tuple(data)
|
||||
|
||||
def remove_skipped_components(self, *data) -> tuple:
|
||||
data = [
|
||||
d
|
||||
for d, oct in zip(data, self.output_component_types)
|
||||
if oct not in utils.SKIP_COMPONENTS
|
||||
]
|
||||
return tuple(data)
|
||||
|
||||
def reduce_singleton_output(self, *data) -> Any:
|
||||
if (
|
||||
len(
|
||||
[
|
||||
oct
|
||||
for oct in self.output_component_types
|
||||
if oct not in utils.SKIP_COMPONENTS
|
||||
]
|
||||
)
|
||||
== 1
|
||||
):
|
||||
return data[0]
|
||||
else:
|
||||
return data
|
||||
|
||||
def serialize(self, *data) -> tuple:
|
||||
if len(data) != len(self.serializers):
|
||||
raise ValueError(
|
||||
f"Expected {len(self.serializers)} arguments, got {len(data)}"
|
||||
)
|
||||
|
||||
files = [
|
||||
f
|
||||
for f, t in zip(data, self.input_component_types)
|
||||
if t in ["file", "uploadbutton"]
|
||||
]
|
||||
uploaded_files = self._upload(files)
|
||||
data = list(data)
|
||||
self._add_uploaded_files_to_data(uploaded_files, data)
|
||||
o = tuple([s.serialize(d) for s, d in zip(self.serializers, data)])
|
||||
return o
|
||||
|
||||
def deserialize(self, *data) -> tuple:
|
||||
if len(data) != len(self.deserializers):
|
||||
raise ValueError(
|
||||
f"Expected {len(self.deserializers)} outputs, got {len(data)}"
|
||||
)
|
||||
outputs = tuple(
|
||||
[
|
||||
s.deserialize(
|
||||
d,
|
||||
save_dir=self.client.output_dir,
|
||||
hf_token=self.client.hf_token,
|
||||
root_url=self.root_url,
|
||||
)
|
||||
for s, d in zip(self.deserializers, data)
|
||||
]
|
||||
)
|
||||
return outputs
|
||||
|
||||
def process_predictions(self, *predictions):
|
||||
if self.client.download_files:
|
||||
predictions = self.deserialize(*predictions)
|
||||
predictions = self.remove_skipped_components(*predictions)
|
||||
predictions = self.reduce_singleton_output(*predictions)
|
||||
return predictions
|
||||
|
||||
def _setup_serializers(
|
||||
self,
|
||||
) -> tuple[list[serializing.Serializable], list[serializing.Serializable]]:
|
||||
inputs = self.dependency["inputs"]
|
||||
serializers = []
|
||||
|
||||
for i in inputs:
|
||||
for component in self.client.config["components"]:
|
||||
if component["id"] == i:
|
||||
component_name = component["type"]
|
||||
self.input_component_types.append(component_name)
|
||||
if component.get("serializer"):
|
||||
serializer_name = component["serializer"]
|
||||
if serializer_name not in serializing.SERIALIZER_MAPPING:
|
||||
raise SerializationSetupError(
|
||||
f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
|
||||
)
|
||||
serializer = serializing.SERIALIZER_MAPPING[serializer_name]
|
||||
elif component_name in serializing.COMPONENT_MAPPING:
|
||||
serializer = serializing.COMPONENT_MAPPING[component_name]
|
||||
else:
|
||||
raise SerializationSetupError(
|
||||
f"Unknown component: {component_name}, you may need to update your gradio_client version."
|
||||
)
|
||||
serializers.append(serializer()) # type: ignore
|
||||
|
||||
outputs = self.dependency["outputs"]
|
||||
deserializers = []
|
||||
for i in outputs:
|
||||
for component in self.client.config["components"]:
|
||||
if component["id"] == i:
|
||||
component_name = component["type"]
|
||||
self.output_component_types.append(component_name)
|
||||
if component.get("serializer"):
|
||||
serializer_name = component["serializer"]
|
||||
if serializer_name not in serializing.SERIALIZER_MAPPING:
|
||||
raise SerializationSetupError(
|
||||
f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
|
||||
)
|
||||
deserializer = serializing.SERIALIZER_MAPPING[serializer_name]
|
||||
elif component_name in utils.SKIP_COMPONENTS:
|
||||
deserializer = serializing.SimpleSerializable
|
||||
elif component_name in serializing.COMPONENT_MAPPING:
|
||||
deserializer = serializing.COMPONENT_MAPPING[component_name]
|
||||
else:
|
||||
raise SerializationSetupError(
|
||||
f"Unknown component: {component_name}, you may need to update your gradio_client version."
|
||||
)
|
||||
deserializers.append(deserializer()) # type: ignore
|
||||
|
||||
return serializers, deserializers
|
||||
|
||||
def _use_websocket(self, dependency: dict) -> bool:
|
||||
queue_enabled = self.client.config.get("enable_queue", False)
|
||||
queue_uses_websocket = version.parse(
|
||||
self.client.config.get("version", "2.0")
|
||||
) >= version.Version("3.2")
|
||||
dependency_uses_queue = dependency.get("queue", False) is not False
|
||||
return queue_enabled and queue_uses_websocket and dependency_uses_queue
|
||||
|
||||
async def _ws_fn(self, data, hash_data, helper: Communicator):
|
||||
async with websockets.connect( # type: ignore
|
||||
self.client.ws_url,
|
||||
open_timeout=10,
|
||||
extra_headers=self.client.headers,
|
||||
max_size=1024 * 1024 * 1024,
|
||||
) as websocket:
|
||||
return await utils.get_pred_from_ws(websocket, data, hash_data, helper)
|
||||
|
||||
|
||||
@document("result", "outputs", "status")
|
||||
class Job(Future):
|
||||
"""
|
||||
|
327
client/python/gradio_client/compatibility.py
Normal file
327
client/python/gradio_client/compatibility.py
Normal file
@ -0,0 +1,327 @@
|
||||
""" This module contains the EndpointV3Compatibility class, which is used to connect to Gradio apps running 3.x.x versions of Gradio."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import httpx
|
||||
import huggingface_hub
|
||||
import websockets
|
||||
from packaging import version
|
||||
|
||||
from gradio_client import serializing, utils
|
||||
from gradio_client.exceptions import SerializationSetupError
|
||||
from gradio_client.utils import (
|
||||
Communicator,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio_client import Client
|
||||
|
||||
|
||||
class EndpointV3Compatibility:
|
||||
"""Endpoint class for connecting to v3 endpoints. Backwards compatibility."""
|
||||
|
||||
def __init__(self, client: Client, fn_index: int, dependency: dict, *_args):
|
||||
self.client: Client = client
|
||||
self.fn_index = fn_index
|
||||
self.dependency = dependency
|
||||
api_name = dependency.get("api_name")
|
||||
self.api_name: str | Literal[False] | None = (
|
||||
"/" + api_name if isinstance(api_name, str) else api_name
|
||||
)
|
||||
self.use_ws = self._use_websocket(self.dependency)
|
||||
self.protocol = "ws" if self.use_ws else "http"
|
||||
self.input_component_types = []
|
||||
self.output_component_types = []
|
||||
self.root_url = client.src + "/" if not client.src.endswith("/") else client.src
|
||||
self.is_continuous = dependency.get("types", {}).get("continuous", False)
|
||||
try:
|
||||
# Only a real API endpoint if backend_fn is True (so not just a frontend function), serializers are valid,
|
||||
# and api_name is not False (meaning that the developer has explicitly disabled the API endpoint)
|
||||
self.serializers, self.deserializers = self._setup_serializers()
|
||||
self.is_valid = self.dependency["backend_fn"] and self.api_name is not False
|
||||
except SerializationSetupError:
|
||||
self.is_valid = False
|
||||
self.backend_fn = dependency.get("backend_fn")
|
||||
self.show_api = True
|
||||
|
||||
def __repr__(self):
|
||||
return f"Endpoint src: {self.client.src}, api_name: {self.api_name}, fn_index: {self.fn_index}"
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def make_end_to_end_fn(self, helper: Communicator | None = None):
|
||||
_predict = self.make_predict(helper)
|
||||
|
||||
def _inner(*data):
|
||||
if not self.is_valid:
|
||||
raise utils.InvalidAPIEndpointError()
|
||||
data = self.insert_state(*data)
|
||||
if self.client.upload_files:
|
||||
data = self.serialize(*data)
|
||||
predictions = _predict(*data)
|
||||
predictions = self.process_predictions(*predictions)
|
||||
# Append final output only if not already present
|
||||
# for consistency between generators and not generators
|
||||
if helper:
|
||||
with helper.lock:
|
||||
if not helper.job.outputs:
|
||||
helper.job.outputs.append(predictions)
|
||||
return predictions
|
||||
|
||||
return _inner
|
||||
|
||||
def make_predict(self, helper: Communicator | None = None):
|
||||
def _predict(*data) -> tuple:
|
||||
data = json.dumps(
|
||||
{
|
||||
"data": data,
|
||||
"fn_index": self.fn_index,
|
||||
"session_hash": self.client.session_hash,
|
||||
}
|
||||
)
|
||||
hash_data = json.dumps(
|
||||
{
|
||||
"fn_index": self.fn_index,
|
||||
"session_hash": self.client.session_hash,
|
||||
}
|
||||
)
|
||||
if self.use_ws:
|
||||
result = utils.synchronize_async(self._ws_fn, data, hash_data, helper)
|
||||
if "error" in result:
|
||||
raise ValueError(result["error"])
|
||||
else:
|
||||
response = httpx.post(
|
||||
self.client.api_url, headers=self.client.headers, json=data
|
||||
)
|
||||
result = json.loads(response.content.decode("utf-8"))
|
||||
try:
|
||||
output = result["data"]
|
||||
except KeyError as ke:
|
||||
is_public_space = (
|
||||
self.client.space_id
|
||||
and not huggingface_hub.space_info(self.client.space_id).private
|
||||
)
|
||||
if "error" in result and "429" in result["error"] and is_public_space:
|
||||
raise utils.TooManyRequestsError(
|
||||
f"Too many requests to the API, please try again later. To avoid being rate-limited, "
|
||||
f"please duplicate the Space using Client.duplicate({self.client.space_id}) "
|
||||
f"and pass in your Hugging Face token."
|
||||
) from None
|
||||
elif "error" in result:
|
||||
raise ValueError(result["error"]) from None
|
||||
raise KeyError(
|
||||
f"Could not find 'data' key in response. Response received: {result}"
|
||||
) from ke
|
||||
return tuple(output)
|
||||
|
||||
return _predict
|
||||
|
||||
def _predict_resolve(self, *data) -> Any:
|
||||
"""Needed for gradio.load(), which has a slightly different signature for serializing/deserializing"""
|
||||
outputs = self.make_predict()(*data)
|
||||
if len(self.dependency["outputs"]) == 1:
|
||||
return outputs[0]
|
||||
return outputs
|
||||
|
||||
def _upload(
|
||||
self, file_paths: list[str | list[str]]
|
||||
) -> list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]]:
|
||||
if not file_paths:
|
||||
return []
|
||||
# Put all the filepaths in one file
|
||||
# but then keep track of which index in the
|
||||
# original list they came from so we can recreate
|
||||
# the original structure
|
||||
files = []
|
||||
indices = []
|
||||
for i, fs in enumerate(file_paths):
|
||||
if not isinstance(fs, list):
|
||||
fs = [fs]
|
||||
for f in fs:
|
||||
files.append(("files", (Path(f).name, open(f, "rb")))) # noqa: SIM115
|
||||
indices.append(i)
|
||||
r = httpx.post(self.client.upload_url, headers=self.client.headers, files=files)
|
||||
if r.status_code != 200:
|
||||
uploaded = file_paths
|
||||
else:
|
||||
uploaded = []
|
||||
result = r.json()
|
||||
for i, fs in enumerate(file_paths):
|
||||
if isinstance(fs, list):
|
||||
output = [o for ix, o in enumerate(result) if indices[ix] == i]
|
||||
res = [
|
||||
{
|
||||
"is_file": True,
|
||||
"name": o,
|
||||
"orig_name": Path(f).name,
|
||||
"data": None,
|
||||
}
|
||||
for f, o in zip(fs, output)
|
||||
]
|
||||
else:
|
||||
o = next(o for ix, o in enumerate(result) if indices[ix] == i)
|
||||
res = {
|
||||
"is_file": True,
|
||||
"name": o,
|
||||
"orig_name": Path(fs).name,
|
||||
"data": None,
|
||||
}
|
||||
uploaded.append(res)
|
||||
return uploaded
|
||||
|
||||
def _add_uploaded_files_to_data(
|
||||
self,
|
||||
files: list[str | list[str]] | list[dict[str, Any] | list[dict[str, Any]]],
|
||||
data: list[Any],
|
||||
) -> None:
|
||||
"""Helper function to modify the input data with the uploaded files."""
|
||||
file_counter = 0
|
||||
for i, t in enumerate(self.input_component_types):
|
||||
if t in ["file", "uploadbutton"]:
|
||||
data[i] = files[file_counter]
|
||||
file_counter += 1
|
||||
|
||||
def insert_state(self, *data) -> tuple:
|
||||
data = list(data)
|
||||
for i, input_component_type in enumerate(self.input_component_types):
|
||||
if input_component_type == utils.STATE_COMPONENT:
|
||||
data.insert(i, None)
|
||||
return tuple(data)
|
||||
|
||||
def remove_skipped_components(self, *data) -> tuple:
|
||||
data = [
|
||||
d
|
||||
for d, oct in zip(data, self.output_component_types)
|
||||
if oct not in utils.SKIP_COMPONENTS
|
||||
]
|
||||
return tuple(data)
|
||||
|
||||
def reduce_singleton_output(self, *data) -> Any:
|
||||
if (
|
||||
len(
|
||||
[
|
||||
oct
|
||||
for oct in self.output_component_types
|
||||
if oct not in utils.SKIP_COMPONENTS
|
||||
]
|
||||
)
|
||||
== 1
|
||||
):
|
||||
return data[0]
|
||||
else:
|
||||
return data
|
||||
|
||||
def serialize(self, *data) -> tuple:
|
||||
if len(data) != len(self.serializers):
|
||||
raise ValueError(
|
||||
f"Expected {len(self.serializers)} arguments, got {len(data)}"
|
||||
)
|
||||
|
||||
files = [
|
||||
f
|
||||
for f, t in zip(data, self.input_component_types)
|
||||
if t in ["file", "uploadbutton"]
|
||||
]
|
||||
uploaded_files = self._upload(files)
|
||||
data = list(data)
|
||||
self._add_uploaded_files_to_data(uploaded_files, data)
|
||||
o = tuple([s.serialize(d) for s, d in zip(self.serializers, data)])
|
||||
return o
|
||||
|
||||
def deserialize(self, *data) -> tuple:
|
||||
if len(data) != len(self.deserializers):
|
||||
raise ValueError(
|
||||
f"Expected {len(self.deserializers)} outputs, got {len(data)}"
|
||||
)
|
||||
outputs = tuple(
|
||||
[
|
||||
s.deserialize(
|
||||
d,
|
||||
save_dir=self.client.output_dir,
|
||||
hf_token=self.client.hf_token,
|
||||
root_url=self.root_url,
|
||||
)
|
||||
for s, d in zip(self.deserializers, data)
|
||||
]
|
||||
)
|
||||
return outputs
|
||||
|
||||
def process_predictions(self, *predictions):
|
||||
if self.client.download_files:
|
||||
predictions = self.deserialize(*predictions)
|
||||
predictions = self.remove_skipped_components(*predictions)
|
||||
predictions = self.reduce_singleton_output(*predictions)
|
||||
return predictions
|
||||
|
||||
def _setup_serializers(
|
||||
self,
|
||||
) -> tuple[list[serializing.Serializable], list[serializing.Serializable]]:
|
||||
inputs = self.dependency["inputs"]
|
||||
serializers = []
|
||||
|
||||
for i in inputs:
|
||||
for component in self.client.config["components"]:
|
||||
if component["id"] == i:
|
||||
component_name = component["type"]
|
||||
self.input_component_types.append(component_name)
|
||||
if component.get("serializer"):
|
||||
serializer_name = component["serializer"]
|
||||
if serializer_name not in serializing.SERIALIZER_MAPPING:
|
||||
raise SerializationSetupError(
|
||||
f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
|
||||
)
|
||||
serializer = serializing.SERIALIZER_MAPPING[serializer_name]
|
||||
elif component_name in serializing.COMPONENT_MAPPING:
|
||||
serializer = serializing.COMPONENT_MAPPING[component_name]
|
||||
else:
|
||||
raise SerializationSetupError(
|
||||
f"Unknown component: {component_name}, you may need to update your gradio_client version."
|
||||
)
|
||||
serializers.append(serializer()) # type: ignore
|
||||
|
||||
outputs = self.dependency["outputs"]
|
||||
deserializers = []
|
||||
for i in outputs:
|
||||
for component in self.client.config["components"]:
|
||||
if component["id"] == i:
|
||||
component_name = component["type"]
|
||||
self.output_component_types.append(component_name)
|
||||
if component.get("serializer"):
|
||||
serializer_name = component["serializer"]
|
||||
if serializer_name not in serializing.SERIALIZER_MAPPING:
|
||||
raise SerializationSetupError(
|
||||
f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
|
||||
)
|
||||
deserializer = serializing.SERIALIZER_MAPPING[serializer_name]
|
||||
elif component_name in utils.SKIP_COMPONENTS:
|
||||
deserializer = serializing.SimpleSerializable
|
||||
elif component_name in serializing.COMPONENT_MAPPING:
|
||||
deserializer = serializing.COMPONENT_MAPPING[component_name]
|
||||
else:
|
||||
raise SerializationSetupError(
|
||||
f"Unknown component: {component_name}, you may need to update your gradio_client version."
|
||||
)
|
||||
deserializers.append(deserializer()) # type: ignore
|
||||
|
||||
return serializers, deserializers
|
||||
|
||||
def _use_websocket(self, dependency: dict) -> bool:
|
||||
queue_enabled = self.client.config.get("enable_queue", False)
|
||||
queue_uses_websocket = version.parse(
|
||||
self.client.config.get("version", "2.0")
|
||||
) >= version.Version("3.2")
|
||||
dependency_uses_queue = dependency.get("queue", False) is not False
|
||||
return queue_enabled and queue_uses_websocket and dependency_uses_queue
|
||||
|
||||
async def _ws_fn(self, data, hash_data, helper: Communicator):
|
||||
async with websockets.connect( # type: ignore
|
||||
self.client.ws_url,
|
||||
open_timeout=10,
|
||||
extra_headers=self.client.headers,
|
||||
max_size=1024 * 1024 * 1024,
|
||||
) as websocket:
|
||||
return await utils.get_pred_from_ws(websocket, data, hash_data, helper)
|
@ -246,10 +246,12 @@ class Communicator:
|
||||
########################
|
||||
|
||||
|
||||
def is_http_url_like(possible_url: str) -> bool:
|
||||
def is_http_url_like(possible_url) -> bool:
|
||||
"""
|
||||
Check if the given string looks like an HTTP(S) URL.
|
||||
Check if the given value is a string that looks like an HTTP(S) URL.
|
||||
"""
|
||||
if not isinstance(possible_url, str):
|
||||
return False
|
||||
return possible_url.startswith(("http://", "https://"))
|
||||
|
||||
|
||||
@ -390,7 +392,7 @@ async def get_pred_from_sse_v1_v2(
|
||||
cookies: dict[str, str] | None,
|
||||
pending_messages_per_event: dict[str, list[Message | None]],
|
||||
event_id: str,
|
||||
protocol: Literal["sse_v1", "sse_v2"],
|
||||
protocol: Literal["sse_v1", "sse_v2", "sse_v2.1"],
|
||||
) -> dict[str, Any] | None:
|
||||
done, pending = await asyncio.wait(
|
||||
[
|
||||
@ -510,7 +512,7 @@ async def stream_sse_v1_v2(
|
||||
helper: Communicator,
|
||||
pending_messages_per_event: dict[str, list[Message | None]],
|
||||
event_id: str,
|
||||
protocol: Literal["sse_v1", "sse_v2"],
|
||||
protocol: Literal["sse_v1", "sse_v2", "sse_v2.1"],
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
pending_messages = pending_messages_per_event[event_id]
|
||||
@ -546,10 +548,10 @@ async def stream_sse_v1_v2(
|
||||
log=log_message,
|
||||
)
|
||||
output = msg.get("output", {}).get("data", [])
|
||||
if (
|
||||
msg["msg"] == ServerMessage.process_generating
|
||||
and protocol == "sse_v2"
|
||||
):
|
||||
if msg["msg"] == ServerMessage.process_generating and protocol in [
|
||||
"sse_v2",
|
||||
"sse_v2.1",
|
||||
]:
|
||||
if pending_responses_for_diffs is None:
|
||||
pending_responses_for_diffs = list(output)
|
||||
else:
|
||||
@ -623,6 +625,20 @@ def apply_diff(obj, diff):
|
||||
########################
|
||||
|
||||
|
||||
def upload_file(
|
||||
file_path: str,
|
||||
upload_url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
cookies: dict[str, str] | None = None,
|
||||
):
|
||||
with open(file_path, "rb") as f:
|
||||
files = [("files", (Path(file_path).name, f))]
|
||||
r = httpx.post(upload_url, headers=headers, cookies=cookies, files=files)
|
||||
r.raise_for_status()
|
||||
result = r.json()
|
||||
return result[0]
|
||||
|
||||
|
||||
def download_file(
|
||||
url_path: str,
|
||||
save_dir: str,
|
||||
@ -898,13 +914,18 @@ def get_type(schema: dict):
|
||||
raise APIInfoParseError(f"Cannot parse type for {schema}")
|
||||
|
||||
|
||||
OLD_FILE_DATA = "Dict(path: str, url: str | None, size: int | None, orig_name: str | None, mime_type: str | None)"
|
||||
FILE_DATA = "Dict(path: str, url: str | None, size: int | None, orig_name: str | None, mime_type: str | None, is_stream: bool)"
|
||||
FILE_DATA_FORMATS = [
|
||||
"Dict(path: str, url: str | None, size: int | None, orig_name: str | None, mime_type: str | None)",
|
||||
"Dict(path: str, url: str | None, size: int | None, orig_name: str | None, mime_type: str | None, is_stream: bool)",
|
||||
"Dict(path: str, url: str | None, size: int | None, orig_name: str | None, mime_type: str | None, is_stream: bool, meta: Dict())",
|
||||
]
|
||||
|
||||
CURRENT_FILE_DATA_FORMAT = FILE_DATA_FORMATS[-1]
|
||||
|
||||
|
||||
def json_schema_to_python_type(schema: Any) -> str:
|
||||
type_ = _json_schema_to_python_type(schema, schema.get("$defs"))
|
||||
return type_.replace(FILE_DATA, "filepath")
|
||||
return type_.replace(CURRENT_FILE_DATA_FORMAT, "filepath")
|
||||
|
||||
|
||||
def _json_schema_to_python_type(schema: Any, defs) -> str:
|
||||
@ -980,7 +1001,7 @@ def _json_schema_to_python_type(schema: Any, defs) -> str:
|
||||
raise APIInfoParseError(f"Cannot parse schema {schema}")
|
||||
|
||||
|
||||
def traverse(json_obj: Any, func: Callable, is_root: Callable) -> Any:
|
||||
def traverse(json_obj: Any, func: Callable, is_root: Callable[..., bool]) -> Any:
|
||||
if is_root(json_obj):
|
||||
return func(json_obj)
|
||||
elif isinstance(json_obj, dict):
|
||||
@ -999,27 +1020,57 @@ def traverse(json_obj: Any, func: Callable, is_root: Callable) -> Any:
|
||||
|
||||
def value_is_file(api_info: dict) -> bool:
|
||||
info = _json_schema_to_python_type(api_info, api_info.get("$defs"))
|
||||
return FILE_DATA in info or OLD_FILE_DATA in info
|
||||
return any(file_data_format in info for file_data_format in FILE_DATA_FORMATS)
|
||||
|
||||
|
||||
def is_filepath(s):
|
||||
return isinstance(s, str) and Path(s).exists()
|
||||
def is_filepath(s) -> bool:
|
||||
"""
|
||||
Check if the given value is a valid str or Path filepath on the local filesystem, e.g. "path/to/file".
|
||||
"""
|
||||
return isinstance(s, (str, Path)) and Path(s).exists() and Path(s).is_file()
|
||||
|
||||
|
||||
def is_url(s):
|
||||
return isinstance(s, str) and is_http_url_like(s)
|
||||
def is_file_obj(d) -> bool:
|
||||
"""
|
||||
Check if the given value is a valid FileData object dictionary in versions of Gradio<=4.20, e.g.
|
||||
{
|
||||
"path": "path/to/file",
|
||||
}
|
||||
"""
|
||||
return isinstance(d, dict) and "path" in d and isinstance(d["path"], str)
|
||||
|
||||
|
||||
def is_file_obj(d):
|
||||
return isinstance(d, dict) and "path" in d
|
||||
|
||||
|
||||
def is_file_obj_with_url(d):
|
||||
def is_file_obj_with_meta(d) -> bool:
|
||||
"""
|
||||
Check if the given value is a valid FileData object dictionary in newer versions of Gradio
|
||||
where the file objects include a specific "meta" key, e.g.
|
||||
{
|
||||
"path": "path/to/file",
|
||||
"meta": {"_type: "gradio.FileData"}
|
||||
}
|
||||
"""
|
||||
return (
|
||||
isinstance(d, dict) and "path" in d and "url" in d and isinstance(d["url"], str)
|
||||
isinstance(d, dict)
|
||||
and "path" in d
|
||||
and isinstance(d["path"], str)
|
||||
and "meta" in d
|
||||
and d["meta"].get("_type", "") == "gradio.FileData"
|
||||
)
|
||||
|
||||
|
||||
def is_file_obj_with_url(d) -> bool:
|
||||
"""
|
||||
Check if the given value is a valid FileData object dictionary in newer versions of Gradio
|
||||
where the file objects include a specific "meta" key, and ALSO include a "url" key, e.g.
|
||||
{
|
||||
"path": "path/to/file",
|
||||
"url": "/file=path/to/file",
|
||||
"meta": {"_type: "gradio.FileData"}
|
||||
}
|
||||
"""
|
||||
return is_file_obj_with_meta(d) and "url" in d and isinstance(d["url"], str)
|
||||
|
||||
|
||||
SKIP_COMPONENTS = {
|
||||
"state",
|
||||
"row",
|
||||
@ -1034,3 +1085,16 @@ SKIP_COMPONENTS = {
|
||||
"interpretation",
|
||||
"dataset",
|
||||
}
|
||||
|
||||
|
||||
def file(filepath_or_url: str | Path):
|
||||
s = str(filepath_or_url)
|
||||
data = {"path": s, "meta": {"_type": "gradio.FileData"}}
|
||||
if is_http_url_like(s):
|
||||
return {**data, "orig_name": s.split("/")[-1], "url": s}
|
||||
elif Path(s).exists():
|
||||
return {**data, "orig_name": Path(s).name}
|
||||
else:
|
||||
raise ValueError(
|
||||
f"File {s} does not exist on local filesystem and is not a valid URL."
|
||||
)
|
||||
|
@ -19,7 +19,7 @@ from gradio.networking import Server
|
||||
from huggingface_hub import HfFolder
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
from gradio_client import Client
|
||||
from gradio_client import Client, file
|
||||
from gradio_client.client import DEFAULT_TEMP_DIR
|
||||
from gradio_client.exceptions import AuthenticationError
|
||||
from gradio_client.utils import (
|
||||
@ -255,7 +255,9 @@ class TestClientPredictions:
|
||||
with connect(video_component) as client:
|
||||
job = client.submit(
|
||||
{
|
||||
"video": "https://huggingface.co/spaces/gradio/video_component/resolve/main/files/a.mp4"
|
||||
"video": file(
|
||||
"https://huggingface.co/spaces/gradio/video_component/resolve/main/files/a.mp4"
|
||||
)
|
||||
},
|
||||
fn_index=0,
|
||||
)
|
||||
@ -269,7 +271,9 @@ class TestClientPredictions:
|
||||
with connect(video_component, output_dir=temp_dir) as client:
|
||||
job = client.submit(
|
||||
{
|
||||
"video": "https://huggingface.co/spaces/gradio/video_component/resolve/main/files/a.mp4"
|
||||
"video": file(
|
||||
"https://huggingface.co/spaces/gradio/video_component/resolve/main/files/a.mp4"
|
||||
)
|
||||
},
|
||||
fn_index=0,
|
||||
)
|
||||
@ -370,13 +374,15 @@ class TestClientPredictions:
|
||||
def test_stream_audio(self, stream_audio):
|
||||
with connect(stream_audio) as client:
|
||||
job1 = client.submit(
|
||||
"https://gradio-builds.s3.amazonaws.com/demo-files/bark_demo.mp4",
|
||||
file("https://gradio-builds.s3.amazonaws.com/demo-files/bark_demo.mp4"),
|
||||
api_name="/predict",
|
||||
)
|
||||
assert Path(job1.result()).exists()
|
||||
|
||||
job2 = client.submit(
|
||||
"https://gradio-builds.s3.amazonaws.com/demo-files/audio_sample.wav",
|
||||
file(
|
||||
"https://gradio-builds.s3.amazonaws.com/demo-files/audio_sample.wav"
|
||||
),
|
||||
api_name="/predict",
|
||||
)
|
||||
assert Path(job2.result()).exists()
|
||||
@ -497,6 +503,13 @@ class TestClientPredictions:
|
||||
ret = client.predict(message, initial_history, api_name="/submit")
|
||||
assert ret == ("", [["", None], ["Hello", "I love you"]])
|
||||
|
||||
def test_does_not_upload_dir(self, stateful_chatbot):
|
||||
with connect(stateful_chatbot) as client:
|
||||
initial_history = [["", None]]
|
||||
message = "Hello"
|
||||
ret = client.predict(message, initial_history, api_name="/submit")
|
||||
assert ret == ("", [["", None], ["Hello", "I love you"]])
|
||||
|
||||
def test_can_call_mounted_app_via_api(self):
|
||||
def greet(name):
|
||||
return "Hello " + name + "!"
|
||||
@ -1004,13 +1017,13 @@ class TestAPIInfo:
|
||||
"description": "",
|
||||
}
|
||||
assert isinstance(inputs[0]["example_input"], list)
|
||||
assert isinstance(inputs[0]["example_input"][0], str)
|
||||
assert isinstance(inputs[0]["example_input"][0], dict)
|
||||
|
||||
assert inputs[1]["python_type"] == {
|
||||
"type": "filepath",
|
||||
"description": "",
|
||||
}
|
||||
assert isinstance(inputs[1]["example_input"], str)
|
||||
assert isinstance(inputs[1]["example_input"], dict)
|
||||
|
||||
assert outputs[0]["python_type"] == {
|
||||
"type": "List[filepath]",
|
||||
@ -1158,43 +1171,6 @@ class TestEndpoints:
|
||||
"file7",
|
||||
]
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_upload_v4(self):
|
||||
client = Client(
|
||||
src="gradio-tests/not-actually-private-file-uploadv4-sse",
|
||||
)
|
||||
response = MagicMock(status_code=200)
|
||||
response.json.return_value = [
|
||||
"file1",
|
||||
"file2",
|
||||
"file3",
|
||||
"file4",
|
||||
"file5",
|
||||
"file6",
|
||||
"file7",
|
||||
]
|
||||
with patch("httpx.post", MagicMock(return_value=response)):
|
||||
with patch("builtins.open", MagicMock()):
|
||||
with patch.object(pathlib.Path, "name") as mock_name:
|
||||
mock_name.side_effect = lambda x: x
|
||||
results = client.endpoints[0]._upload(
|
||||
["pre1", ["pre2", "pre3", "pre4"], ["pre5", "pre6"], "pre7"]
|
||||
)
|
||||
|
||||
res = []
|
||||
for re in results:
|
||||
if isinstance(re, list):
|
||||
res.append([r["path"] for r in re])
|
||||
else:
|
||||
res.append(re["path"])
|
||||
|
||||
assert res == [
|
||||
"file1",
|
||||
["file2", "file3", "file4"],
|
||||
["file5", "file6"],
|
||||
"file7",
|
||||
]
|
||||
|
||||
|
||||
cpu = huggingface_hub.SpaceHardware.CPU_BASIC
|
||||
|
||||
|
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from gradio_client import file
|
||||
from gradio_client.documentation import document
|
||||
|
||||
from gradio.components.base import Component
|
||||
@ -98,7 +99,9 @@ class SimpleImage(Component):
|
||||
return FileData(path=str(value), orig_name=Path(value).name)
|
||||
|
||||
def example_payload(self) -> Any:
|
||||
return "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
|
||||
return file(
|
||||
"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
|
||||
)
|
||||
|
||||
def example_value(self) -> Any:
|
||||
return "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
|
||||
|
@ -121,6 +121,9 @@ class Block:
|
||||
self.state_session_capacity = 10000
|
||||
self.temp_files: set[str] = set()
|
||||
self.GRADIO_CACHE = get_upload_folder()
|
||||
# Keep tracks of files that should not be deleted when the delete_cache parmaeter is set
|
||||
# These files are the default value of the component and files that are used in examples
|
||||
self.keep_in_cache = set()
|
||||
|
||||
if render:
|
||||
self.render()
|
||||
@ -253,9 +256,16 @@ class Block:
|
||||
else:
|
||||
url_or_file_path = str(utils.abspath(url_or_file_path))
|
||||
if not utils.is_in_or_equal(url_or_file_path, self.GRADIO_CACHE):
|
||||
temp_file_path = processing_utils.save_file_to_cache(
|
||||
url_or_file_path, cache_dir=self.GRADIO_CACHE
|
||||
)
|
||||
try:
|
||||
temp_file_path = processing_utils.save_file_to_cache(
|
||||
url_or_file_path, cache_dir=self.GRADIO_CACHE
|
||||
)
|
||||
except FileNotFoundError:
|
||||
# This can happen if when using gr.load() and the file is on a remote Space
|
||||
# but the file is not the `value` of the component. For example, if the file
|
||||
# is the `avatar_image` of the `Chatbot` component. In this case, we skip
|
||||
# copying the file to the cache and just use the remote file path.
|
||||
return url_or_file_path
|
||||
else:
|
||||
temp_file_path = url_or_file_path
|
||||
self.temp_files.add(temp_file_path)
|
||||
@ -753,15 +763,18 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
# targets field
|
||||
_targets = dependency.pop("targets")
|
||||
trigger = dependency.pop("trigger", None)
|
||||
targets = [
|
||||
getattr(
|
||||
original_mapping[
|
||||
target if isinstance(target, int) else target[0]
|
||||
],
|
||||
trigger if isinstance(target, int) else target[1],
|
||||
)
|
||||
for target in _targets
|
||||
]
|
||||
is_then_event = False
|
||||
|
||||
# This assumes that you cannot combine multiple .then() events in a single
|
||||
# gr.on() event, which is true for now. If this changes, we will need to
|
||||
# update this code.
|
||||
if not isinstance(_targets[0], int) and _targets[0][1] == "then":
|
||||
if len(_targets) != 1:
|
||||
raise ValueError(
|
||||
"This logic assumes that .then() events are not combined with other events in a single gr.on() event"
|
||||
)
|
||||
is_then_event = True
|
||||
|
||||
dependency.pop("backend_fn")
|
||||
dependency.pop("documentation", None)
|
||||
dependency["inputs"] = [
|
||||
@ -773,12 +786,30 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
dependency.pop("status_tracker", None)
|
||||
dependency["preprocess"] = False
|
||||
dependency["postprocess"] = False
|
||||
targets = [
|
||||
EventListenerMethod(
|
||||
t.__self__ if t.has_trigger else None, t.event_name
|
||||
if is_then_event:
|
||||
targets = [EventListenerMethod(None, "then")]
|
||||
dependency["trigger_after"] = dependency.pop("trigger_after")
|
||||
dependency["trigger_only_on_success"] = dependency.pop(
|
||||
"trigger_only_on_success"
|
||||
)
|
||||
for t in targets
|
||||
]
|
||||
dependency["no_target"] = True
|
||||
else:
|
||||
targets = [
|
||||
getattr(
|
||||
original_mapping[
|
||||
target if isinstance(target, int) else target[0]
|
||||
],
|
||||
trigger if isinstance(target, int) else target[1],
|
||||
)
|
||||
for target in _targets
|
||||
]
|
||||
targets = [
|
||||
EventListenerMethod(
|
||||
t.__self__ if t.has_trigger else None,
|
||||
t.event_name, # type: ignore
|
||||
)
|
||||
for t in targets
|
||||
]
|
||||
dependency = blocks.set_event_trigger(
|
||||
targets=targets, fn=fn, **dependency
|
||||
)[0]
|
||||
@ -866,7 +897,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
show_progress: whether to show progress animation while running.
|
||||
api_name: defines how the endpoint appears in the API docs. Can be a string, None, or False. If set to a string, the endpoint will be exposed in the API docs with the given name. If None (default), the name of the function will be used as the API endpoint. If False, the endpoint will not be exposed in the API docs and downstream apps (including those that `gr.load` this app) will not be able to use this event.
|
||||
js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components
|
||||
no_target: if True, sets "targets" to [], used for Blocks "load" event
|
||||
no_target: if True, sets "targets" to [], used for the Blocks.load() event and .then() events
|
||||
queue: If True, will place the request on the queue, if the queue has been enabled. If False, will not put this event on the queue, even if the queue has been enabled. If None, will use the queue setting of the gradio app.
|
||||
batch: whether this function takes in a batch of inputs
|
||||
max_batch_size: the maximum batch size to send to the function
|
||||
@ -884,7 +915,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
# Support for singular parameter
|
||||
_targets = [
|
||||
(
|
||||
target.block._id if target.block and not no_target else None,
|
||||
target.block._id if not no_target and target.block else None,
|
||||
target.event_name,
|
||||
)
|
||||
for target in targets
|
||||
@ -1256,7 +1287,8 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
serialized_input = client_utils.traverse(
|
||||
inputs[i],
|
||||
format_file,
|
||||
lambda s: client_utils.is_filepath(s) or client_utils.is_url(s),
|
||||
lambda s: client_utils.is_filepath(s)
|
||||
or client_utils.is_http_url_like(s),
|
||||
)
|
||||
else:
|
||||
serialized_input = inputs[i]
|
||||
@ -1735,7 +1767,7 @@ Received outputs:
|
||||
"is_colab": utils.colab_check(),
|
||||
"stylesheets": self.stylesheets,
|
||||
"theme": self.theme.name,
|
||||
"protocol": "sse_v2",
|
||||
"protocol": "sse_v2.1",
|
||||
"body_css": {
|
||||
"body_background_fill": self.theme._get_computed_value(
|
||||
"body_background_fill"
|
||||
|
@ -7,6 +7,7 @@ from typing import Any, List
|
||||
import gradio_client.utils as client_utils
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
from gradio_client import file
|
||||
from gradio_client.documentation import document
|
||||
|
||||
from gradio import processing_utils, utils
|
||||
@ -205,7 +206,9 @@ class AnnotatedImage(Component):
|
||||
|
||||
def example_payload(self) -> Any:
|
||||
return {
|
||||
"image": "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png",
|
||||
"image": file(
|
||||
"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
|
||||
),
|
||||
"annotations": [],
|
||||
}
|
||||
|
||||
|
@ -8,6 +8,7 @@ from typing import Any, Callable, Literal
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
from gradio_client import file
|
||||
from gradio_client import utils as client_utils
|
||||
from gradio_client.documentation import document
|
||||
|
||||
@ -182,7 +183,9 @@ class Audio(
|
||||
)
|
||||
|
||||
def example_payload(self) -> Any:
|
||||
return "https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav"
|
||||
return file(
|
||||
"https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav"
|
||||
)
|
||||
|
||||
def example_value(self) -> Any:
|
||||
return "https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav"
|
||||
|
@ -190,9 +190,6 @@ class Component(ComponentBase, Block):
|
||||
self.scale = scale
|
||||
self.min_width = min_width
|
||||
self.interactive = interactive
|
||||
# Keep tracks of files that should not be deleted when the delete_cache parmaeter is set
|
||||
# These files are the default value of the component and files that are used in examples
|
||||
self.keep_in_cache = set()
|
||||
|
||||
# load_event is set in the Blocks.attach_load_events method
|
||||
self.load_event: None | dict[str, Any] = None
|
||||
@ -203,6 +200,7 @@ class Component(ComponentBase, Block):
|
||||
initial_value,
|
||||
self, # type: ignore
|
||||
postprocess=True,
|
||||
keep_in_cache=True,
|
||||
)
|
||||
if is_file_obj(self.value):
|
||||
self.keep_in_cache.add(self.value["path"])
|
||||
|
@ -5,7 +5,6 @@ from __future__ import annotations
|
||||
from typing import Any, Literal
|
||||
|
||||
from gradio_client.documentation import document
|
||||
from gradio_client.utils import is_file_obj
|
||||
|
||||
from gradio import processing_utils
|
||||
from gradio.components.base import (
|
||||
@ -95,12 +94,9 @@ class Dataset(Component):
|
||||
# use the previous name to be backwards-compatible with previously-created
|
||||
# custom components
|
||||
example[i] = component.as_example(ex)
|
||||
example[i] = processing_utils.move_files_to_cache(
|
||||
example[i],
|
||||
component,
|
||||
)
|
||||
if is_file_obj(example[i]):
|
||||
self.keep_in_cache.add(example[i]["path"])
|
||||
example[i] = processing_utils.move_files_to_cache(
|
||||
example[i], component, keep_in_cache=True
|
||||
)
|
||||
self.type = type
|
||||
self.label = label
|
||||
if headers is not None:
|
||||
|
@ -6,6 +6,7 @@ import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal
|
||||
|
||||
from gradio_client import file
|
||||
from gradio_client.documentation import document
|
||||
|
||||
from gradio.components.base import Component
|
||||
@ -99,8 +100,10 @@ class DownloadButton(Component):
|
||||
return None
|
||||
return FileData(path=str(value))
|
||||
|
||||
def example_payload(self) -> str:
|
||||
return "https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
|
||||
def example_payload(self) -> dict:
|
||||
return file(
|
||||
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
|
||||
)
|
||||
|
||||
def example_value(self) -> str:
|
||||
return "https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
|
||||
|
@ -8,6 +8,7 @@ from pathlib import Path
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
import gradio_client.utils as client_utils
|
||||
from gradio_client import file
|
||||
from gradio_client.documentation import document
|
||||
|
||||
from gradio import processing_utils
|
||||
@ -199,10 +200,14 @@ class File(Component):
|
||||
|
||||
def example_payload(self) -> Any:
|
||||
if self.file_count == "single":
|
||||
return "https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
|
||||
return file(
|
||||
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
|
||||
)
|
||||
else:
|
||||
return [
|
||||
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
|
||||
file(
|
||||
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
|
||||
)
|
||||
]
|
||||
|
||||
def example_value(self) -> Any:
|
||||
|
@ -9,6 +9,7 @@ from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
from gradio_client import file
|
||||
from gradio_client.documentation import document
|
||||
from gradio_client.utils import is_http_url_like
|
||||
|
||||
@ -223,7 +224,9 @@ class Gallery(Component):
|
||||
def example_payload(self) -> Any:
|
||||
return [
|
||||
{
|
||||
"image": "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
|
||||
"image": file(
|
||||
"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
|
||||
)
|
||||
},
|
||||
]
|
||||
|
||||
|
@ -8,6 +8,7 @@ from typing import Any, Literal, cast
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
from gradio_client import file
|
||||
from gradio_client.documentation import document
|
||||
from PIL import ImageOps
|
||||
|
||||
@ -209,7 +210,9 @@ class Image(StreamingInput, Component):
|
||||
)
|
||||
|
||||
def example_payload(self) -> Any:
|
||||
return "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
|
||||
return file(
|
||||
"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
|
||||
)
|
||||
|
||||
def example_value(self) -> Any:
|
||||
return "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
|
||||
|
@ -9,6 +9,7 @@ from typing import Any, Iterable, List, Literal, Optional, TypedDict, Union, cas
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
from gradio_client import file
|
||||
from gradio_client.documentation import document
|
||||
|
||||
from gradio import image_utils, utils
|
||||
@ -216,9 +217,7 @@ class ImageEditor(Component):
|
||||
) -> np.ndarray | PIL.Image.Image | str | None:
|
||||
if file is None:
|
||||
return None
|
||||
|
||||
im = PIL.Image.open(file.path)
|
||||
|
||||
if file.orig_name:
|
||||
p = Path(file.orig_name)
|
||||
name = p.stem
|
||||
@ -317,7 +316,9 @@ class ImageEditor(Component):
|
||||
|
||||
def example_payload(self) -> Any:
|
||||
return {
|
||||
"background": "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png",
|
||||
"background": file(
|
||||
"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png"
|
||||
),
|
||||
"layers": [],
|
||||
"composite": None,
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from gradio_client import file
|
||||
from gradio_client.documentation import document
|
||||
|
||||
from gradio.components.base import Component
|
||||
@ -118,7 +119,9 @@ class Model3D(Component):
|
||||
return Path(input_data).name if input_data else ""
|
||||
|
||||
def example_payload(self):
|
||||
return "https://raw.githubusercontent.com/gradio-app/gradio/main/demo/model3D/files/Fox.gltf"
|
||||
return file(
|
||||
"https://raw.githubusercontent.com/gradio-app/gradio/main/demo/model3D/files/Fox.gltf"
|
||||
)
|
||||
|
||||
def example_value(self):
|
||||
return "https://raw.githubusercontent.com/gradio-app/gradio/main/demo/model3D/files/Fox.gltf"
|
||||
|
@ -8,6 +8,7 @@ from pathlib import Path
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
import gradio_client.utils as client_utils
|
||||
from gradio_client import file
|
||||
from gradio_client.documentation import document
|
||||
|
||||
from gradio import processing_utils
|
||||
@ -114,10 +115,14 @@ class UploadButton(Component):
|
||||
|
||||
def example_payload(self) -> Any:
|
||||
if self.file_count == "single":
|
||||
return "https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
|
||||
return file(
|
||||
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
|
||||
)
|
||||
else:
|
||||
return [
|
||||
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
|
||||
file(
|
||||
"https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf"
|
||||
)
|
||||
]
|
||||
|
||||
def example_value(self) -> Any:
|
||||
|
@ -7,6 +7,7 @@ import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Literal, Optional
|
||||
|
||||
from gradio_client import file
|
||||
from gradio_client import utils as client_utils
|
||||
from gradio_client.documentation import document
|
||||
|
||||
@ -353,7 +354,9 @@ class Video(Component):
|
||||
|
||||
def example_payload(self) -> Any:
|
||||
return {
|
||||
"video": "https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4",
|
||||
"video": file(
|
||||
"https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4"
|
||||
),
|
||||
}
|
||||
|
||||
def example_value(self) -> Any:
|
||||
|
@ -164,6 +164,7 @@ class FileData(GradioModel):
|
||||
orig_name: Optional[str] = None # original filename
|
||||
mime_type: Optional[str] = None
|
||||
is_stream: bool = False
|
||||
meta: dict = {"_type": "gradio.FileData"}
|
||||
|
||||
@property
|
||||
def is_none(self):
|
||||
|
@ -418,7 +418,13 @@ def from_spaces(
|
||||
|
||||
|
||||
def from_spaces_blocks(space: str, hf_token: str | None) -> Blocks:
|
||||
client = Client(space, hf_token=hf_token, download_files=False)
|
||||
client = Client(
|
||||
space,
|
||||
hf_token=hf_token,
|
||||
upload_files=False,
|
||||
download_files=False,
|
||||
_skip_components=False,
|
||||
)
|
||||
# We set deserialize to False to avoid downloading output files from the server.
|
||||
# Instead, we serve them as URLs using the /proxy/ endpoint directly from the server.
|
||||
|
||||
|
@ -241,6 +241,7 @@ def move_files_to_cache(
|
||||
block: Block,
|
||||
postprocess: bool = False,
|
||||
check_in_upload_folder=False,
|
||||
keep_in_cache=False,
|
||||
) -> dict:
|
||||
"""Move any files in `data` to cache and (optionally), adds URL prefixes (/file=...) needed to access the cached file.
|
||||
Also handles the case where the file is on an external Gradio app (/proxy=...).
|
||||
@ -252,6 +253,7 @@ def move_files_to_cache(
|
||||
block: The component whose data is being processed
|
||||
postprocess: Whether its running from postprocessing
|
||||
check_in_upload_folder: If True, instead of moving the file to cache, checks if the file is in already in cache (exception if not).
|
||||
keep_in_cache: If True, the file will not be deleted from cache when the server is shut down.
|
||||
"""
|
||||
|
||||
def _move_to_cache(d: dict):
|
||||
@ -278,6 +280,8 @@ def move_files_to_cache(
|
||||
if temp_file_path is None:
|
||||
raise ValueError("Did not determine a file path for the resource.")
|
||||
payload.path = temp_file_path
|
||||
if keep_in_cache:
|
||||
block.keep_in_cache.add(payload.path)
|
||||
|
||||
url_prefix = "/stream/" if payload.is_stream else "/file="
|
||||
if block.proxy_url:
|
||||
|
@ -11,10 +11,10 @@ Using the `gradio_client` library, we can easily use the Gradio as an API to tra
|
||||
Here's the entire code to do it:
|
||||
|
||||
```python
|
||||
from gradio_client import Client
|
||||
from gradio_client import Client, file
|
||||
|
||||
client = Client("abidlabs/whisper")
|
||||
client.predict("audio_sample.wav")
|
||||
client.predict(file("audio_sample.wav"))
|
||||
|
||||
>> "This is a test of the whisper speech recognition model."
|
||||
```
|
||||
@ -25,12 +25,12 @@ The Gradio client works with any hosted Gradio app, whether it be an image gener
|
||||
|
||||
## Installation
|
||||
|
||||
If you already have a recent version of `gradio`, then the `gradio_client` is included as a dependency.
|
||||
If you already have a recent version of `gradio`, then the `gradio_client` is included as a dependency. But note that this documentation reflects the latest version of the `gradio_client`, so upgrade if you're not sure!
|
||||
|
||||
Otherwise, the lightweight `gradio_client` package can be installed from pip (or pip3) and is tested to work with Python versions 3.9 or higher:
|
||||
The lightweight `gradio_client` package can be installed from pip (or pip3) and is tested to work with Python versions 3.9 or higher:
|
||||
|
||||
```bash
|
||||
$ pip install gradio_client
|
||||
$ pip install --upgrade gradio_client
|
||||
```
|
||||
|
||||
## Connecting to a running Gradio App
|
||||
@ -62,12 +62,12 @@ The `gradio_client` includes a class method: `Client.duplicate()` to make this p
|
||||
|
||||
```python
|
||||
import os
|
||||
from gradio_client import Client
|
||||
from gradio_client import Client, file
|
||||
|
||||
HF_TOKEN = os.environ.get("HF_TOKEN")
|
||||
|
||||
client = Client.duplicate("abidlabs/whisper", hf_token=HF_TOKEN)
|
||||
client.predict("audio_sample.wav")
|
||||
client.predict(file("audio_sample.wav"))
|
||||
|
||||
>> "This is a test of the whisper speech recognition model."
|
||||
```
|
||||
@ -130,13 +130,13 @@ client.predict(4, "add", 5)
|
||||
>> 9.0
|
||||
```
|
||||
|
||||
For certain inputs, such as images, you should pass in the filepath or URL to the file. Likewise, for the corresponding output types, you will get a filepath or URL returned.
|
||||
For when working with files (e.g. image files), you should pass in the filepath or URL to the file enclosed within `gradio_client.file()`.
|
||||
|
||||
```python
|
||||
from gradio_client import Client
|
||||
from gradio_client import Client, file
|
||||
|
||||
client = Client("abidlabs/whisper")
|
||||
client.predict("https://audio-samples.github.io/samples/mp3/blizzard_unconditional/sample-0.mp3")
|
||||
client.predict(file("https://audio-samples.github.io/samples/mp3/blizzard_unconditional/sample-0.mp3"))
|
||||
|
||||
>> "My thought I have nobody by a beauty and will as you poured. Mr. Rochester is serve in that so don't find simpus, and devoted abode, to at might in a r—"
|
||||
```
|
||||
@ -202,8 +202,8 @@ The `Job` class also has a `.cancel()` instance method that cancels jobs that ha
|
||||
|
||||
```py
|
||||
client = Client("abidlabs/whisper")
|
||||
job1 = client.submit("audio_sample1.wav")
|
||||
job2 = client.submit("audio_sample2.wav")
|
||||
job1 = client.submit(file("audio_sample1.wav"))
|
||||
job2 = client.submit(file("audio_sample2.wav"))
|
||||
job1.cancel() # will return False, assuming the job has started
|
||||
job2.cancel() # will return True, indicating that the job has been canceled
|
||||
```
|
||||
|
@ -1,10 +1,19 @@
|
||||
<script lang="ts">
|
||||
import type { ComponentMeta, Dependency } from "../types";
|
||||
import CopyButton from "./CopyButton.svelte";
|
||||
import { represent_value } from "./utils";
|
||||
import { represent_value, is_potentially_nested_file_data } from "./utils";
|
||||
import { Block } from "@gradio/atoms";
|
||||
import EndpointDetail from "./EndpointDetail.svelte";
|
||||
|
||||
interface EndpointParameter {
|
||||
label: string;
|
||||
type: string;
|
||||
python_type: { type: string };
|
||||
component: string;
|
||||
example_input: string;
|
||||
serializer: string;
|
||||
}
|
||||
|
||||
export let dependency: Dependency;
|
||||
export let dependency_index: number;
|
||||
export let root: string;
|
||||
@ -18,19 +27,12 @@
|
||||
let python_code: HTMLElement;
|
||||
let js_code: HTMLElement;
|
||||
|
||||
let has_file_path = endpoint_parameters.some((param: EndpointParameter) =>
|
||||
is_potentially_nested_file_data(param.example_input)
|
||||
);
|
||||
let blob_components = ["Audio", "File", "Image", "Video"];
|
||||
let blob_examples: any[] = endpoint_parameters.filter(
|
||||
(param: {
|
||||
label: string;
|
||||
type: string;
|
||||
python_type: {
|
||||
type: string;
|
||||
description: string;
|
||||
};
|
||||
component: string;
|
||||
example_input: string;
|
||||
serializer: string;
|
||||
}) => blob_components.includes(param.component)
|
||||
(param: EndpointParameter) => blob_components.includes(param.component)
|
||||
);
|
||||
</script>
|
||||
|
||||
@ -47,7 +49,7 @@
|
||||
<CopyButton code={python_code?.innerText} />
|
||||
</div>
|
||||
<div bind:this={python_code}>
|
||||
<pre>from gradio_client import Client
|
||||
<pre>from gradio_client import Client{#if has_file_path}, file{/if}
|
||||
|
||||
client = Client(<span class="token string">"{root}"</span>)
|
||||
result = client.predict(<!--
|
||||
@ -64,7 +66,8 @@ result = client.predict(<!--
|
||||
-->{/if}<!--
|
||||
--><span class="desc"
|
||||
><!--
|
||||
--> # {python_type.type} {#if python_type.description}({python_type.description}){/if}<!----> in '{label}' <!--
|
||||
--> # {python_type.type} {#if python_type.description}({python_type.description})
|
||||
{/if}<!---->in '{label}' <!--
|
||||
-->{component} component<!--
|
||||
--></span
|
||||
><!--
|
||||
@ -84,7 +87,7 @@ print(result)</pre>
|
||||
<pre>import { client } from "@gradio/client";
|
||||
{#each blob_examples as { label, type, python_type, component, example_input, serializer }, i}<!--
|
||||
-->
|
||||
const response_{i} = await fetch("{example_input}");
|
||||
const response_{i} = await fetch("{example_input.url}");
|
||||
const example{component} = await response_{i}.blob();
|
||||
{/each}<!--
|
||||
-->
|
||||
|
@ -7,7 +7,7 @@ export function represent_value(
|
||||
if (type === undefined) {
|
||||
return lang === "py" ? "None" : null;
|
||||
}
|
||||
if (type === "string" || type === "str" || type == "filepath") {
|
||||
if (type === "string" || type === "str") {
|
||||
return lang === null ? value : '"' + value + '"';
|
||||
} else if (type === "number") {
|
||||
return lang === null ? parseFloat(value) : value;
|
||||
@ -35,5 +35,72 @@ export function represent_value(
|
||||
}
|
||||
return value;
|
||||
}
|
||||
return JSON.stringify(value);
|
||||
if (lang === "py") {
|
||||
value = replace_file_data_with_file_function(value);
|
||||
}
|
||||
return stringify_except_file_function(value);
|
||||
}
|
||||
|
||||
export function is_potentially_nested_file_data(obj: any): boolean {
|
||||
if (typeof obj === "object" && obj !== null) {
|
||||
if (obj.hasOwnProperty("path") && obj.hasOwnProperty("meta")) {
|
||||
if (
|
||||
typeof obj.meta === "object" &&
|
||||
obj.meta !== null &&
|
||||
obj.meta._type === "gradio.FileData"
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (typeof obj === "object" && obj !== null) {
|
||||
for (let key in obj) {
|
||||
if (typeof obj[key] === "object") {
|
||||
let result = is_potentially_nested_file_data(obj[key]);
|
||||
if (result) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
function replace_file_data_with_file_function(obj: any): any {
|
||||
if (typeof obj === "object" && obj !== null && !Array.isArray(obj)) {
|
||||
if (
|
||||
"path" in obj &&
|
||||
"meta" in obj &&
|
||||
obj.meta?._type === "gradio.FileData"
|
||||
) {
|
||||
return `file('${obj.path}')`;
|
||||
}
|
||||
}
|
||||
if (Array.isArray(obj)) {
|
||||
obj.forEach((item, index) => {
|
||||
if (typeof item === "object" && item !== null) {
|
||||
obj[index] = replace_file_data_with_file_function(item); // Recurse and update array elements
|
||||
}
|
||||
});
|
||||
} else if (typeof obj === "object" && obj !== null) {
|
||||
Object.keys(obj).forEach((key) => {
|
||||
obj[key] = replace_file_data_with_file_function(obj[key]); // Recurse and update object properties
|
||||
});
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
|
||||
function stringify_except_file_function(obj: any): string {
|
||||
const jsonString = JSON.stringify(obj, (key, value) => {
|
||||
if (
|
||||
typeof value === "string" &&
|
||||
value.startsWith("file(") &&
|
||||
value.endsWith(")")
|
||||
) {
|
||||
return `UNQUOTED${value}`; // Flag the special strings
|
||||
}
|
||||
return value;
|
||||
});
|
||||
const regex = /"UNQUOTEDfile\(([^)]*)\)"/g;
|
||||
return jsonString.replace(regex, (match, p1) => `file(${p1})`);
|
||||
}
|
||||
|
@ -37,7 +37,7 @@ from gradio.components.dataframe import DataframeData
|
||||
from gradio.components.file_explorer import FileExplorerData
|
||||
from gradio.components.image_editor import EditorData
|
||||
from gradio.components.video import VideoData
|
||||
from gradio.data_classes import FileData, ListFiles
|
||||
from gradio.data_classes import FileData, GradioModel, GradioRootModel, ListFiles
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
@ -1543,6 +1543,7 @@ class TestVideo:
|
||||
"size": None,
|
||||
"url": None,
|
||||
"is_stream": False,
|
||||
"meta": {"_type": "gradio.FileData"},
|
||||
},
|
||||
"subtitles": None,
|
||||
}
|
||||
@ -1555,6 +1556,7 @@ class TestVideo:
|
||||
"size": None,
|
||||
"url": None,
|
||||
"is_stream": False,
|
||||
"meta": {"_type": "gradio.FileData"},
|
||||
},
|
||||
"subtitles": {
|
||||
"path": "s1.srt",
|
||||
@ -1563,6 +1565,7 @@ class TestVideo:
|
||||
"size": None,
|
||||
"url": None,
|
||||
"is_stream": False,
|
||||
"meta": {"_type": "gradio.FileData"},
|
||||
},
|
||||
}
|
||||
postprocessed_video["video"]["path"] = os.path.basename(
|
||||
@ -2262,6 +2265,7 @@ class TestGallery:
|
||||
"size": None,
|
||||
"url": url,
|
||||
"is_stream": False,
|
||||
"meta": {"_type": "gradio.FileData"},
|
||||
},
|
||||
"caption": None,
|
||||
}
|
||||
@ -2290,6 +2294,7 @@ class TestGallery:
|
||||
"size": None,
|
||||
"url": None,
|
||||
"is_stream": False,
|
||||
"meta": {"_type": "gradio.FileData"},
|
||||
},
|
||||
"caption": "foo_caption",
|
||||
},
|
||||
@ -2301,6 +2306,7 @@ class TestGallery:
|
||||
"size": None,
|
||||
"url": None,
|
||||
"is_stream": False,
|
||||
"meta": {"_type": "gradio.FileData"},
|
||||
},
|
||||
"caption": "bar_caption",
|
||||
},
|
||||
@ -2312,6 +2318,7 @@ class TestGallery:
|
||||
"size": None,
|
||||
"url": None,
|
||||
"is_stream": False,
|
||||
"meta": {"_type": "gradio.FileData"},
|
||||
},
|
||||
"caption": None,
|
||||
},
|
||||
@ -2323,6 +2330,7 @@ class TestGallery:
|
||||
"size": None,
|
||||
"url": None,
|
||||
"is_stream": False,
|
||||
"meta": {"_type": "gradio.FileData"},
|
||||
},
|
||||
"caption": None,
|
||||
},
|
||||
@ -2977,3 +2985,25 @@ def test_component_example_values(io_components):
|
||||
else:
|
||||
c: Component = component()
|
||||
c.postprocess(c.example_value())
|
||||
|
||||
|
||||
def test_component_example_payloads(io_components):
|
||||
for component in io_components:
|
||||
if component == PDF:
|
||||
continue
|
||||
elif component in [gr.BarPlot, gr.LinePlot, gr.ScatterPlot]:
|
||||
c: Component = component(x="x", y="y")
|
||||
else:
|
||||
c: Component = component()
|
||||
data = c.example_payload()
|
||||
data = processing_utils.move_files_to_cache(
|
||||
data,
|
||||
c,
|
||||
check_in_upload_folder=False,
|
||||
)
|
||||
if getattr(c, "data_model", None) and data is not None:
|
||||
if issubclass(c.data_model, GradioModel): # type: ignore
|
||||
data = c.data_model(**data) # type: ignore
|
||||
elif issubclass(c.data_model, GradioRootModel): # type: ignore
|
||||
data = c.data_model(root=data) # type: ignore
|
||||
c.preprocess(data)
|
||||
|
@ -397,6 +397,9 @@ class TestLoadInterfaceWithExamples:
|
||||
demo = gr.load("spaces/gradio-tests/test-calculator-2v4-sse")
|
||||
assert demo(2, "add", 4) == 6
|
||||
|
||||
def test_loading_chatbot_with_avatar_images_does_not_raise_errors(self):
|
||||
gr.load("gradio/chatbot_multimodal", src="spaces")
|
||||
|
||||
|
||||
def test_get_tabular_examples_replaces_nan_with_str_nan():
|
||||
readme = """
|
||||
|
@ -339,10 +339,12 @@ def test_add_root_url():
|
||||
"file": {
|
||||
"path": "path",
|
||||
"url": "/file=path",
|
||||
"meta": {"_type": "gradio.FileData"},
|
||||
},
|
||||
"file2": {
|
||||
"path": "path2",
|
||||
"url": "https://www.gradio.app",
|
||||
"meta": {"_type": "gradio.FileData"},
|
||||
},
|
||||
}
|
||||
root_url = "http://localhost:7860"
|
||||
@ -350,10 +352,12 @@ def test_add_root_url():
|
||||
"file": {
|
||||
"path": "path",
|
||||
"url": f"{root_url}/file=path",
|
||||
"meta": {"_type": "gradio.FileData"},
|
||||
},
|
||||
"file2": {
|
||||
"path": "path2",
|
||||
"url": "https://www.gradio.app",
|
||||
"meta": {"_type": "gradio.FileData"},
|
||||
},
|
||||
}
|
||||
assert processing_utils.add_root_url(data, root_url, None) == expected
|
||||
@ -362,10 +366,12 @@ def test_add_root_url():
|
||||
"file": {
|
||||
"path": "path",
|
||||
"url": f"{new_root_url}/file=path",
|
||||
"meta": {"_type": "gradio.FileData"},
|
||||
},
|
||||
"file2": {
|
||||
"path": "path2",
|
||||
"url": "https://www.gradio.app",
|
||||
"meta": {"_type": "gradio.FileData"},
|
||||
},
|
||||
}
|
||||
assert (
|
||||
|
Loading…
x
Reference in New Issue
Block a user