mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-24 13:01:18 +08:00
Add support for loading Private Spaces using HF Token (#2568)
* pass token & add typing * updated docstring * reorg * improve docs * test for private space * changelog * addressing review * fix tests * fix websocket headers
This commit is contained in:
parent
17cd9b59bd
commit
e360f159a9
@ -88,7 +88,8 @@ No changes to highlight.
|
||||
No changes to highlight.
|
||||
|
||||
## Full Changelog:
|
||||
No changes to highlight.
|
||||
* Allows loading private Spaces by passing an an `api_key` to `gr.Interface.load()`
|
||||
by [@abidlabs](https://github.com/abidlabs) in [PR 2568](https://github.com/gradio-app/gradio/pull/2568)
|
||||
|
||||
## Contributors Shoutout:
|
||||
No changes to highlight.
|
||||
|
@ -1054,10 +1054,10 @@ class Blocks(BlockContext):
|
||||
|
||||
Instance method: adds event that runs as soon as the demo loads in the browser. Example usage below.
|
||||
Parameters:
|
||||
name: Class Method - the name of the model (e.g. "gpt2"), can include the `src` as prefix (e.g. "models/gpt2")
|
||||
src: Class Method - the source of the model: `models` or `spaces` (or empty if source is provided as a prefix in `name`)
|
||||
api_key: Class Method - optional api key for use with Hugging Face Hub
|
||||
alias: Class Method - optional string used as the name of the loaded model instead of the default name
|
||||
name: Class Method - the name of the model (e.g. "gpt2" or "facebook/bart-base") or space (e.g. "flax-community/spanish-gpt2"), can include the `src` as prefix (e.g. "models/facebook/bart-base")
|
||||
src: Class Method - the source of the model: `models` or `spaces` (or leave empty if source is provided as a prefix in `name`)
|
||||
api_key: Class Method - optional access token for loading private Hugging Face Hub models or spaces. Find your token here: https://huggingface.co/settings/tokens
|
||||
alias: Class Method - optional string used as the name of the loaded model instead of the default name (only applies if loading a Space running Gradio 2.x)
|
||||
fn: Instance Method - Callable function
|
||||
inputs: Instance Method - input list
|
||||
outputs: Instance Method - output list
|
||||
|
@ -1,4 +1,12 @@
|
||||
class DuplicateBlockError(ValueError):
|
||||
"""Raised when a Blocks contains more than one Block with the same id"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TooManyRequestsError(Exception):
|
||||
"""Raised when the Hugging Face API returns a 429 status code."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
|
@ -21,19 +21,18 @@ from packaging import version
|
||||
|
||||
import gradio
|
||||
from gradio import components, exceptions, utils
|
||||
from gradio.exceptions import TooManyRequestsError
|
||||
from gradio.processing_utils import to_binary
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.blocks import Blocks
|
||||
from gradio.components import DataframeData
|
||||
from gradio.interface import Interface
|
||||
|
||||
|
||||
class TooManyRequestsError(Exception):
|
||||
"""Raised when the Hugging Face API returns a 429 status code."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def load_blocks_from_repo(name, src=None, api_key=None, alias=None, **kwargs):
|
||||
def load_blocks_from_repo(
|
||||
name: str, src: str = None, api_key: str = None, alias: str = None, **kwargs
|
||||
) -> Blocks:
|
||||
"""Creates and returns a Blocks instance from several kinds of Hugging Face repos:
|
||||
1) A model repo
|
||||
2) A Spaces repo running Gradio 2.x
|
||||
@ -55,7 +54,7 @@ def load_blocks_from_repo(name, src=None, api_key=None, alias=None, **kwargs):
|
||||
return blocks
|
||||
|
||||
|
||||
def get_tabular_examples(model_name) -> Dict[str, List[float]]:
|
||||
def get_tabular_examples(model_name: str) -> Dict[str, List[float]]:
|
||||
readme = requests.get(f"https://huggingface.co/{model_name}/resolve/main/README.md")
|
||||
if readme.status_code != 200:
|
||||
warnings.warn(f"Cannot load examples from README for {model_name}", UserWarning)
|
||||
@ -107,7 +106,7 @@ def rows_to_cols(
|
||||
return {"inputs": {"data": data_column_wise}}
|
||||
|
||||
|
||||
def get_models_interface(model_name, api_key, alias, **kwargs):
|
||||
def get_models_interface(model_name: str, api_key: str | None, alias: str, **kwargs):
|
||||
model_url = "https://huggingface.co/{}".format(model_name)
|
||||
api_url = "https://api-inference.huggingface.co/models/{}".format(model_name)
|
||||
print("Fetching model from: {}".format(model_url))
|
||||
@ -394,23 +393,37 @@ def get_models_interface(model_name, api_key, alias, **kwargs):
|
||||
return interface
|
||||
|
||||
|
||||
def get_spaces(model_name, api_key, alias, **kwargs):
|
||||
space_url = "https://huggingface.co/spaces/{}".format(model_name)
|
||||
print("Fetching interface from: {}".format(space_url))
|
||||
iframe_url = "https://hf.space/embed/{}/+".format(model_name)
|
||||
def get_spaces(space_name: str, api_key: str | None, alias: str, **kwargs) -> Blocks:
|
||||
space_url = "https://huggingface.co/spaces/{}".format(space_name)
|
||||
print("Fetching Space from: {}".format(space_url))
|
||||
|
||||
headers = {}
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
iframe_url = (
|
||||
requests.get(
|
||||
f"https://huggingface.co/api/spaces/{space_name}/host", headers=headers
|
||||
)
|
||||
.json()
|
||||
.get("host")
|
||||
)
|
||||
|
||||
r = requests.get(iframe_url, headers=headers)
|
||||
|
||||
r = requests.get(iframe_url)
|
||||
result = re.search(
|
||||
r"window.gradio_config = (.*?);[\s]*</script>", r.text
|
||||
) # some basic regex to extract the config
|
||||
try:
|
||||
config = json.loads(result.group(1))
|
||||
except AttributeError:
|
||||
raise ValueError("Could not load the Space: {}".format(model_name))
|
||||
raise ValueError("Could not load the Space: {}".format(space_name))
|
||||
if "allow_flagging" in config: # Create an Interface for Gradio 2.x Spaces
|
||||
return get_spaces_interface(model_name, config, alias, **kwargs)
|
||||
return get_spaces_interface(
|
||||
space_name, config, alias, api_key, iframe_url, **kwargs
|
||||
)
|
||||
else: # Create a Blocks for Gradio 3.x Spaces
|
||||
return get_spaces_blocks(model_name, config)
|
||||
return get_spaces_blocks(space_name, config, api_key, iframe_url)
|
||||
|
||||
|
||||
async def get_pred_from_ws(
|
||||
@ -430,9 +443,11 @@ async def get_pred_from_ws(
|
||||
return resp["output"]
|
||||
|
||||
|
||||
def get_ws_fn(ws_url):
|
||||
def get_ws_fn(ws_url, headers):
|
||||
async def ws_fn(data, hash_data):
|
||||
async with websockets.connect(ws_url, open_timeout=10) as websocket:
|
||||
async with websockets.connect(
|
||||
ws_url, open_timeout=10, extra_headers=headers
|
||||
) as websocket:
|
||||
return await get_pred_from_ws(websocket, data, hash_data)
|
||||
|
||||
return ws_fn
|
||||
@ -447,7 +462,9 @@ def use_websocket(config, dependency):
|
||||
return queue_enabled and queue_uses_websocket and dependency_uses_queue
|
||||
|
||||
|
||||
def get_spaces_blocks(model_name, config):
|
||||
def get_spaces_blocks(
|
||||
model_name: str, config: Dict, api_key: str | None, iframe_url: str
|
||||
) -> Blocks:
|
||||
def streamline_config(config: dict) -> dict:
|
||||
"""Streamlines the blocks config dictionary to fix components that don't render correctly."""
|
||||
# TODO(abidlabs): Need a better way to fix relative paths in dataset component
|
||||
@ -457,11 +474,13 @@ def get_spaces_blocks(model_name, config):
|
||||
return config
|
||||
|
||||
config = streamline_config(config)
|
||||
api_url = "https://hf.space/embed/{}/api/predict/".format(model_name)
|
||||
api_url = "{}/api/predict/".format(iframe_url)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
ws_url = "wss://spaces.huggingface.tech/{}/queue/join".format(model_name)
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
ws_url = "{}/queue/join".format(iframe_url).replace("https", "wss")
|
||||
|
||||
ws_fn = get_ws_fn(ws_url)
|
||||
ws_fn = get_ws_fn(ws_url, headers)
|
||||
|
||||
fns = []
|
||||
for d, dependency in enumerate(config["dependencies"]):
|
||||
@ -504,8 +523,15 @@ def get_spaces_blocks(model_name, config):
|
||||
return gradio.Blocks.from_config(config, fns)
|
||||
|
||||
|
||||
def get_spaces_interface(model_name, config, alias, **kwargs):
|
||||
def streamline_config(config: dict) -> dict:
|
||||
def get_spaces_interface(
|
||||
model_name: str,
|
||||
config: Dict,
|
||||
alias: str,
|
||||
api_key: str | None,
|
||||
iframe_url: str,
|
||||
**kwargs,
|
||||
) -> Interface:
|
||||
def streamline_config(config: Dict) -> Dict:
|
||||
"""Streamlines the interface config dictionary to remove unnecessary keys."""
|
||||
config["inputs"] = [
|
||||
components.get_component_instance(component)
|
||||
@ -528,8 +554,10 @@ def get_spaces_interface(model_name, config, alias, **kwargs):
|
||||
return config
|
||||
|
||||
config = streamline_config(config)
|
||||
api_url = "https://hf.space/embed/{}/api/predict/".format(model_name)
|
||||
api_url = "{}/api/predict/".format(iframe_url)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
# The function should call the API with preprocessed data
|
||||
def fn(*data):
|
||||
@ -571,7 +599,7 @@ factory_methods: Dict[str, Callable] = {
|
||||
}
|
||||
|
||||
|
||||
def load_from_pipeline(pipeline):
|
||||
def load_from_pipeline(pipeline) -> Dict:
|
||||
"""
|
||||
Gets the appropriate Interface kwargs for a given Hugging Face transformers.Pipeline.
|
||||
pipeline (transformers.Pipeline): the transformers.Pipeline from which to create an interface
|
||||
|
@ -92,10 +92,10 @@ class Interface(Blocks):
|
||||
model repos (if src is "models") or Space repos (if src is "spaces"). The input
|
||||
and output components are automatically loaded from the repo.
|
||||
Parameters:
|
||||
name: the name of the model (e.g. "gpt2"), can include the `src` as prefix (e.g. "models/gpt2")
|
||||
src: the source of the model: `models` or `spaces` (or empty if source is provided as a prefix in `name`)
|
||||
api_key: optional api key for use with Hugging Face Hub
|
||||
alias: optional string used as the name of the loaded model instead of the default name
|
||||
name: the name of the model (e.g. "gpt2" or "facebook/bart-base") or space (e.g. "flax-community/spanish-gpt2"), can include the `src` as prefix (e.g. "models/facebook/bart-base")
|
||||
src: the source of the model: `models` or `spaces` (or leave empty if source is provided as a prefix in `name`)
|
||||
api_key: optional access token for loading private Hugging Face Hub models or spaces. Find your token here: https://huggingface.co/settings/tokens
|
||||
alias: optional string used as the name of the loaded model instead of the default name (only applies if loading a Space running Gradio 2.x)
|
||||
Returns:
|
||||
a Gradio Interface object for the given model
|
||||
Example:
|
||||
|
@ -234,6 +234,17 @@ class TestLoadInterface(unittest.TestCase):
|
||||
except TooManyRequestsError:
|
||||
pass
|
||||
|
||||
def test_private_space(self):
|
||||
api_key = "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" # Intentionally revealing this key for testing purposes
|
||||
io = gr.Interface.load(
|
||||
"spaces/gradio-tests/not-actually-private-space", api_key=api_key
|
||||
)
|
||||
try:
|
||||
output = io("abc")
|
||||
self.assertEqual(output, "abc")
|
||||
except TooManyRequestsError:
|
||||
pass
|
||||
|
||||
|
||||
class TestLoadFromPipeline(unittest.TestCase):
|
||||
def test_text_to_text_model_from_pipeline(self):
|
||||
|
@ -285,8 +285,8 @@ class TestQueueRoutes:
|
||||
"http://127.0.0.1:7861/gradio/gradio/gradio/",
|
||||
),
|
||||
(
|
||||
"wss://huggingface.co.tech/path/queue/join",
|
||||
"https://huggingface.co.tech/path/",
|
||||
"wss://gradio-titanic-survival.hf.space/queue/join",
|
||||
"https://gradio-titanic-survival.hf.space/",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -1,9 +1,11 @@
|
||||
import requests
|
||||
from upload_demos import demos, upload_demo_to_space, AUTH_TOKEN, gradio_version
|
||||
from gradio.networking import url_ok
|
||||
|
||||
|
||||
for demo in demos:
|
||||
space_id = "gradio/" + demo
|
||||
if not url_ok(f"https://hf.space/embed/{space_id}/+"):
|
||||
space_url = requests.get(f"https://huggingface.co/api/spaces/{space_id}/host").json().get("host")
|
||||
if not url_ok(space_url):
|
||||
print(f"{space_id} was down, restarting")
|
||||
upload_demo_to_space(demo_name=demo, space_id=space_id, hf_token=AUTH_TOKEN, gradio_version=gradio_version)
|
||||
|
Loading…
x
Reference in New Issue
Block a user