Fixes components when loading private spaces (#3068)

* file routes

* adding access token

* add reverse proxy

* adding access token

* context

* rewrite

* frontend

* formatting

* changelog

* formatting

* fix tests

* fixed image issue

* fix frontend

* os removal

* Update test_external.py

* fixes to normalise

* version

* fixes so that functions work

* lint

* formatting
This commit is contained in:
Abubakar Abid 2023-02-07 07:55:51 -08:00 committed by GitHub
parent f062c7e1fd
commit f37d17089d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 193 additions and 81 deletions

1
.gitignore vendored
View File

@ -22,6 +22,7 @@ gradio/templates/frontend
gradio/launches.json
flagged/
gradio_cached_examples/
tmp.zip
# Tests
.coverage

View File

@ -41,6 +41,7 @@ By [@maxaudron](https://github.com/maxaudron) in [PR 3075](https://github.com/gr
- Fixes URL resolution on Windows by [@abidlabs](https://github.com/abidlabs) in [PR 3108](https://github.com/gradio-app/gradio/pull/3108)
- Example caching now works with components without a label attribute (e.g. `Column`) by [@abidlabs](https://github.com/abidlabs) in [PR 3123](https://github.com/gradio-app/gradio/pull/3123)
- Ensure the Video component correctly resets the UI state whe a new video source is loaded and reduce choppiness of UI by [@pngwn](https://github.com/abidlabs) in [PR 3117](https://github.com/gradio-app/gradio/pull/3117)
- Fixes loading private Spaces by [@abidlabs](https://github.com/abidlabs) in [PR 3068](https://github.com/gradio-app/gradio/pull/3068)
- Added a warning when attempting to launch an `Interface` via the `%%blocks` jupyter notebook magic command by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3126](https://github.com/gradio-app/gradio/pull/3126)
## Documentation Changes:
@ -420,8 +421,8 @@ No changes to highlight.
No changes to highlight.
## Bug Fixes:
No changes to highlight.
*No changes to highlight.
*
## Documentation Changes:
* Improves documentation of several queuing-related parameters by [@abidlabs](https://github.com/abidlabs) in [PR 2825](https://github.com/gradio-app/gradio/pull/2825)

View File

@ -528,7 +528,10 @@ class Blocks(BlockContext):
@classmethod
def from_config(
cls, config: dict, fns: List[Callable], root_url: str | None = None
cls,
config: dict,
fns: List[Callable],
root_url: str | None = None,
) -> Blocks:
"""
Factory method that creates a Blocks from a config and list of functions.

View File

@ -1678,7 +1678,13 @@ class Image(
)
def as_example(self, input_data: str | None) -> str:
return "" if input_data is None else str(utils.abspath(input_data))
if input_data is None:
return ""
elif (
self.root_url
): # If an externally hosted image, don't convert to absolute path
return input_data
return str(utils.abspath(input_data))
@document("change", "clear", "play", "pause", "stop", "style")
@ -4974,6 +4980,8 @@ class Dataset(Clickable, Component):
[isinstance(c, IOComponent) for c in self.components]
), "All components in a `Dataset` must be subclasses of `IOComponent`"
self.components = [c for c in self.components if isinstance(c, IOComponent)]
for component in self.components:
component.root_url = self.root_url
self.samples = [[]] if samples is None else samples
for example in self.samples:

View File

@ -12,4 +12,7 @@ class Context:
root_block: Blocks | None = None # The current root block that holds all blocks.
block: BlockContext | None = None # The current block that children are added to.
id: int = 0 # Running id to uniquely refer to any block that gets defined
ip_address: str | None = None
ip_address: str | None = None # The IP address of the user.
access_token: str | None = (
None # The HF token that is provided when loading private models or Spaces
)

View File

@ -14,6 +14,7 @@ import requests
import gradio
from gradio import components, utils
from gradio.context import Context
from gradio.exceptions import TooManyRequestsError
from gradio.external_utils import (
cols_to_rows,
@ -59,6 +60,13 @@ def load_blocks_from_repo(
factory_methods.keys()
)
if api_key is not None:
if Context.access_token is not None and Context.access_token != api_key:
warnings.warn(
"""You are loading a model/Space with a different access token than the one you used to load a previous model/Space. This is not recommended, as it may cause unexpected behavior."""
)
Context.access_token = api_key
blocks: gradio.Blocks = factory_methods[src](name, api_key, alias, **kwargs)
return blocks

View File

@ -410,24 +410,39 @@ class TempFileManager:
return full_temp_file_path
def create_tmp_copy_of_file(file_path, dir=None):
def download_tmp_copy_of_file(
url_path: str, access_token: str | None = None, dir: str | None = None
) -> tempfile._TemporaryFileWrapper:
if dir is not None:
os.makedirs(dir, exist_ok=True)
file_name = Path(file_path).name
prefix, extension = file_name, None
if "." in file_name:
prefix = file_name[0 : file_name.index(".")]
extension = file_name[file_name.index(".") + 1 :]
prefix = utils.strip_invalid_filename_characters(prefix)
if extension is None:
file_obj = tempfile.NamedTemporaryFile(delete=False, prefix=prefix, dir=dir)
else:
file_obj = tempfile.NamedTemporaryFile(
delete=False,
prefix=prefix,
suffix="." + extension,
dir=dir,
)
headers = {"Authorization": "Bearer " + access_token} if access_token else {}
prefix = Path(url_path).stem
suffix = Path(url_path).suffix
file_obj = tempfile.NamedTemporaryFile(
delete=False,
prefix=prefix,
suffix=suffix,
dir=dir,
)
with requests.get(url_path, headers=headers, stream=True) as r:
with open(file_obj.name, "wb") as f:
shutil.copyfileobj(r.raw, f)
return file_obj
def create_tmp_copy_of_file(
file_path: str, dir: str | None = None
) -> tempfile._TemporaryFileWrapper:
if dir is not None:
os.makedirs(dir, exist_ok=True)
prefix = Path(file_path).stem
suffix = Path(file_path).suffix
file_obj = tempfile.NamedTemporaryFile(
delete=False,
prefix=prefix,
suffix=suffix,
dir=dir,
)
shutil.copy2(file_path, file_obj.name)
return file_obj

View File

@ -17,6 +17,7 @@ from typing import Any, Dict, List, Optional, Type
from urllib.parse import urlparse
import fastapi
import httpx
import markupsafe
import orjson
import pkg_resources
@ -31,12 +32,14 @@ from fastapi.responses import (
from fastapi.security import OAuth2PasswordRequestForm
from fastapi.templating import Jinja2Templates
from jinja2.exceptions import TemplateNotFound
from starlette.responses import RedirectResponse
from starlette.background import BackgroundTask
from starlette.responses import RedirectResponse, StreamingResponse
from starlette.websockets import WebSocketState
import gradio
import gradio.ranged_response as ranged_response
from gradio import utils
from gradio.context import Context
from gradio.data_classes import PredictBody, ResetBody
from gradio.documentation import document, set_documentation_group
from gradio.exceptions import Error
@ -85,6 +88,7 @@ def toorjson(value):
templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
templates.env.filters["toorjson"] = toorjson
client = httpx.AsyncClient()
###########
# Auth
@ -261,6 +265,23 @@ class App(FastAPI):
else:
return FileResponse(blocks.favicon_path)
@app.head("/proxy={url_path:path}", dependencies=[Depends(login_check)])
@app.get("/proxy={url_path:path}", dependencies=[Depends(login_check)])
async def reverse_proxy(url_path: str):
# Adapted from: https://github.com/tiangolo/fastapi/issues/1788
url = httpx.URL(url_path)
headers = {}
if Context.access_token is not None:
headers["Authorization"] = f"Bearer {Context.access_token}"
rp_req = client.build_request("GET", url, headers=headers)
rp_resp = await client.send(rp_req, stream=True)
return StreamingResponse(
rp_resp.aiter_raw(),
status_code=rp_resp.status_code,
headers=rp_resp.headers, # type: ignore
background=BackgroundTask(rp_resp.aclose),
)
@app.head("/file={path_or_url:path}", dependencies=[Depends(login_check)])
@app.get("/file={path_or_url:path}", dependencies=[Depends(login_check)])
async def file(path_or_url: str, request: fastapi.Request):

View File

@ -1,11 +1,11 @@
from __future__ import annotations
import urllib.parse
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict
from gradio import processing_utils, utils
from gradio.context import Context
class Serializable(ABC):
@ -164,9 +164,11 @@ class FileSerializable(Serializable):
elif isinstance(x, dict):
if x.get("is_file", False):
if root_url is not None:
x["name"] = urllib.parse.urljoin(root_url, "file=" + x["name"])
if utils.validate_url(x["name"]):
file_name = x["name"]
file_name = processing_utils.download_tmp_copy_of_file(
root_url + "file=" + x["name"],
access_token=Context.access_token,
dir=save_dir,
).name
else:
file_name = processing_utils.create_tmp_copy_of_file(
x["name"], dir=save_dir

View File

@ -448,12 +448,12 @@ def async_iteration(iterator):
class AsyncRequest:
"""
The AsyncRequest class is a low-level API that allow you to create asynchronous HTTP requests without a context manager.
Compared to making calls by using httpx directly, AsyncRequest offers more flexibility and control over:
Compared to making calls by using httpx directly, AsyncRequest offers several advantages:
(1) Includes response validation functionality both using validation models and functions.
(2) Since we're still using httpx.Request class by wrapping it, we have all it's functionalities.
(3) Exceptions are handled silently during the request call, which gives us the ability to inspect each one
individually in the case of multiple asynchronous request calls and some of them failing.
(4) Provides HTTP request types with AsyncRequest.Method Enum class for ease of usage
(2) Exceptions are handled silently during the request call, which provides the ability to inspect each one
request call individually in the case where there are multiple asynchronous request calls and some of them fail.
(3) Provides HTTP request types with AsyncRequest.Method Enum class for ease of usage
AsyncRequest also offers some util functions such as has_exception, is_valid and status to inspect get detailed
information about executed request call.

View File

@ -1 +1 @@
3.17.1
3.17.1b2

View File

@ -12,6 +12,7 @@ from fastapi.testclient import TestClient
import gradio
import gradio as gr
from gradio import media_data, utils
from gradio.context import Context
from gradio.exceptions import InvalidApiName
from gradio.external import (
TooManyRequestsError,
@ -298,6 +299,40 @@ class TestLoadInterface:
except TooManyRequestsError:
pass
def test_private_space_audio(self):
api_key = "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" # Intentionally revealing this key for testing purposes
io = gr.Interface.load(
"spaces/gradio-tests/not-actually-private-space-audio", api_key=api_key
)
try:
output = io(media_data.BASE64_AUDIO["name"])
assert output.endswith(".wav")
except TooManyRequestsError:
pass
def test_multiple_spaces_one_private(self):
api_key = "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" # Intentionally revealing this key for testing purposes
with gr.Blocks():
gr.Interface.load(
"spaces/gradio-tests/not-actually-private-space", api_key=api_key
)
gr.Interface.load(
"spaces/gradio/test-loading-examples",
)
assert Context.access_token == api_key
def test_loading_files_via_proxy_works(self):
api_key = "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" # Intentionally revealing this key for testing purposes
io = gr.Interface.load(
"spaces/gradio-tests/test-loading-examples-private", api_key=api_key
)
app, _, _ = io.launch(prevent_thread_lock=True)
test_client = TestClient(app)
r = test_client.get(
"/proxy=https://gradio-tests-test-loading-examples-private.hf.space/file=/tmp/tmprahzj703/Bunnyual53t2x.obj"
)
assert r.status_code == 200
class TestLoadInterfaceWithExamples:
def test_interface_load_examples(self, tmp_path):

View File

@ -319,3 +319,12 @@ class TestVideoProcessing:
)
# If the conversion succeeded it'd be .mp4
assert pathlib.Path(playable_vid).suffix == ".avi"
def test_download_private_file():
url_path = "https://gradio-tests-not-actually-private-space.hf.space/file=lion.jpg"
access_token = "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" # Intentionally revealing this key for testing purposes
file = processing_utils.download_tmp_copy_of_file(
url_path=url_path, access_token=access_token
)
assert file.name.endswith(".jpg")

BIN
tmp.zip Normal file

Binary file not shown.

View File

@ -34,7 +34,7 @@
export let loading_status: LoadingStatus;
let _value: null | FileData;
$: _value = normalise_file(value, root_url ?? root);
$: _value = normalise_file(value, root, root_url);
let dragging: boolean;
</script>

View File

@ -16,7 +16,9 @@
const dispatch = createEventDispatcher<{ click: number }>();
let samples_dir: string = (root_url ?? root) + "file=";
let samples_dir: string = root_url
? "proxy=" + root_url + "file="
: root + "file=";
let page = 0;
$: gallery = components.length < 2;
let paginate = samples.length > samples_per_page;

View File

@ -43,6 +43,10 @@
} catch (e) {
console.error(e);
}
})
.catch((e) => {
loaded_value = value;
loaded = true;
});
}
</script>
@ -55,26 +59,30 @@
on:mouseenter={() => (hovered = true)}
on:mouseleave={() => (hovered = false)}
>
<table class="">
{#each loaded_value.slice(0, 3) as row, i}
<tr>
{#each row.slice(0, 3) as cell, j}
<td>{cell}</td>
{/each}
{#if row.length > 3}
<td></td>
{/if}
</tr>
{/each}
{#if value.length > 3}
<div
class="overlay"
class:odd={index % 2 != 0}
class:even={index % 2 == 0}
class:button={type === "gallery"}
/>
{/if}
</table>
{#if typeof loaded_value === "string"}
{loaded_value}
{:else}
<table class="">
{#each loaded_value.slice(0, 3) as row, i}
<tr>
{#each row.slice(0, 3) as cell, j}
<td>{cell}</td>
{/each}
{#if row.length > 3}
<td></td>
{/if}
</tr>
{/each}
{#if value.length > 3}
<div
class="overlay"
class:odd={index % 2 != 0}
class:even={index % 2 == 0}
class:button={type === "gallery"}
/>
{/if}
</table>
{/if}
</div>
{/if}

View File

@ -2,7 +2,7 @@
import { createEventDispatcher } from "svelte";
import { File, FileUpload } from "@gradio/file";
import type { FileData } from "@gradio/upload";
import { normalise_files } from "@gradio/upload";
import { normalise_file } from "@gradio/upload";
import { Block } from "@gradio/atoms";
import UploadText from "../UploadText.svelte";
@ -27,7 +27,7 @@
export let loading_status: LoadingStatus;
let _value: null | FileData | Array<FileData>;
$: _value = normalise_files(value, root_url ?? root);
$: _value = normalise_file(value, root, root_url);
let dragging = false;

View File

@ -23,7 +23,7 @@
export let show_label: boolean;
let _value: null | FileData;
$: _value = normalise_file(value, root_url ?? root);
$: _value = normalise_file(value, root, root_url);
let dragging = false;
</script>

View File

@ -29,7 +29,7 @@
export let mode: "static" | "dynamic";
let _value: null | FileData;
$: _value = normalise_file(value, root_url ?? root);
$: _value = normalise_file(value, root, root_url);
let dragging = false;

View File

@ -21,8 +21,8 @@
? null
: value.map((img) =>
Array.isArray(img)
? [normalise_file(img[0], root_url ?? root), img[1]]
: [normalise_file(img, root_url ?? root), null]
? [normalise_file(img[0], root, root_url), img[1]]
: [normalise_file(img, root, root_url), null]
);
let prevValue: string[] | FileData[] | null = null;

View File

@ -2,27 +2,19 @@ import type { FileData } from "./types";
export function normalise_file(
file: string | FileData | null,
root: string
): FileData | null {
if (file == null) return null;
if (typeof file === "string") {
return {
name: "file_data",
data: file
};
} else if (Array.isArray(file)) {
for (const x of file) {
normalise_file(x, root);
}
} else if (file.is_file) {
file.data = root + "file=" + file.name;
}
return file;
}
root: string,
root_url: string | null
): FileData | null;
export function normalise_file(
file: Array<FileData> | FileData | null,
root: string,
root_url: string | null
): Array<FileData> | FileData | null;
export function normalise_files(
export function normalise_file(
file: string | FileData | Array<FileData> | null,
root: string
root: string,
root_url: string | null
): FileData | Array<FileData> | null {
if (file == null) return null;
if (typeof file === "string") {
@ -32,10 +24,14 @@ export function normalise_files(
};
} else if (Array.isArray(file)) {
for (const x of file) {
normalise_file(x, root);
normalise_file(x, root, root_url);
}
} else if (file.is_file) {
file.data = root + "file=" + file.name;
if (root_url == null) {
file.data = "file=" + file.name;
} else {
file.data = "proxy=" + root_url + "file=" + file.name;
}
}
return file;
}