Fix gr.load for file-based Spaces (#7350)

* changes

* add changeset

* fixes

* changes

* fix

* add test

* add changeset

* improve test

* Fixing `gr.load()` part II (#7358)

* audio

* changes

* add changeset

* changes

* changes

* changes

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>

* Delete .changeset/fresh-gifts-worry.md

* add changeset

* format

* upload

* add changeset

* changes

* backend

* print

* add changeset

* changes

* client

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2024-02-09 11:24:00 -08:00 committed by GitHub
parent a7fa47a175
commit 7302a6e151
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 84 additions and 30 deletions

View File

@ -0,0 +1,6 @@
---
"gradio": patch
"gradio_client": patch
---
fix:Fix `gr.load` for file-based Spaces

View File

@ -71,26 +71,36 @@ class Client:
src: str,
hf_token: str | None = None,
max_workers: int = 40,
serialize: bool = True,
serialize: bool | None = None,
output_dir: str | Path = DEFAULT_TEMP_DIR,
verbose: bool = True,
auth: tuple[str, str] | None = None,
*,
headers: dict[str, str] | None = None,
upload_files: bool = True,
download_files: bool = True,
):
"""
Parameters:
src: Either the name of the Hugging Face Space to load, (e.g. "abidlabs/whisper-large-v2") or the full URL (including "http" or "https") of the hosted Gradio app to load (e.g. "http://mydomain.com/app" or "https://bec81a83-5b5c-471e.gradio.live/").
hf_token: The Hugging Face token to use to access private Spaces. Automatically fetched if you are logged in via the Hugging Face Hub CLI. Obtain from: https://huggingface.co/settings/token
max_workers: The maximum number of thread workers that can be used to make requests to the remote Gradio app simultaneously.
serialize: Whether the client should serialize the inputs and deserialize the outputs of the remote API. If set to False, the client will pass the inputs and outputs as-is, without serializing/deserializing them. E.g. you if you set this to False, you'd submit an image in base64 format instead of a filepath, and you'd get back an image in base64 format from the remote API instead of a filepath.
serialize: Deprecated. Please use the equivalent `upload_files` parameter instead.
output_dir: The directory to save files that are downloaded from the remote API. If None, reads from the GRADIO_TEMP_DIR environment variable. Defaults to a temporary directory on your machine.
verbose: Whether the client should print statements to the console.
headers: Additional headers to send to the remote Gradio app on every request. By default only the HF authorization and user-agent headers are sent. These headers will override the default headers if they have the same keys.
upload_files: Whether the client should treat input string filepath as files and upload them to the remote server. If False, the client will treat input string filepaths as strings always and not modify them.
download_files: Whether the client should download output files from the remote API and return them as string filepaths on the local machine. If False, the client will a FileData dataclass object with the filepath on the remote machine instead.
"""
self.verbose = verbose
self.hf_token = hf_token
self.serialize = serialize
if serialize is not None:
warnings.warn(
"The `serialize` parameter is deprecated and will be removed. Please use the equivalent `upload_files` parameter instead."
)
upload_files = serialize
self.upload_files = upload_files
self.download_files = download_files
self.headers = build_hf_headers(
token=hf_token,
library_name="gradio_client",
@ -463,11 +473,10 @@ class Client:
return job
def _get_api_info(self):
if self.serialize:
if self.upload_files:
api_info_url = urllib.parse.urljoin(self.src, utils.API_INFO_URL)
else:
api_info_url = urllib.parse.urljoin(self.src, utils.RAW_API_INFO_URL)
if self.app_version > version.Version("3.36.1"):
r = httpx.get(api_info_url, headers=self.headers, cookies=self.cookies)
if r.is_success:
@ -477,7 +486,10 @@ class Client:
else:
fetch = httpx.post(
utils.SPACE_FETCHER_URL,
json={"config": json.dumps(self.config), "serialize": self.serialize},
json={
"config": json.dumps(self.config),
"serialize": self.upload_files,
},
)
if fetch.is_success:
info = fetch.json()["api"]
@ -955,7 +967,11 @@ class Endpoint:
@staticmethod
def value_is_file(component: dict) -> bool:
# Hacky for now
# This is still hacky as it does not tell us which part of the payload is a file.
# If a component has a complex payload, part of which is a file, this will simply
# return True, which means that all parts of the payload will be uploaded as files
# if they are valid file paths. The better approach would be to traverse the
# component's api_info and figure out exactly which part of the payload is a file.
if "api_info" not in component:
return False
return utils.value_is_file(component["api_info"])
@ -973,7 +989,7 @@ class Endpoint:
if not self.is_valid:
raise utils.InvalidAPIEndpointError()
data = self.insert_state(*data)
if self.client.serialize:
if self.client.upload_files:
data = self.serialize(*data)
predictions = _predict(*data)
predictions = self.process_predictions(*predictions)
@ -1117,6 +1133,9 @@ class Endpoint:
file_list.append(d)
return ReplaceMe(len(file_list) - 1)
def handle_url(s):
return {"path": s, "orig_name": s.split("/")[-1]}
new_data = []
for i, d in enumerate(data):
if self.input_component_types[i].value_is_file:
@ -1126,6 +1145,8 @@ class Endpoint:
d = utils.traverse(
d, get_file, lambda s: utils.is_file_obj(s) or utils.is_filepath(s)
)
# Handle URLs here since we don't upload them
d = utils.traverse(d, handle_url, lambda s: utils.is_url(s))
new_data.append(d)
return file_list, new_data
@ -1146,11 +1167,6 @@ class Endpoint:
uploaded_files = self._upload(files)
data = list(new_data)
data = self._add_uploaded_files_to_data(data, uploaded_files)
data = utils.traverse(
data,
lambda s: {"path": s},
utils.is_url,
)
o = tuple(data)
return o
@ -1182,12 +1198,12 @@ class Endpoint:
def deserialize(self, *data) -> tuple:
data_ = list(data)
data_: list[Any] = utils.traverse(data_, self.download_file, utils.is_file_obj)
return tuple(data_)
def process_predictions(self, *predictions):
predictions = self.deserialize(*predictions)
if self.client.download_files:
predictions = self.deserialize(*predictions)
predictions = self.remove_skipped_components(*predictions)
predictions = self.reduce_singleton_output(*predictions)
return predictions
@ -1258,7 +1274,7 @@ class EndpointV3Compatibility:
if not self.is_valid:
raise utils.InvalidAPIEndpointError()
data = self.insert_state(*data)
if self.client.serialize:
if self.client.upload_files:
data = self.serialize(*data)
predictions = _predict(*data)
predictions = self.process_predictions(*predictions)
@ -1449,7 +1465,7 @@ class EndpointV3Compatibility:
return outputs
def process_predictions(self, *predictions):
if self.client.serialize:
if self.client.download_files:
predictions = self.deserialize(*predictions)
predictions = self.remove_skipped_components(*predictions)
predictions = self.reduce_singleton_output(*predictions)

View File

@ -895,6 +895,7 @@ def get_type(schema: dict):
raise APIInfoParseError(f"Cannot parse type for {schema}")
OLD_FILE_DATA = "Dict(path: str, url: str | None, size: int | None, orig_name: str | None, mime_type: str | None)"
FILE_DATA = "Dict(path: str, url: str | None, size: int | None, orig_name: str | None, mime_type: str | None, is_stream: bool)"
@ -995,7 +996,7 @@ def traverse(json_obj: Any, func: Callable, is_root: Callable) -> Any:
def value_is_file(api_info: dict) -> bool:
info = _json_schema_to_python_type(api_info, api_info.get("$defs"))
return FILE_DATA in info
return FILE_DATA in info or OLD_FILE_DATA in info
def is_filepath(s):

View File

@ -133,6 +133,28 @@ class TestClientPredictions:
output = client.predict("abc", api_name="/predict")
assert output == "abc"
@pytest.mark.flaky
def test_space_with_files_v4_sse_v2(self):
space_id = "gradio-tests/space_with_files_v4_sse_v2"
client = Client(space_id)
payload = (
"https://audio-samples.github.io/samples/mp3/blizzard_unconditional/sample-0.mp3",
{
"video": "https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4",
"subtitle": None,
},
"https://audio-samples.github.io/samples/mp3/blizzard_unconditional/sample-0.mp3",
)
output = client.predict(*payload, api_name="/predict")
assert output[0].endswith(".wav") # Audio files are converted to wav
assert output[1]["video"].endswith(
"world.mp4"
) # Video files are not converted by default
assert (
output[2]
== "https://audio-samples.github.io/samples/mp3/blizzard_unconditional/sample-0.mp3"
) # textbox string should remain exactly the same
def test_state(self, increment_demo):
with connect(increment_demo) as client:
output = client.predict(api_name="/increment_without_queue")

View File

@ -350,4 +350,7 @@ class Video(Component):
return FileData(path=str(subtitle))
def example_inputs(self) -> Any:
return "https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4"
return {
"video": "https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4",
"subtitles": None,
}

View File

@ -418,12 +418,16 @@ def from_spaces(
def from_spaces_blocks(space: str, hf_token: str | None) -> Blocks:
client = Client(space, hf_token=hf_token)
client = Client(space, hf_token=hf_token, download_files=False)
# We set deserialize to False to avoid downloading output files from the server.
# Instead, we serve them as URLs using the /proxy/ endpoint directly from the server.
if client.app_version < version.Version("4.0.0b14"):
raise GradioVersionIncompatibleError(
f"Gradio version 4.x cannot load spaces with versions less than 4.x ({client.app_version})."
"Please downgrade to version 3 to load this space."
)
# Use end_to_end_fn here to properly upload/download all files
predict_fns = []
for fn_index, endpoint in enumerate(client.endpoints):

View File

@ -261,22 +261,24 @@ def move_files_to_cache(
# without it being served from the gradio server
# This makes it so that the URL is not downloaded and speeds up event processing
if payload.url and postprocess:
temp_file_path = payload.url
else:
payload.path = payload.url
elif not block.proxy_url:
# If the file is on a remote server, do not move it to cache.
temp_file_path = move_resource_to_block_cache(payload.path, block)
assert temp_file_path is not None
payload.path = temp_file_path
assert temp_file_path is not None
payload.path = temp_file_path
if add_urls:
url_prefix = "/stream/" if payload.is_stream else "/file="
if block.proxy_url:
url = f"/proxy={block.proxy_url}{url_prefix}{temp_file_path}"
elif client_utils.is_http_url_like(
temp_file_path
) or temp_file_path.startswith(f"{url_prefix}"):
url = temp_file_path
proxy_url = block.proxy_url.rstrip("/")
url = f"/proxy={proxy_url}{url_prefix}{payload.path}"
elif client_utils.is_http_url_like(payload.path) or payload.path.startswith(
f"{url_prefix}"
):
url = payload.path
else:
url = f"{url_prefix}{temp_file_path}"
url = f"{url_prefix}{payload.path}"
payload.url = url
return payload.model_dump()