Add root_url to serializers in gradio_client (#3736)

* Add root_url to serializers

* Add url fix

* Respect fn parameter

* Fix docstring

* Add other test

* Pass to method
This commit is contained in:
Freddy Boulton 2023-04-04 10:35:12 -07:00 committed by GitHub
parent 1981c010c6
commit 46ee226d8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 27 additions and 6 deletions

View File

@ -44,13 +44,14 @@ class Client:
)
if src.startswith("http://") or src.startswith("https://"):
self.src = src
_src = src
else:
self.src = self._space_name_to_src(src)
if self.src is None:
_src = self._space_name_to_src(src)
if _src is None:
raise ValueError(
f"Could not find Space: {src}. If it is a private Space, please provide an hf_token."
)
self.src = _src
print(f"Loaded as API: {self.src}")
self.api_url = utils.API_URL.format(self.src)
@ -308,6 +309,7 @@ class Endpoint:
self.use_ws = self._use_websocket(self.dependency)
self.input_component_types = []
self.output_component_types = []
self.root_url = client.src + "/" if not client.src.endswith("/") else client.src
try:
self.serializers, self.deserializers = self._setup_serializers()
self.is_valid = self.dependency[
@ -465,7 +467,7 @@ class Endpoint:
), f"Expected {len(self.deserializers)} outputs, got {len(data)}"
return tuple(
[
s.deserialize(d, hf_token=self.client.hf_token)
s.deserialize(d, hf_token=self.client.hf_token, root_url=self.root_url)
for s, d, oct in zip(
self.deserializers, data, self.output_component_types
)

View File

@ -164,7 +164,10 @@ class FileSerializable(Serializable):
"""
if x is None or x == "":
return None
filename = str(Path(load_dir) / x)
if utils.is_valid_url(x):
filename = x
else:
filename = str(Path(load_dir) / x)
return {
"name": filename,
"data": utils.encode_url_or_file_to_base64(filename),
@ -185,7 +188,7 @@ class FileSerializable(Serializable):
Parameters:
x: Base64 representation of file to deserialize into a string filepath
save_dir: Path to directory to save the deserialized file to
root_url: If this component is loaded from an external Space, this is the URL of the Space
root_url: If this component is loaded from an external Space, this is the URL of the Space.
hf_token: If this component is loaded from an external private Space, this is the access token for the Space
"""
if x is None:

View File

@ -1,5 +1,6 @@
import json
import os
import pathlib
import time
from datetime import datetime, timedelta
from unittest.mock import patch
@ -87,6 +88,15 @@ class TestPredictionsFromSpaces:
statuses.append(job.status())
assert all(s.code in [Status.PROCESSING, Status.FINISHED] for s in statuses)
@pytest.mark.flaky
def test_job_output_video(self):
client = Client(src="gradio/video_component")
job = client.predict(
"https://huggingface.co/spaces/gradio/video_component/resolve/main/files/a.mp4",
fn_index=0,
)
assert pathlib.Path(job.result()).exists()
class TestStatusUpdates:
@patch("gradio_client.client.Endpoint.make_end_to_end_fn")
@ -244,6 +254,12 @@ class TestStatusUpdates:
class TestAPIInfo:
@pytest.mark.parametrize("trailing_char", ["/", ""])
def test_test_endpoint_src(self, trailing_char):
src = "https://gradio-calculator.hf.space" + trailing_char
client = Client(src=src)
assert client.endpoints[0].root_url == "https://gradio-calculator.hf.space/"
@pytest.mark.flaky
def test_numerical_to_label_space(self):
client = Client("gradio-tests/titanic-survival")