Allow accepting user-provided-tokens in gr.load (#9807)

* load

* add changeset

* lint

* external

* lint

* changes

* format

* changes

* renamed

* external

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2024-10-23 11:58:46 -07:00 committed by GitHub
parent 90d9d14518
commit 5e89b6d23a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 39 additions and 11 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": minor
---
feat:Allow accepting user-provided-tokens in `gr.load`

View File

@ -41,6 +41,7 @@ def load(
| None = None,
token: str | None = None,
hf_token: str | None = None,
accept_token: bool = False,
**kwargs,
) -> Blocks:
"""
@ -49,6 +50,7 @@ def load(
name: the name of the model (e.g. "google/vit-base-patch16-224") or Space (e.g. "flax-community/spanish-gpt2"). This is the first parameter passed into the `src` function. Can also be formatted as {src}/{repo name} (e.g. "models/google/vit-base-patch16-224") if `src` is not provided.
src: function that accepts a string model `name` and a string or None `token` and returns a Gradio app. Alternatively, this parameter takes one of two strings for convenience: "models" (for loading a Hugging Face model through the Inference API) or "spaces" (for loading a Hugging Face Space). If None, uses the prefix of the `name` parameter to determine `src`.
token: optional token that is passed as the second parameter to the `src` function. For Hugging Face repos, uses the local HF token when loading models but not Spaces (when loading Spaces, only provide a token if you are loading a trusted private Space as the token can be read by the Space you are loading). Find HF tokens here: https://huggingface.co/settings/tokens.
accept_token: if True, a Textbox component is first rendered to allow the user to provide a token, which will be used instead of the `token` parameter when calling the loaded model or Space.
kwargs: additional keyword parameters to pass into the `src` function. If `src` is "models" or "Spaces", these parameters are passed into the `gr.Interface` or `gr.ChatInterface` constructor.
Returns:
a Gradio Blocks app for the given model
@ -64,24 +66,45 @@ def load(
)
if src is None:
# Separate the repo type (e.g. "model") from repo name (e.g. "google/vit-base-patch16-224")
tokens = name.split("/")
if len(tokens) <= 1:
parts = name.split("/")
if len(parts) <= 1:
raise ValueError(
"Either `src` parameter must be provided, or `name` must be formatted as {src}/{repo name}"
)
src = tokens[0] # type: ignore
name = "/".join(tokens[1:])
if src in ["huggingface", "models", "spaces"]:
return load_blocks_from_huggingface(
name=name, src=src, hf_token=token, **kwargs
)
elif isinstance(src, Callable):
return src(name, token, **kwargs)
else:
src = parts[0] # type: ignore
name = "/".join(parts[1:])
assert src is not None # noqa: S101
if not isinstance(src, Callable) and src not in ["models", "spaces", "huggingface"]:
raise ValueError(
"The `src` parameter must be one of 'huggingface', 'models', 'spaces', or a function that accepts a model name (and optionally, a token), and returns a Gradio app."
)
if not accept_token:
if isinstance(src, Callable):
return src(name, token, **kwargs)
return load_blocks_from_huggingface(
name=name, src=src, hf_token=token, **kwargs
)
else:
import gradio as gr
with gr.Blocks(fill_height=True) as demo:
textbox = gr.Textbox(
type="password",
label="Token",
info="Enter your token and press enter.",
)
@gr.render(inputs=[textbox], triggers=[textbox.submit])
def create(token_value):
if isinstance(src, Callable):
return src(name, token_value, **kwargs)
return load_blocks_from_huggingface(
name=name, src=src, hf_token=token_value, **kwargs
)
return demo
def load_blocks_from_huggingface(
name: str,