mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
Add root_url to serializers in gradio_client (#3736)
* Add root_url to serializers * Add url fix * Respect fn parameter * Fix docstring * Add other test * Pass to method
This commit is contained in:
parent
1981c010c6
commit
46ee226d8c
@ -44,13 +44,14 @@ class Client:
|
||||
)
|
||||
|
||||
if src.startswith("http://") or src.startswith("https://"):
|
||||
self.src = src
|
||||
_src = src
|
||||
else:
|
||||
self.src = self._space_name_to_src(src)
|
||||
if self.src is None:
|
||||
_src = self._space_name_to_src(src)
|
||||
if _src is None:
|
||||
raise ValueError(
|
||||
f"Could not find Space: {src}. If it is a private Space, please provide an hf_token."
|
||||
)
|
||||
self.src = _src
|
||||
print(f"Loaded as API: {self.src} ✔")
|
||||
|
||||
self.api_url = utils.API_URL.format(self.src)
|
||||
@ -308,6 +309,7 @@ class Endpoint:
|
||||
self.use_ws = self._use_websocket(self.dependency)
|
||||
self.input_component_types = []
|
||||
self.output_component_types = []
|
||||
self.root_url = client.src + "/" if not client.src.endswith("/") else client.src
|
||||
try:
|
||||
self.serializers, self.deserializers = self._setup_serializers()
|
||||
self.is_valid = self.dependency[
|
||||
@ -465,7 +467,7 @@ class Endpoint:
|
||||
), f"Expected {len(self.deserializers)} outputs, got {len(data)}"
|
||||
return tuple(
|
||||
[
|
||||
s.deserialize(d, hf_token=self.client.hf_token)
|
||||
s.deserialize(d, hf_token=self.client.hf_token, root_url=self.root_url)
|
||||
for s, d, oct in zip(
|
||||
self.deserializers, data, self.output_component_types
|
||||
)
|
||||
|
@ -164,7 +164,10 @@ class FileSerializable(Serializable):
|
||||
"""
|
||||
if x is None or x == "":
|
||||
return None
|
||||
filename = str(Path(load_dir) / x)
|
||||
if utils.is_valid_url(x):
|
||||
filename = x
|
||||
else:
|
||||
filename = str(Path(load_dir) / x)
|
||||
return {
|
||||
"name": filename,
|
||||
"data": utils.encode_url_or_file_to_base64(filename),
|
||||
@ -185,7 +188,7 @@ class FileSerializable(Serializable):
|
||||
Parameters:
|
||||
x: Base64 representation of file to deserialize into a string filepath
|
||||
save_dir: Path to directory to save the deserialized file to
|
||||
root_url: If this component is loaded from an external Space, this is the URL of the Space
|
||||
root_url: If this component is loaded from an external Space, this is the URL of the Space.
|
||||
hf_token: If this component is loaded from an external private Space, this is the access token for the Space
|
||||
"""
|
||||
if x is None:
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
@ -87,6 +88,15 @@ class TestPredictionsFromSpaces:
|
||||
statuses.append(job.status())
|
||||
assert all(s.code in [Status.PROCESSING, Status.FINISHED] for s in statuses)
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_job_output_video(self):
|
||||
client = Client(src="gradio/video_component")
|
||||
job = client.predict(
|
||||
"https://huggingface.co/spaces/gradio/video_component/resolve/main/files/a.mp4",
|
||||
fn_index=0,
|
||||
)
|
||||
assert pathlib.Path(job.result()).exists()
|
||||
|
||||
|
||||
class TestStatusUpdates:
|
||||
@patch("gradio_client.client.Endpoint.make_end_to_end_fn")
|
||||
@ -244,6 +254,12 @@ class TestStatusUpdates:
|
||||
|
||||
|
||||
class TestAPIInfo:
|
||||
@pytest.mark.parametrize("trailing_char", ["/", ""])
|
||||
def test_test_endpoint_src(self, trailing_char):
|
||||
src = "https://gradio-calculator.hf.space" + trailing_char
|
||||
client = Client(src=src)
|
||||
assert client.endpoints[0].root_url == "https://gradio-calculator.hf.space/"
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_numerical_to_label_space(self):
|
||||
client = Client("gradio-tests/titanic-survival")
|
||||
|
Loading…
x
Reference in New Issue
Block a user