mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-15 02:11:15 +08:00
Allow accepting user-provided-tokens in gr.load
(#9807)
* load * add changeset * lint * external * lint * changes * format * changes * renamed * external --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
90d9d14518
commit
5e89b6d23a
5
.changeset/shaggy-mirrors-nail.md
Normal file
5
.changeset/shaggy-mirrors-nail.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": minor
|
||||
---
|
||||
|
||||
feat:Allow accepting user-provided-tokens in `gr.load`
|
@ -41,6 +41,7 @@ def load(
|
||||
| None = None,
|
||||
token: str | None = None,
|
||||
hf_token: str | None = None,
|
||||
accept_token: bool = False,
|
||||
**kwargs,
|
||||
) -> Blocks:
|
||||
"""
|
||||
@ -49,6 +50,7 @@ def load(
|
||||
name: the name of the model (e.g. "google/vit-base-patch16-224") or Space (e.g. "flax-community/spanish-gpt2"). This is the first parameter passed into the `src` function. Can also be formatted as {src}/{repo name} (e.g. "models/google/vit-base-patch16-224") if `src` is not provided.
|
||||
src: function that accepts a string model `name` and a string or None `token` and returns a Gradio app. Alternatively, this parameter takes one of two strings for convenience: "models" (for loading a Hugging Face model through the Inference API) or "spaces" (for loading a Hugging Face Space). If None, uses the prefix of the `name` parameter to determine `src`.
|
||||
token: optional token that is passed as the second parameter to the `src` function. For Hugging Face repos, uses the local HF token when loading models but not Spaces (when loading Spaces, only provide a token if you are loading a trusted private Space as the token can be read by the Space you are loading). Find HF tokens here: https://huggingface.co/settings/tokens.
|
||||
accept_token: if True, a Textbox component is first rendered to allow the user to provide a token, which will be used instead of the `token` parameter when calling the loaded model or Space.
|
||||
kwargs: additional keyword parameters to pass into the `src` function. If `src` is "models" or "Spaces", these parameters are passed into the `gr.Interface` or `gr.ChatInterface` constructor.
|
||||
Returns:
|
||||
a Gradio Blocks app for the given model
|
||||
@ -64,24 +66,45 @@ def load(
|
||||
)
|
||||
if src is None:
|
||||
# Separate the repo type (e.g. "model") from repo name (e.g. "google/vit-base-patch16-224")
|
||||
tokens = name.split("/")
|
||||
if len(tokens) <= 1:
|
||||
parts = name.split("/")
|
||||
if len(parts) <= 1:
|
||||
raise ValueError(
|
||||
"Either `src` parameter must be provided, or `name` must be formatted as {src}/{repo name}"
|
||||
)
|
||||
src = tokens[0] # type: ignore
|
||||
name = "/".join(tokens[1:])
|
||||
if src in ["huggingface", "models", "spaces"]:
|
||||
return load_blocks_from_huggingface(
|
||||
name=name, src=src, hf_token=token, **kwargs
|
||||
)
|
||||
elif isinstance(src, Callable):
|
||||
return src(name, token, **kwargs)
|
||||
else:
|
||||
src = parts[0] # type: ignore
|
||||
name = "/".join(parts[1:])
|
||||
assert src is not None # noqa: S101
|
||||
if not isinstance(src, Callable) and src not in ["models", "spaces", "huggingface"]:
|
||||
raise ValueError(
|
||||
"The `src` parameter must be one of 'huggingface', 'models', 'spaces', or a function that accepts a model name (and optionally, a token), and returns a Gradio app."
|
||||
)
|
||||
|
||||
if not accept_token:
|
||||
if isinstance(src, Callable):
|
||||
return src(name, token, **kwargs)
|
||||
return load_blocks_from_huggingface(
|
||||
name=name, src=src, hf_token=token, **kwargs
|
||||
)
|
||||
else:
|
||||
import gradio as gr
|
||||
|
||||
with gr.Blocks(fill_height=True) as demo:
|
||||
textbox = gr.Textbox(
|
||||
type="password",
|
||||
label="Token",
|
||||
info="Enter your token and press enter.",
|
||||
)
|
||||
|
||||
@gr.render(inputs=[textbox], triggers=[textbox.submit])
|
||||
def create(token_value):
|
||||
if isinstance(src, Callable):
|
||||
return src(name, token_value, **kwargs)
|
||||
return load_blocks_from_huggingface(
|
||||
name=name, src=src, hf_token=token_value, **kwargs
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
def load_blocks_from_huggingface(
|
||||
name: str,
|
||||
|
Loading…
Reference in New Issue
Block a user