mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
parent
5e3a8969ff
commit
b456feb5c4
@ -51,19 +51,23 @@ class Client:
|
||||
src: str,
|
||||
hf_token: str | None = None,
|
||||
max_workers: int = 40,
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
src: Either the name of the Hugging Face Space to load, (e.g. "abidlabs/whisper-large-v2") or the full URL (including "http" or "https") of the hosted Gradio app to load (e.g. "http://mydomain.com/app" or "https://bec81a83-5b5c-471e.gradio.live/").
|
||||
hf_token: The Hugging Face token to use to access private Spaces. Automatically fetched if you are logged in via the Hugging Face Hub CLI.
|
||||
max_workers: The maximum number of thread workers that can be used to make requests to the remote Gradio app simultaneously.
|
||||
verbose: Whether the client should print statements to the console.
|
||||
"""
|
||||
self.verbose = verbose
|
||||
self.hf_token = hf_token
|
||||
self.headers = build_hf_headers(
|
||||
token=hf_token,
|
||||
library_name="gradio_client",
|
||||
library_version=utils.__version__,
|
||||
)
|
||||
self.space_id = None
|
||||
|
||||
if src.startswith("http://") or src.startswith("https://"):
|
||||
_src = src
|
||||
@ -73,8 +77,10 @@ class Client:
|
||||
raise ValueError(
|
||||
f"Could not find Space: {src}. If it is a private Space, please provide an hf_token."
|
||||
)
|
||||
self.space_id = src
|
||||
self.src = _src
|
||||
print(f"Loaded as API: {self.src} ✔")
|
||||
if self.verbose:
|
||||
print(f"Loaded as API: {self.src} ✔")
|
||||
|
||||
self.api_url = utils.API_URL.format(self.src)
|
||||
self.ws_url = utils.WS_URL.format(self.src).replace("http", "ws", 1)
|
||||
@ -151,7 +157,9 @@ class Client:
|
||||
end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper)
|
||||
future = self.executor.submit(end_to_end_fn, *args)
|
||||
|
||||
job = Job(future, communicator=helper)
|
||||
job = Job(
|
||||
future, communicator=helper, verbose=self.verbose, space_id=self.space_id
|
||||
)
|
||||
|
||||
if result_callbacks:
|
||||
if isinstance(result_callbacks, Callable):
|
||||
@ -635,15 +643,21 @@ class Job(Future):
|
||||
self,
|
||||
future: Future,
|
||||
communicator: Communicator | None = None,
|
||||
verbose: bool = True,
|
||||
space_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
future: The future object that represents the prediction call, created by the Client.submit() method
|
||||
communicator: The communicator object that is used to communicate between the client and the background thread running the job
|
||||
verbose: Whether to print any status-related messages to the console
|
||||
space_id: The space ID corresponding to the Client object that created this Job object
|
||||
"""
|
||||
self.future = future
|
||||
self.communicator = communicator
|
||||
self._counter = 0
|
||||
self.verbose = verbose
|
||||
self.space_id = space_id
|
||||
|
||||
def __iter__(self) -> Job:
|
||||
return self
|
||||
@ -772,6 +786,12 @@ class Job(Future):
|
||||
)
|
||||
else:
|
||||
with self.communicator.lock:
|
||||
eta = self.communicator.job.latest_status.eta
|
||||
if self.verbose and self.space_id and eta and eta > 30:
|
||||
print(
|
||||
f"Due to heavy traffic on this app, the prediction will take approximately {int(eta)} seconds."
|
||||
f"For faster predictions without waiting in queue, you may duplicate the space: {utils.DUPLICATE_URL.format(self.space_id)}"
|
||||
)
|
||||
return self.communicator.job.latest_status
|
||||
|
||||
def __getattr__(self, name):
|
||||
|
@ -20,6 +20,7 @@ from websockets.legacy.protocol import WebSocketCommonProtocol
|
||||
|
||||
API_URL = "{}/api/predict/"
|
||||
WS_URL = "{}/queue/join"
|
||||
DUPLICATE_URL = "https://huggingface.co/spaces/{}?duplicate=true"
|
||||
STATE_COMPONENT = "state"
|
||||
|
||||
__version__ = (pkgutil.get_data(__name__, "version.txt") or b"").decode("ascii").strip()
|
||||
|
Loading…
x
Reference in New Issue
Block a user