2
0
mirror of https://github.com/gradio-app/gradio.git synced 2025-04-24 13:01:18 +08:00

Add support for loading Private Spaces using HF Token ()

* pass token & add typing

* updated docstring

* reorg

* improve docs

* test for private space

* changelog

* addressing review

* fix tests

* fix websocket headers
This commit is contained in:
Abubakar Abid 2022-10-31 20:29:13 -07:00 committed by GitHub
parent 17cd9b59bd
commit e360f159a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 89 additions and 39 deletions

@ -88,7 +88,8 @@ No changes to highlight.
No changes to highlight.
## Full Changelog:
No changes to highlight.
* Allows loading private Spaces by passing an an `api_key` to `gr.Interface.load()`
by [@abidlabs](https://github.com/abidlabs) in [PR 2568](https://github.com/gradio-app/gradio/pull/2568)
## Contributors Shoutout:
No changes to highlight.

@ -1054,10 +1054,10 @@ class Blocks(BlockContext):
Instance method: adds event that runs as soon as the demo loads in the browser. Example usage below.
Parameters:
name: Class Method - the name of the model (e.g. "gpt2"), can include the `src` as prefix (e.g. "models/gpt2")
src: Class Method - the source of the model: `models` or `spaces` (or empty if source is provided as a prefix in `name`)
api_key: Class Method - optional api key for use with Hugging Face Hub
alias: Class Method - optional string used as the name of the loaded model instead of the default name
name: Class Method - the name of the model (e.g. "gpt2" or "facebook/bart-base") or space (e.g. "flax-community/spanish-gpt2"), can include the `src` as prefix (e.g. "models/facebook/bart-base")
src: Class Method - the source of the model: `models` or `spaces` (or leave empty if source is provided as a prefix in `name`)
api_key: Class Method - optional access token for loading private Hugging Face Hub models or spaces. Find your token here: https://huggingface.co/settings/tokens
alias: Class Method - optional string used as the name of the loaded model instead of the default name (only applies if loading a Space running Gradio 2.x)
fn: Instance Method - Callable function
inputs: Instance Method - input list
outputs: Instance Method - output list

@ -1,4 +1,12 @@
class DuplicateBlockError(ValueError):
"""Raised when a Blocks contains more than one Block with the same id"""
pass
class TooManyRequestsError(Exception):
"""Raised when the Hugging Face API returns a 429 status code."""
pass

@ -21,19 +21,18 @@ from packaging import version
import gradio
from gradio import components, exceptions, utils
from gradio.exceptions import TooManyRequestsError
from gradio.processing_utils import to_binary
if TYPE_CHECKING:
from gradio.blocks import Blocks
from gradio.components import DataframeData
from gradio.interface import Interface
class TooManyRequestsError(Exception):
"""Raised when the Hugging Face API returns a 429 status code."""
pass
def load_blocks_from_repo(name, src=None, api_key=None, alias=None, **kwargs):
def load_blocks_from_repo(
name: str, src: str = None, api_key: str = None, alias: str = None, **kwargs
) -> Blocks:
"""Creates and returns a Blocks instance from several kinds of Hugging Face repos:
1) A model repo
2) A Spaces repo running Gradio 2.x
@ -55,7 +54,7 @@ def load_blocks_from_repo(name, src=None, api_key=None, alias=None, **kwargs):
return blocks
def get_tabular_examples(model_name) -> Dict[str, List[float]]:
def get_tabular_examples(model_name: str) -> Dict[str, List[float]]:
readme = requests.get(f"https://huggingface.co/{model_name}/resolve/main/README.md")
if readme.status_code != 200:
warnings.warn(f"Cannot load examples from README for {model_name}", UserWarning)
@ -107,7 +106,7 @@ def rows_to_cols(
return {"inputs": {"data": data_column_wise}}
def get_models_interface(model_name, api_key, alias, **kwargs):
def get_models_interface(model_name: str, api_key: str | None, alias: str, **kwargs):
model_url = "https://huggingface.co/{}".format(model_name)
api_url = "https://api-inference.huggingface.co/models/{}".format(model_name)
print("Fetching model from: {}".format(model_url))
@ -394,23 +393,37 @@ def get_models_interface(model_name, api_key, alias, **kwargs):
return interface
def get_spaces(model_name, api_key, alias, **kwargs):
space_url = "https://huggingface.co/spaces/{}".format(model_name)
print("Fetching interface from: {}".format(space_url))
iframe_url = "https://hf.space/embed/{}/+".format(model_name)
def get_spaces(space_name: str, api_key: str | None, alias: str, **kwargs) -> Blocks:
space_url = "https://huggingface.co/spaces/{}".format(space_name)
print("Fetching Space from: {}".format(space_url))
headers = {}
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
iframe_url = (
requests.get(
f"https://huggingface.co/api/spaces/{space_name}/host", headers=headers
)
.json()
.get("host")
)
r = requests.get(iframe_url, headers=headers)
r = requests.get(iframe_url)
result = re.search(
r"window.gradio_config = (.*?);[\s]*</script>", r.text
) # some basic regex to extract the config
try:
config = json.loads(result.group(1))
except AttributeError:
raise ValueError("Could not load the Space: {}".format(model_name))
raise ValueError("Could not load the Space: {}".format(space_name))
if "allow_flagging" in config: # Create an Interface for Gradio 2.x Spaces
return get_spaces_interface(model_name, config, alias, **kwargs)
return get_spaces_interface(
space_name, config, alias, api_key, iframe_url, **kwargs
)
else: # Create a Blocks for Gradio 3.x Spaces
return get_spaces_blocks(model_name, config)
return get_spaces_blocks(space_name, config, api_key, iframe_url)
async def get_pred_from_ws(
@ -430,9 +443,11 @@ async def get_pred_from_ws(
return resp["output"]
def get_ws_fn(ws_url):
def get_ws_fn(ws_url, headers):
async def ws_fn(data, hash_data):
async with websockets.connect(ws_url, open_timeout=10) as websocket:
async with websockets.connect(
ws_url, open_timeout=10, extra_headers=headers
) as websocket:
return await get_pred_from_ws(websocket, data, hash_data)
return ws_fn
@ -447,7 +462,9 @@ def use_websocket(config, dependency):
return queue_enabled and queue_uses_websocket and dependency_uses_queue
def get_spaces_blocks(model_name, config):
def get_spaces_blocks(
model_name: str, config: Dict, api_key: str | None, iframe_url: str
) -> Blocks:
def streamline_config(config: dict) -> dict:
"""Streamlines the blocks config dictionary to fix components that don't render correctly."""
# TODO(abidlabs): Need a better way to fix relative paths in dataset component
@ -457,11 +474,13 @@ def get_spaces_blocks(model_name, config):
return config
config = streamline_config(config)
api_url = "https://hf.space/embed/{}/api/predict/".format(model_name)
api_url = "{}/api/predict/".format(iframe_url)
headers = {"Content-Type": "application/json"}
ws_url = "wss://spaces.huggingface.tech/{}/queue/join".format(model_name)
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
ws_url = "{}/queue/join".format(iframe_url).replace("https", "wss")
ws_fn = get_ws_fn(ws_url)
ws_fn = get_ws_fn(ws_url, headers)
fns = []
for d, dependency in enumerate(config["dependencies"]):
@ -504,8 +523,15 @@ def get_spaces_blocks(model_name, config):
return gradio.Blocks.from_config(config, fns)
def get_spaces_interface(model_name, config, alias, **kwargs):
def streamline_config(config: dict) -> dict:
def get_spaces_interface(
model_name: str,
config: Dict,
alias: str,
api_key: str | None,
iframe_url: str,
**kwargs,
) -> Interface:
def streamline_config(config: Dict) -> Dict:
"""Streamlines the interface config dictionary to remove unnecessary keys."""
config["inputs"] = [
components.get_component_instance(component)
@ -528,8 +554,10 @@ def get_spaces_interface(model_name, config, alias, **kwargs):
return config
config = streamline_config(config)
api_url = "https://hf.space/embed/{}/api/predict/".format(model_name)
api_url = "{}/api/predict/".format(iframe_url)
headers = {"Content-Type": "application/json"}
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
# The function should call the API with preprocessed data
def fn(*data):
@ -571,7 +599,7 @@ factory_methods: Dict[str, Callable] = {
}
def load_from_pipeline(pipeline):
def load_from_pipeline(pipeline) -> Dict:
"""
Gets the appropriate Interface kwargs for a given Hugging Face transformers.Pipeline.
pipeline (transformers.Pipeline): the transformers.Pipeline from which to create an interface

@ -92,10 +92,10 @@ class Interface(Blocks):
model repos (if src is "models") or Space repos (if src is "spaces"). The input
and output components are automatically loaded from the repo.
Parameters:
name: the name of the model (e.g. "gpt2"), can include the `src` as prefix (e.g. "models/gpt2")
src: the source of the model: `models` or `spaces` (or empty if source is provided as a prefix in `name`)
api_key: optional api key for use with Hugging Face Hub
alias: optional string used as the name of the loaded model instead of the default name
name: the name of the model (e.g. "gpt2" or "facebook/bart-base") or space (e.g. "flax-community/spanish-gpt2"), can include the `src` as prefix (e.g. "models/facebook/bart-base")
src: the source of the model: `models` or `spaces` (or leave empty if source is provided as a prefix in `name`)
api_key: optional access token for loading private Hugging Face Hub models or spaces. Find your token here: https://huggingface.co/settings/tokens
alias: optional string used as the name of the loaded model instead of the default name (only applies if loading a Space running Gradio 2.x)
Returns:
a Gradio Interface object for the given model
Example:

@ -234,6 +234,17 @@ class TestLoadInterface(unittest.TestCase):
except TooManyRequestsError:
pass
def test_private_space(self):
api_key = "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" # Intentionally revealing this key for testing purposes
io = gr.Interface.load(
"spaces/gradio-tests/not-actually-private-space", api_key=api_key
)
try:
output = io("abc")
self.assertEqual(output, "abc")
except TooManyRequestsError:
pass
class TestLoadFromPipeline(unittest.TestCase):
def test_text_to_text_model_from_pipeline(self):

@ -285,8 +285,8 @@ class TestQueueRoutes:
"http://127.0.0.1:7861/gradio/gradio/gradio/",
),
(
"wss://huggingface.co.tech/path/queue/join",
"https://huggingface.co.tech/path/",
"wss://gradio-titanic-survival.hf.space/queue/join",
"https://gradio-titanic-survival.hf.space/",
),
],
)

@ -1,9 +1,11 @@
import requests
from upload_demos import demos, upload_demo_to_space, AUTH_TOKEN, gradio_version
from gradio.networking import url_ok
for demo in demos:
space_id = "gradio/" + demo
if not url_ok(f"https://hf.space/embed/{space_id}/+"):
space_url = requests.get(f"https://huggingface.co/api/spaces/{space_id}/host").json().get("host")
if not url_ok(space_url):
print(f"{space_id} was down, restarting")
upload_demo_to_space(demo_name=demo, space_id=space_id, hf_token=AUTH_TOKEN, gradio_version=gradio_version)