Allow setting custom headers in Python Client (#7334)

* add headers

* add changeset

* fix

* test

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2024-02-07 07:19:53 -08:00 committed by GitHub
parent 5b45a162b3
commit b95d0d043c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 32 additions and 0 deletions

View File

@ -0,0 +1,6 @@
---
"gradio": minor
"gradio_client": minor
---
feat:Allow setting custom headers in Python Client

View File

@ -75,6 +75,8 @@ class Client:
output_dir: str | Path = DEFAULT_TEMP_DIR,
verbose: bool = True,
auth: tuple[str, str] | None = None,
*,
headers: dict[str, str] | None = None,
):
"""
Parameters:
@ -84,6 +86,7 @@ class Client:
serialize: Whether the client should serialize the inputs and deserialize the outputs of the remote API. If set to False, the client will pass the inputs and outputs as-is, without serializing/deserializing them. E.g. you if you set this to False, you'd submit an image in base64 format instead of a filepath, and you'd get back an image in base64 format from the remote API instead of a filepath.
output_dir: The directory to save files that are downloaded from the remote API. If None, reads from the GRADIO_TEMP_DIR environment variable. Defaults to a temporary directory on your machine.
verbose: Whether the client should print statements to the console.
headers: Additional headers to send to the remote Gradio app on every request. By default only the HF authorization and user-agent headers are sent. These headers will override the default headers if they have the same keys.
"""
self.verbose = verbose
self.hf_token = hf_token
@ -93,6 +96,8 @@ class Client:
library_name="gradio_client",
library_version=utils.__version__,
)
if headers:
self.headers.update(headers)
self.space_id = None
self.cookies: dict[str, str] = {}
self.output_dir = (

View File

@ -51,6 +51,27 @@ def connect(
demo.server.thread.join(timeout=1)
class TestClientInitialization:
def test_headers_constructed_correctly(self):
client = Client("gradio-tests/titanic-survival", hf_token=HF_TOKEN)
assert {"authorization": f"Bearer {HF_TOKEN}"}.items() <= client.headers.items()
client = Client(
"gradio-tests/titanic-survival",
hf_token=HF_TOKEN,
headers={"additional": "value"},
)
assert {
"authorization": f"Bearer {HF_TOKEN}",
"additional": "value",
}.items() <= client.headers.items()
client = Client(
"gradio-tests/titanic-survival",
hf_token=HF_TOKEN,
headers={"authorization": "Bearer abcde"},
)
assert {"authorization": "Bearer abcde"}.items() <= client.headers.items()
class TestClientPredictions:
@pytest.mark.flaky
def test_raise_error_invalid_state(self):