mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
Fix gr.load
for file-based Spaces (#7350)
* changes * add changeset * fixes * changes * fix * add test * add changeset * improve test * Fixing `gr.load()` part II (#7358) * audio * changes * add changeset * changes * changes * changes --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> * Delete .changeset/fresh-gifts-worry.md * add changeset * format * upload * add changeset * changes * backend * print * add changeset * changes * client --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
a7fa47a175
commit
7302a6e151
6
.changeset/tender-lamps-shout.md
Normal file
6
.changeset/tender-lamps-shout.md
Normal file
@ -0,0 +1,6 @@
|
||||
---
|
||||
"gradio": patch
|
||||
"gradio_client": patch
|
||||
---
|
||||
|
||||
fix:Fix `gr.load` for file-based Spaces
|
@ -71,26 +71,36 @@ class Client:
|
||||
src: str,
|
||||
hf_token: str | None = None,
|
||||
max_workers: int = 40,
|
||||
serialize: bool = True,
|
||||
serialize: bool | None = None,
|
||||
output_dir: str | Path = DEFAULT_TEMP_DIR,
|
||||
verbose: bool = True,
|
||||
auth: tuple[str, str] | None = None,
|
||||
*,
|
||||
headers: dict[str, str] | None = None,
|
||||
upload_files: bool = True,
|
||||
download_files: bool = True,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
src: Either the name of the Hugging Face Space to load, (e.g. "abidlabs/whisper-large-v2") or the full URL (including "http" or "https") of the hosted Gradio app to load (e.g. "http://mydomain.com/app" or "https://bec81a83-5b5c-471e.gradio.live/").
|
||||
hf_token: The Hugging Face token to use to access private Spaces. Automatically fetched if you are logged in via the Hugging Face Hub CLI. Obtain from: https://huggingface.co/settings/token
|
||||
max_workers: The maximum number of thread workers that can be used to make requests to the remote Gradio app simultaneously.
|
||||
serialize: Whether the client should serialize the inputs and deserialize the outputs of the remote API. If set to False, the client will pass the inputs and outputs as-is, without serializing/deserializing them. E.g. you if you set this to False, you'd submit an image in base64 format instead of a filepath, and you'd get back an image in base64 format from the remote API instead of a filepath.
|
||||
serialize: Deprecated. Please use the equivalent `upload_files` parameter instead.
|
||||
output_dir: The directory to save files that are downloaded from the remote API. If None, reads from the GRADIO_TEMP_DIR environment variable. Defaults to a temporary directory on your machine.
|
||||
verbose: Whether the client should print statements to the console.
|
||||
headers: Additional headers to send to the remote Gradio app on every request. By default only the HF authorization and user-agent headers are sent. These headers will override the default headers if they have the same keys.
|
||||
upload_files: Whether the client should treat input string filepath as files and upload them to the remote server. If False, the client will treat input string filepaths as strings always and not modify them.
|
||||
download_files: Whether the client should download output files from the remote API and return them as string filepaths on the local machine. If False, the client will a FileData dataclass object with the filepath on the remote machine instead.
|
||||
"""
|
||||
self.verbose = verbose
|
||||
self.hf_token = hf_token
|
||||
self.serialize = serialize
|
||||
if serialize is not None:
|
||||
warnings.warn(
|
||||
"The `serialize` parameter is deprecated and will be removed. Please use the equivalent `upload_files` parameter instead."
|
||||
)
|
||||
upload_files = serialize
|
||||
self.upload_files = upload_files
|
||||
self.download_files = download_files
|
||||
self.headers = build_hf_headers(
|
||||
token=hf_token,
|
||||
library_name="gradio_client",
|
||||
@ -463,11 +473,10 @@ class Client:
|
||||
return job
|
||||
|
||||
def _get_api_info(self):
|
||||
if self.serialize:
|
||||
if self.upload_files:
|
||||
api_info_url = urllib.parse.urljoin(self.src, utils.API_INFO_URL)
|
||||
else:
|
||||
api_info_url = urllib.parse.urljoin(self.src, utils.RAW_API_INFO_URL)
|
||||
|
||||
if self.app_version > version.Version("3.36.1"):
|
||||
r = httpx.get(api_info_url, headers=self.headers, cookies=self.cookies)
|
||||
if r.is_success:
|
||||
@ -477,7 +486,10 @@ class Client:
|
||||
else:
|
||||
fetch = httpx.post(
|
||||
utils.SPACE_FETCHER_URL,
|
||||
json={"config": json.dumps(self.config), "serialize": self.serialize},
|
||||
json={
|
||||
"config": json.dumps(self.config),
|
||||
"serialize": self.upload_files,
|
||||
},
|
||||
)
|
||||
if fetch.is_success:
|
||||
info = fetch.json()["api"]
|
||||
@ -955,7 +967,11 @@ class Endpoint:
|
||||
|
||||
@staticmethod
|
||||
def value_is_file(component: dict) -> bool:
|
||||
# Hacky for now
|
||||
# This is still hacky as it does not tell us which part of the payload is a file.
|
||||
# If a component has a complex payload, part of which is a file, this will simply
|
||||
# return True, which means that all parts of the payload will be uploaded as files
|
||||
# if they are valid file paths. The better approach would be to traverse the
|
||||
# component's api_info and figure out exactly which part of the payload is a file.
|
||||
if "api_info" not in component:
|
||||
return False
|
||||
return utils.value_is_file(component["api_info"])
|
||||
@ -973,7 +989,7 @@ class Endpoint:
|
||||
if not self.is_valid:
|
||||
raise utils.InvalidAPIEndpointError()
|
||||
data = self.insert_state(*data)
|
||||
if self.client.serialize:
|
||||
if self.client.upload_files:
|
||||
data = self.serialize(*data)
|
||||
predictions = _predict(*data)
|
||||
predictions = self.process_predictions(*predictions)
|
||||
@ -1117,6 +1133,9 @@ class Endpoint:
|
||||
file_list.append(d)
|
||||
return ReplaceMe(len(file_list) - 1)
|
||||
|
||||
def handle_url(s):
|
||||
return {"path": s, "orig_name": s.split("/")[-1]}
|
||||
|
||||
new_data = []
|
||||
for i, d in enumerate(data):
|
||||
if self.input_component_types[i].value_is_file:
|
||||
@ -1126,6 +1145,8 @@ class Endpoint:
|
||||
d = utils.traverse(
|
||||
d, get_file, lambda s: utils.is_file_obj(s) or utils.is_filepath(s)
|
||||
)
|
||||
# Handle URLs here since we don't upload them
|
||||
d = utils.traverse(d, handle_url, lambda s: utils.is_url(s))
|
||||
new_data.append(d)
|
||||
return file_list, new_data
|
||||
|
||||
@ -1146,11 +1167,6 @@ class Endpoint:
|
||||
uploaded_files = self._upload(files)
|
||||
data = list(new_data)
|
||||
data = self._add_uploaded_files_to_data(data, uploaded_files)
|
||||
data = utils.traverse(
|
||||
data,
|
||||
lambda s: {"path": s},
|
||||
utils.is_url,
|
||||
)
|
||||
o = tuple(data)
|
||||
return o
|
||||
|
||||
@ -1182,12 +1198,12 @@ class Endpoint:
|
||||
|
||||
def deserialize(self, *data) -> tuple:
|
||||
data_ = list(data)
|
||||
|
||||
data_: list[Any] = utils.traverse(data_, self.download_file, utils.is_file_obj)
|
||||
return tuple(data_)
|
||||
|
||||
def process_predictions(self, *predictions):
|
||||
predictions = self.deserialize(*predictions)
|
||||
if self.client.download_files:
|
||||
predictions = self.deserialize(*predictions)
|
||||
predictions = self.remove_skipped_components(*predictions)
|
||||
predictions = self.reduce_singleton_output(*predictions)
|
||||
return predictions
|
||||
@ -1258,7 +1274,7 @@ class EndpointV3Compatibility:
|
||||
if not self.is_valid:
|
||||
raise utils.InvalidAPIEndpointError()
|
||||
data = self.insert_state(*data)
|
||||
if self.client.serialize:
|
||||
if self.client.upload_files:
|
||||
data = self.serialize(*data)
|
||||
predictions = _predict(*data)
|
||||
predictions = self.process_predictions(*predictions)
|
||||
@ -1449,7 +1465,7 @@ class EndpointV3Compatibility:
|
||||
return outputs
|
||||
|
||||
def process_predictions(self, *predictions):
|
||||
if self.client.serialize:
|
||||
if self.client.download_files:
|
||||
predictions = self.deserialize(*predictions)
|
||||
predictions = self.remove_skipped_components(*predictions)
|
||||
predictions = self.reduce_singleton_output(*predictions)
|
||||
|
@ -895,6 +895,7 @@ def get_type(schema: dict):
|
||||
raise APIInfoParseError(f"Cannot parse type for {schema}")
|
||||
|
||||
|
||||
OLD_FILE_DATA = "Dict(path: str, url: str | None, size: int | None, orig_name: str | None, mime_type: str | None)"
|
||||
FILE_DATA = "Dict(path: str, url: str | None, size: int | None, orig_name: str | None, mime_type: str | None, is_stream: bool)"
|
||||
|
||||
|
||||
@ -995,7 +996,7 @@ def traverse(json_obj: Any, func: Callable, is_root: Callable) -> Any:
|
||||
|
||||
def value_is_file(api_info: dict) -> bool:
|
||||
info = _json_schema_to_python_type(api_info, api_info.get("$defs"))
|
||||
return FILE_DATA in info
|
||||
return FILE_DATA in info or OLD_FILE_DATA in info
|
||||
|
||||
|
||||
def is_filepath(s):
|
||||
|
@ -133,6 +133,28 @@ class TestClientPredictions:
|
||||
output = client.predict("abc", api_name="/predict")
|
||||
assert output == "abc"
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_space_with_files_v4_sse_v2(self):
|
||||
space_id = "gradio-tests/space_with_files_v4_sse_v2"
|
||||
client = Client(space_id)
|
||||
payload = (
|
||||
"https://audio-samples.github.io/samples/mp3/blizzard_unconditional/sample-0.mp3",
|
||||
{
|
||||
"video": "https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4",
|
||||
"subtitle": None,
|
||||
},
|
||||
"https://audio-samples.github.io/samples/mp3/blizzard_unconditional/sample-0.mp3",
|
||||
)
|
||||
output = client.predict(*payload, api_name="/predict")
|
||||
assert output[0].endswith(".wav") # Audio files are converted to wav
|
||||
assert output[1]["video"].endswith(
|
||||
"world.mp4"
|
||||
) # Video files are not converted by default
|
||||
assert (
|
||||
output[2]
|
||||
== "https://audio-samples.github.io/samples/mp3/blizzard_unconditional/sample-0.mp3"
|
||||
) # textbox string should remain exactly the same
|
||||
|
||||
def test_state(self, increment_demo):
|
||||
with connect(increment_demo) as client:
|
||||
output = client.predict(api_name="/increment_without_queue")
|
||||
|
@ -350,4 +350,7 @@ class Video(Component):
|
||||
return FileData(path=str(subtitle))
|
||||
|
||||
def example_inputs(self) -> Any:
|
||||
return "https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4"
|
||||
return {
|
||||
"video": "https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4",
|
||||
"subtitles": None,
|
||||
}
|
||||
|
@ -418,12 +418,16 @@ def from_spaces(
|
||||
|
||||
|
||||
def from_spaces_blocks(space: str, hf_token: str | None) -> Blocks:
|
||||
client = Client(space, hf_token=hf_token)
|
||||
client = Client(space, hf_token=hf_token, download_files=False)
|
||||
# We set deserialize to False to avoid downloading output files from the server.
|
||||
# Instead, we serve them as URLs using the /proxy/ endpoint directly from the server.
|
||||
|
||||
if client.app_version < version.Version("4.0.0b14"):
|
||||
raise GradioVersionIncompatibleError(
|
||||
f"Gradio version 4.x cannot load spaces with versions less than 4.x ({client.app_version})."
|
||||
"Please downgrade to version 3 to load this space."
|
||||
)
|
||||
|
||||
# Use end_to_end_fn here to properly upload/download all files
|
||||
predict_fns = []
|
||||
for fn_index, endpoint in enumerate(client.endpoints):
|
||||
|
@ -261,22 +261,24 @@ def move_files_to_cache(
|
||||
# without it being served from the gradio server
|
||||
# This makes it so that the URL is not downloaded and speeds up event processing
|
||||
if payload.url and postprocess:
|
||||
temp_file_path = payload.url
|
||||
else:
|
||||
payload.path = payload.url
|
||||
elif not block.proxy_url:
|
||||
# If the file is on a remote server, do not move it to cache.
|
||||
temp_file_path = move_resource_to_block_cache(payload.path, block)
|
||||
assert temp_file_path is not None
|
||||
payload.path = temp_file_path
|
||||
assert temp_file_path is not None
|
||||
payload.path = temp_file_path
|
||||
|
||||
if add_urls:
|
||||
url_prefix = "/stream/" if payload.is_stream else "/file="
|
||||
if block.proxy_url:
|
||||
url = f"/proxy={block.proxy_url}{url_prefix}{temp_file_path}"
|
||||
elif client_utils.is_http_url_like(
|
||||
temp_file_path
|
||||
) or temp_file_path.startswith(f"{url_prefix}"):
|
||||
url = temp_file_path
|
||||
proxy_url = block.proxy_url.rstrip("/")
|
||||
url = f"/proxy={proxy_url}{url_prefix}{payload.path}"
|
||||
elif client_utils.is_http_url_like(payload.path) or payload.path.startswith(
|
||||
f"{url_prefix}"
|
||||
):
|
||||
url = payload.path
|
||||
else:
|
||||
url = f"{url_prefix}{temp_file_path}"
|
||||
url = f"{url_prefix}{payload.path}"
|
||||
payload.url = url
|
||||
|
||||
return payload.model_dump()
|
||||
|
Loading…
x
Reference in New Issue
Block a user