mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
Allow setting custom headers in Python Client (#7334)
* add headers * add changeset * fix * test --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
5b45a162b3
commit
b95d0d043c
6
.changeset/twelve-crabs-refuse.md
Normal file
6
.changeset/twelve-crabs-refuse.md
Normal file
@ -0,0 +1,6 @@
|
||||
---
|
||||
"gradio": minor
|
||||
"gradio_client": minor
|
||||
---
|
||||
|
||||
feat:Allow setting custom headers in Python Client
|
@ -75,6 +75,8 @@ class Client:
|
||||
output_dir: str | Path = DEFAULT_TEMP_DIR,
|
||||
verbose: bool = True,
|
||||
auth: tuple[str, str] | None = None,
|
||||
*,
|
||||
headers: dict[str, str] | None = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
@ -84,6 +86,7 @@ class Client:
|
||||
serialize: Whether the client should serialize the inputs and deserialize the outputs of the remote API. If set to False, the client will pass the inputs and outputs as-is, without serializing/deserializing them. E.g. you if you set this to False, you'd submit an image in base64 format instead of a filepath, and you'd get back an image in base64 format from the remote API instead of a filepath.
|
||||
output_dir: The directory to save files that are downloaded from the remote API. If None, reads from the GRADIO_TEMP_DIR environment variable. Defaults to a temporary directory on your machine.
|
||||
verbose: Whether the client should print statements to the console.
|
||||
headers: Additional headers to send to the remote Gradio app on every request. By default only the HF authorization and user-agent headers are sent. These headers will override the default headers if they have the same keys.
|
||||
"""
|
||||
self.verbose = verbose
|
||||
self.hf_token = hf_token
|
||||
@ -93,6 +96,8 @@ class Client:
|
||||
library_name="gradio_client",
|
||||
library_version=utils.__version__,
|
||||
)
|
||||
if headers:
|
||||
self.headers.update(headers)
|
||||
self.space_id = None
|
||||
self.cookies: dict[str, str] = {}
|
||||
self.output_dir = (
|
||||
|
@ -51,6 +51,27 @@ def connect(
|
||||
demo.server.thread.join(timeout=1)
|
||||
|
||||
|
||||
class TestClientInitialization:
|
||||
def test_headers_constructed_correctly(self):
|
||||
client = Client("gradio-tests/titanic-survival", hf_token=HF_TOKEN)
|
||||
assert {"authorization": f"Bearer {HF_TOKEN}"}.items() <= client.headers.items()
|
||||
client = Client(
|
||||
"gradio-tests/titanic-survival",
|
||||
hf_token=HF_TOKEN,
|
||||
headers={"additional": "value"},
|
||||
)
|
||||
assert {
|
||||
"authorization": f"Bearer {HF_TOKEN}",
|
||||
"additional": "value",
|
||||
}.items() <= client.headers.items()
|
||||
client = Client(
|
||||
"gradio-tests/titanic-survival",
|
||||
hf_token=HF_TOKEN,
|
||||
headers={"authorization": "Bearer abcde"},
|
||||
)
|
||||
assert {"authorization": "Bearer abcde"}.items() <= client.headers.items()
|
||||
|
||||
|
||||
class TestClientPredictions:
|
||||
@pytest.mark.flaky
|
||||
def test_raise_error_invalid_state(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user